use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tonic::metadata::{Ascii, AsciiMetadataValue, MetadataKey};
use tonic::transport::{Channel, Endpoint};
use crate::config::{ClientConfig, TlsConfig};
use crate::error::OrleansError;
use crate::generated::pb;
use crate::grain::GrainRef;
use crate::key::GrainKey;
use crate::request_context::RequestContext;
use crate::retry::RetryPolicy;
type BridgeClient = pb::orleans_bridge_client::OrleansBridgeClient<Channel>;
#[derive(Clone)]
pub struct OrleansClient {
inner: BridgeClient,
config: Arc<ClientConfig>,
retry: Arc<RetryPolicy>,
metadata: Arc<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>>,
}
pub(crate) struct InvokeCall<'a> {
pub interface_name: &'a str,
pub grain_type: &'a str,
pub key: &'a GrainKey,
pub method: &'a str,
pub payload: Vec<u8>,
pub codec: &'a str,
pub context: &'a RequestContext,
pub timeout: Option<Duration>,
}
#[derive(Debug, Clone)]
pub struct RawResponse {
pub payload: Vec<u8>,
pub codec: String,
pub response_context: HashMap<String, String>,
}
impl OrleansClient {
pub async fn connect(endpoint: impl Into<String>) -> Result<Self, OrleansError> {
Self::from_config(ClientConfig::new(endpoint)).await
}
#[must_use]
pub fn builder(endpoint: impl Into<String>) -> OrleansClientBuilder {
OrleansClientBuilder::new(endpoint)
}
pub async fn from_config(config: ClientConfig) -> Result<Self, OrleansError> {
Self::build(config, RetryPolicy::disabled()).await
}
async fn build(config: ClientConfig, retry: RetryPolicy) -> Result<Self, OrleansError> {
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())
.map_err(|e| OrleansError::InvalidConfig(format!("invalid endpoint: {e}")))?;
if let Some(connect_timeout) = config.connect_timeout {
endpoint = endpoint.connect_timeout(connect_timeout);
}
endpoint = configure_tls(endpoint, config.tls.as_ref())?;
let metadata = build_metadata(&config.metadata)?;
let channel = endpoint.connect().await?;
let mut client = BridgeClient::new(channel);
if let Some(n) = config.max_decoding_message_size {
client = client.max_decoding_message_size(n);
}
if let Some(n) = config.max_encoding_message_size {
client = client.max_encoding_message_size(n);
}
Ok(Self {
inner: client,
config: Arc::new(config),
retry: Arc::new(retry),
metadata: Arc::new(metadata),
})
}
fn request<T>(&self, message: T) -> tonic::Request<T> {
let mut request = tonic::Request::new(message);
let metadata = request.metadata_mut();
for (key, value) in self.metadata.iter() {
metadata.insert(key.clone(), value.clone());
}
request
}
#[must_use]
pub fn config(&self) -> &ClientConfig {
&self.config
}
pub async fn health(&self) -> Result<pb::HealthResponse, OrleansError> {
let mut client = self.inner.clone();
let response = client
.health(self.request(pb::HealthRequest {}))
.await
.map_err(OrleansError::from_status)?;
Ok(response.into_inner())
}
pub async fn manifest(&self) -> Result<pb::ContractManifest, OrleansError> {
let mut client = self.inner.clone();
let response = client
.get_manifest(self.request(pb::GetManifestRequest {}))
.await
.map_err(OrleansError::from_status)?;
Ok(response.into_inner().manifest.unwrap_or_default())
}
#[must_use]
pub fn grain(
&self,
interface_name: impl Into<String>,
grain_type: impl Into<String>,
key: impl Into<GrainKey>,
) -> GrainRef {
GrainRef::new(
self.clone(),
interface_name.into(),
grain_type.into(),
key.into(),
)
}
pub(crate) async fn invoke_raw(
&self,
call: InvokeCall<'_>,
) -> Result<RawResponse, OrleansError> {
let effective_timeout = call.timeout.unwrap_or(self.config.default_timeout);
let target = pb::GrainTarget {
interface_name: call.interface_name.to_owned(),
grain_type: call.grain_type.to_owned(),
key: Some(call.key.to_proto()),
};
let context_map = call.context.clone().into_map();
let mut attempt: u32 = 0;
loop {
let request = pb::InvokeRequest {
target: Some(target.clone()),
method: call.method.to_owned(),
payload: call.payload.clone(),
payload_codec: call.codec.to_owned(),
request_context: context_map.clone(),
timeout_ms: u32::try_from(effective_timeout.as_millis()).unwrap_or(u32::MAX),
};
match self.invoke_once(request, effective_timeout).await {
Ok(response) => return Ok(response),
Err(error) => {
let can_retry = self.retry.is_enabled()
&& attempt < self.retry.max_retries
&& error.is_retryable();
if !can_retry {
return Err(error);
}
let backoff = self.retry.backoff_for(attempt + 1);
if !backoff.is_zero() {
tokio::time::sleep(backoff).await;
}
attempt += 1;
}
}
}
}
async fn invoke_once(
&self,
message: pb::InvokeRequest,
timeout: Duration,
) -> Result<RawResponse, OrleansError> {
let mut client = self.inner.clone();
let request = self.request(message);
let guard = timeout.saturating_add(Duration::from_secs(5));
let call = client.invoke(request);
let result = match tokio::time::timeout(guard, call).await {
Ok(result) => result,
Err(_) => return Err(OrleansError::Timeout),
};
match result {
Ok(response) => {
let inner = response.into_inner();
Ok(RawResponse {
payload: inner.payload,
codec: inner.payload_codec,
response_context: inner.response_context,
})
}
Err(status) => Err(OrleansError::from_status(status)),
}
}
}
pub struct OrleansClientBuilder {
config: ClientConfig,
retry: RetryPolicy,
}
impl OrleansClientBuilder {
fn new(endpoint: impl Into<String>) -> Self {
Self {
config: ClientConfig::new(endpoint),
retry: RetryPolicy::disabled(),
}
}
#[must_use]
pub fn default_timeout(mut self, timeout: Duration) -> Self {
self.config.default_timeout = timeout;
self
}
#[must_use]
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.config.connect_timeout = Some(timeout);
self
}
#[must_use]
pub fn max_decoding_message_size(mut self, bytes: usize) -> Self {
self.config.max_decoding_message_size = Some(bytes);
self
}
#[must_use]
pub fn max_encoding_message_size(mut self, bytes: usize) -> Self {
self.config.max_encoding_message_size = Some(bytes);
self
}
#[must_use]
pub fn default_context(mut self, context: RequestContext) -> Self {
self.config.default_context = context;
self
}
#[must_use]
pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry = policy;
self
}
#[must_use]
pub fn tls(mut self, tls: TlsConfig) -> Self {
self.config.tls = Some(tls);
self
}
#[must_use]
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.config.metadata.push((key.into(), value.into()));
self
}
#[must_use]
pub fn bearer_token(self, token: impl AsRef<str>) -> Self {
self.metadata("authorization", format!("Bearer {}", token.as_ref()))
}
#[must_use]
pub fn api_key(self, header: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata(header, value)
}
pub async fn connect(self) -> Result<OrleansClient, OrleansError> {
OrleansClient::build(self.config, self.retry).await
}
}
#[cfg(feature = "tls")]
#[allow(clippy::result_large_err)]
fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
use tonic::transport::{Certificate, ClientTlsConfig, Identity};
let Some(tls) = tls else {
return Ok(endpoint);
};
let mut tls_config = ClientTlsConfig::new();
match &tls.ca_certificate_pem {
Some(ca) => tls_config = tls_config.ca_certificate(Certificate::from_pem(ca)),
None => tls_config = tls_config.with_webpki_roots(),
}
if let Some(domain) = &tls.domain_name {
tls_config = tls_config.domain_name(domain.clone());
}
if let Some((certificate, key)) = &tls.client_identity_pem {
tls_config = tls_config.identity(Identity::from_pem(certificate, key));
}
endpoint.tls_config(tls_config).map_err(OrleansError::from)
}
#[cfg(not(feature = "tls"))]
#[allow(clippy::result_large_err)]
fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
if tls.is_some() {
return Err(OrleansError::InvalidConfig(
"TLS was configured but the `tls` cargo feature is not enabled".to_owned(),
));
}
Ok(endpoint)
}
#[allow(clippy::result_large_err)]
fn build_metadata(
entries: &[(String, String)],
) -> Result<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>, OrleansError> {
let mut out = Vec::with_capacity(entries.len());
for (key, value) in entries {
let parsed_key = MetadataKey::<Ascii>::from_bytes(key.to_ascii_lowercase().as_bytes())
.map_err(|_| OrleansError::InvalidConfig(format!("invalid metadata key: {key:?}")))?;
let parsed_value = AsciiMetadataValue::try_from(value.as_str()).map_err(|_| {
OrleansError::InvalidConfig(format!("invalid metadata value for {key:?}"))
})?;
out.push((parsed_key, parsed_value));
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builds_valid_metadata() {
let entries = vec![
("authorization".to_owned(), "Bearer abc.def".to_owned()),
("x-api-key".to_owned(), "key123".to_owned()),
];
let built = build_metadata(&entries).expect("valid metadata");
assert_eq!(built.len(), 2);
assert_eq!(built[0].0.as_str(), "authorization");
}
#[test]
fn lowercases_header_names() {
let entries = vec![("Authorization".to_owned(), "Bearer t".to_owned())];
let built = build_metadata(&entries).unwrap();
assert_eq!(built[0].0.as_str(), "authorization");
}
#[test]
fn rejects_invalid_key() {
let entries = vec![("bad key".to_owned(), "v".to_owned())];
let error = build_metadata(&entries).unwrap_err();
assert!(matches!(error, OrleansError::InvalidConfig(_)));
}
#[test]
fn rejects_invalid_value() {
let entries = vec![("authorization".to_owned(), "bad\nvalue".to_owned())];
let error = build_metadata(&entries).unwrap_err();
assert!(matches!(error, OrleansError::InvalidConfig(_)));
}
}