Skip to main content

orleans_rust_client/
client.rs

1//! The gRPC client over the Orleans bridge.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tonic::metadata::{Ascii, AsciiMetadataValue, MetadataKey};
8use tonic::transport::{Channel, Endpoint};
9
10use crate::config::{ClientConfig, TlsConfig};
11use crate::error::OrleansError;
12use crate::generated::pb;
13use crate::grain::GrainRef;
14use crate::key::GrainKey;
15use crate::request_context::RequestContext;
16use crate::retry::RetryPolicy;
17
18type BridgeClient = pb::orleans_bridge_client::OrleansBridgeClient<Channel>;
19
20/// A cheaply-cloneable handle to an Orleans bridge.
21///
22/// Cloning shares the underlying gRPC channel and configuration, so a single
23/// connected client can be shared across tasks.
24#[derive(Clone)]
25pub struct OrleansClient {
26    inner: BridgeClient,
27    config: Arc<ClientConfig>,
28    retry: Arc<RetryPolicy>,
29    metadata: Arc<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>>,
30}
31
32/// Borrowed parameters for a single raw invocation.
33pub(crate) struct InvokeCall<'a> {
34    pub interface_name: &'a str,
35    pub grain_type: &'a str,
36    pub key: &'a GrainKey,
37    pub method: &'a str,
38    pub payload: Vec<u8>,
39    pub codec: &'a str,
40    pub context: &'a RequestContext,
41    pub timeout: Option<Duration>,
42}
43
44/// The raw result of an [`OrleansClient`] invocation: opaque payload bytes plus
45/// any response-context entries the grain produced.
46#[derive(Debug, Clone)]
47pub struct RawResponse {
48    /// Codec-encoded response bytes.
49    pub payload: Vec<u8>,
50    /// The codec the bridge used to encode `payload`.
51    pub codec: String,
52    /// Response-context entries returned by the bridge.
53    pub response_context: HashMap<String, String>,
54}
55
56impl OrleansClient {
57    /// Connect to a bridge at `endpoint` using default settings.
58    ///
59    /// # Errors
60    /// Returns [`OrleansError::Transport`] if the channel cannot be
61    /// established, or [`OrleansError::InvalidConfig`] for a malformed
62    /// endpoint.
63    pub async fn connect(endpoint: impl Into<String>) -> Result<Self, OrleansError> {
64        Self::from_config(ClientConfig::new(endpoint)).await
65    }
66
67    /// Start building a client with non-default settings.
68    #[must_use]
69    pub fn builder(endpoint: impl Into<String>) -> OrleansClientBuilder {
70        OrleansClientBuilder::new(endpoint)
71    }
72
73    /// Connect using an explicit [`ClientConfig`] and no retries.
74    ///
75    /// # Errors
76    /// See [`OrleansClient::connect`].
77    pub async fn from_config(config: ClientConfig) -> Result<Self, OrleansError> {
78        Self::build(config, RetryPolicy::disabled()).await
79    }
80
81    async fn build(config: ClientConfig, retry: RetryPolicy) -> Result<Self, OrleansError> {
82        let mut endpoint = Endpoint::from_shared(config.endpoint.clone())
83            .map_err(|e| OrleansError::InvalidConfig(format!("invalid endpoint: {e}")))?;
84        if let Some(connect_timeout) = config.connect_timeout {
85            endpoint = endpoint.connect_timeout(connect_timeout);
86        }
87        endpoint = configure_tls(endpoint, config.tls.as_ref())?;
88
89        let metadata = build_metadata(&config.metadata)?;
90
91        let channel = endpoint.connect().await?;
92        let mut client = BridgeClient::new(channel);
93        if let Some(n) = config.max_decoding_message_size {
94            client = client.max_decoding_message_size(n);
95        }
96        if let Some(n) = config.max_encoding_message_size {
97            client = client.max_encoding_message_size(n);
98        }
99
100        Ok(Self {
101            inner: client,
102            config: Arc::new(config),
103            retry: Arc::new(retry),
104            metadata: Arc::new(metadata),
105        })
106    }
107
108    /// Wrap a message in a request carrying the client's configured metadata
109    /// (e.g. an `authorization` header).
110    fn request<T>(&self, message: T) -> tonic::Request<T> {
111        let mut request = tonic::Request::new(message);
112        let metadata = request.metadata_mut();
113        for (key, value) in self.metadata.iter() {
114            metadata.insert(key.clone(), value.clone());
115        }
116        request
117    }
118
119    /// The configuration this client was built with.
120    #[must_use]
121    pub fn config(&self) -> &ClientConfig {
122        &self.config
123    }
124
125    /// Query bridge and cluster identity.
126    ///
127    /// # Errors
128    /// Returns an [`OrleansError`] if the bridge is unreachable.
129    pub async fn health(&self) -> Result<pb::HealthResponse, OrleansError> {
130        let mut client = self.inner.clone();
131        let response = client
132            .health(self.request(pb::HealthRequest {}))
133            .await
134            .map_err(OrleansError::from_status)?;
135        Ok(response.into_inner())
136    }
137
138    /// Fetch the contract manifest describing dispatchable grains.
139    ///
140    /// # Errors
141    /// Returns an [`OrleansError`] if the bridge is unreachable.
142    pub async fn manifest(&self) -> Result<pb::ContractManifest, OrleansError> {
143        let mut client = self.inner.clone();
144        let response = client
145            .get_manifest(self.request(pb::GetManifestRequest {}))
146            .await
147            .map_err(OrleansError::from_status)?;
148        Ok(response.into_inner().manifest.unwrap_or_default())
149    }
150
151    /// Obtain a reference to a specific grain.
152    #[must_use]
153    pub fn grain(
154        &self,
155        interface_name: impl Into<String>,
156        grain_type: impl Into<String>,
157        key: impl Into<GrainKey>,
158    ) -> GrainRef {
159        GrainRef::new(
160            self.clone(),
161            interface_name.into(),
162            grain_type.into(),
163            key.into(),
164        )
165    }
166
167    pub(crate) async fn invoke_raw(
168        &self,
169        call: InvokeCall<'_>,
170    ) -> Result<RawResponse, OrleansError> {
171        let effective_timeout = call.timeout.unwrap_or(self.config.default_timeout);
172        let target = pb::GrainTarget {
173            interface_name: call.interface_name.to_owned(),
174            grain_type: call.grain_type.to_owned(),
175            key: Some(call.key.to_proto()),
176        };
177        let context_map = call.context.clone().into_map();
178
179        let mut attempt: u32 = 0;
180        loop {
181            let request = pb::InvokeRequest {
182                target: Some(target.clone()),
183                method: call.method.to_owned(),
184                payload: call.payload.clone(),
185                payload_codec: call.codec.to_owned(),
186                request_context: context_map.clone(),
187                timeout_ms: u32::try_from(effective_timeout.as_millis()).unwrap_or(u32::MAX),
188            };
189
190            match self.invoke_once(request, effective_timeout).await {
191                Ok(response) => return Ok(response),
192                Err(error) => {
193                    let can_retry = self.retry.is_enabled()
194                        && attempt < self.retry.max_retries
195                        && error.is_retryable();
196                    if !can_retry {
197                        return Err(error);
198                    }
199                    let backoff = self.retry.backoff_for(attempt + 1);
200                    if !backoff.is_zero() {
201                        tokio::time::sleep(backoff).await;
202                    }
203                    attempt += 1;
204                }
205            }
206        }
207    }
208
209    async fn invoke_once(
210        &self,
211        message: pb::InvokeRequest,
212        timeout: Duration,
213    ) -> Result<RawResponse, OrleansError> {
214        let mut client = self.inner.clone();
215        // The deadline is enforced server-side via `InvokeRequest.timeout_ms`,
216        // which lets the bridge return a structured `orleans_timeout`. We do
217        // not set a gRPC deadline here: tonic would surface its own expiry as a
218        // `Cancelled` status ("Timeout expired"), masking the richer error.
219        // Instead we apply a slightly longer client-side backstop so a hung
220        // connection still fails rather than hanging forever.
221        let request = self.request(message);
222        let guard = timeout.saturating_add(Duration::from_secs(5));
223        let call = client.invoke(request);
224        let result = match tokio::time::timeout(guard, call).await {
225            Ok(result) => result,
226            Err(_) => return Err(OrleansError::Timeout),
227        };
228
229        match result {
230            Ok(response) => {
231                let inner = response.into_inner();
232                Ok(RawResponse {
233                    payload: inner.payload,
234                    codec: inner.payload_codec,
235                    response_context: inner.response_context,
236                })
237            }
238            Err(status) => Err(OrleansError::from_status(status)),
239        }
240    }
241}
242
243/// Builder for [`OrleansClient`] with non-default connection settings.
244pub struct OrleansClientBuilder {
245    config: ClientConfig,
246    retry: RetryPolicy,
247}
248
249impl OrleansClientBuilder {
250    fn new(endpoint: impl Into<String>) -> Self {
251        Self {
252            config: ClientConfig::new(endpoint),
253            retry: RetryPolicy::disabled(),
254        }
255    }
256
257    /// Set the default per-call deadline.
258    #[must_use]
259    pub fn default_timeout(mut self, timeout: Duration) -> Self {
260        self.config.default_timeout = timeout;
261        self
262    }
263
264    /// Set the channel connect timeout.
265    #[must_use]
266    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
267        self.config.connect_timeout = Some(timeout);
268        self
269    }
270
271    /// Set the maximum decodable response size in bytes.
272    #[must_use]
273    pub fn max_decoding_message_size(mut self, bytes: usize) -> Self {
274        self.config.max_decoding_message_size = Some(bytes);
275        self
276    }
277
278    /// Set the maximum encodable request size in bytes.
279    #[must_use]
280    pub fn max_encoding_message_size(mut self, bytes: usize) -> Self {
281        self.config.max_encoding_message_size = Some(bytes);
282        self
283    }
284
285    /// Set request-context entries applied to every call.
286    #[must_use]
287    pub fn default_context(mut self, context: RequestContext) -> Self {
288        self.config.default_context = context;
289        self
290    }
291
292    /// Enable a retry policy (disabled by default).
293    #[must_use]
294    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
295        self.retry = policy;
296        self
297    }
298
299    /// Configure transport security (see [`TlsConfig`]).
300    #[must_use]
301    pub fn tls(mut self, tls: TlsConfig) -> Self {
302        self.config.tls = Some(tls);
303        self
304    }
305
306    /// Attach a static gRPC metadata header to every request. The key must be
307    /// a valid ASCII header name and the value valid ASCII; both are validated
308    /// when the client is built.
309    #[must_use]
310    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
311        self.config.metadata.push((key.into(), value.into()));
312        self
313    }
314
315    /// Attach an `authorization: Bearer <token>` header to every request, for a
316    /// JWT-validating proxy in front of the bridge.
317    #[must_use]
318    pub fn bearer_token(self, token: impl AsRef<str>) -> Self {
319        self.metadata("authorization", format!("Bearer {}", token.as_ref()))
320    }
321
322    /// Attach an API-key header (e.g. `x-api-key`) to every request.
323    #[must_use]
324    pub fn api_key(self, header: impl Into<String>, value: impl Into<String>) -> Self {
325        self.metadata(header, value)
326    }
327
328    /// Connect using the accumulated settings.
329    ///
330    /// # Errors
331    /// See [`OrleansClient::connect`].
332    pub async fn connect(self) -> Result<OrleansClient, OrleansError> {
333        OrleansClient::build(self.config, self.retry).await
334    }
335}
336
337// Cold path (runs once at connect time), so returning the large error enum by
338// value is fine.
339#[cfg(feature = "tls")]
340#[allow(clippy::result_large_err)]
341fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
342    use tonic::transport::{Certificate, ClientTlsConfig, Identity};
343
344    let Some(tls) = tls else {
345        return Ok(endpoint);
346    };
347
348    let mut tls_config = ClientTlsConfig::new();
349    match &tls.ca_certificate_pem {
350        Some(ca) => tls_config = tls_config.ca_certificate(Certificate::from_pem(ca)),
351        None => tls_config = tls_config.with_webpki_roots(),
352    }
353    if let Some(domain) = &tls.domain_name {
354        tls_config = tls_config.domain_name(domain.clone());
355    }
356    if let Some((certificate, key)) = &tls.client_identity_pem {
357        tls_config = tls_config.identity(Identity::from_pem(certificate, key));
358    }
359
360    endpoint.tls_config(tls_config).map_err(OrleansError::from)
361}
362
363#[cfg(not(feature = "tls"))]
364#[allow(clippy::result_large_err)]
365fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
366    if tls.is_some() {
367        return Err(OrleansError::InvalidConfig(
368            "TLS was configured but the `tls` cargo feature is not enabled".to_owned(),
369        ));
370    }
371    Ok(endpoint)
372}
373
374// Cold path (runs once at connect time), so returning the large error enum by
375// value is fine.
376#[allow(clippy::result_large_err)]
377fn build_metadata(
378    entries: &[(String, String)],
379) -> Result<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>, OrleansError> {
380    let mut out = Vec::with_capacity(entries.len());
381    for (key, value) in entries {
382        let parsed_key = MetadataKey::<Ascii>::from_bytes(key.to_ascii_lowercase().as_bytes())
383            .map_err(|_| OrleansError::InvalidConfig(format!("invalid metadata key: {key:?}")))?;
384        let parsed_value = AsciiMetadataValue::try_from(value.as_str()).map_err(|_| {
385            OrleansError::InvalidConfig(format!("invalid metadata value for {key:?}"))
386        })?;
387        out.push((parsed_key, parsed_value));
388    }
389    Ok(out)
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn builds_valid_metadata() {
398        let entries = vec![
399            ("authorization".to_owned(), "Bearer abc.def".to_owned()),
400            ("x-api-key".to_owned(), "key123".to_owned()),
401        ];
402        let built = build_metadata(&entries).expect("valid metadata");
403        assert_eq!(built.len(), 2);
404        assert_eq!(built[0].0.as_str(), "authorization");
405    }
406
407    #[test]
408    fn lowercases_header_names() {
409        let entries = vec![("Authorization".to_owned(), "Bearer t".to_owned())];
410        let built = build_metadata(&entries).unwrap();
411        assert_eq!(built[0].0.as_str(), "authorization");
412    }
413
414    #[test]
415    fn rejects_invalid_key() {
416        let entries = vec![("bad key".to_owned(), "v".to_owned())];
417        let error = build_metadata(&entries).unwrap_err();
418        assert!(matches!(error, OrleansError::InvalidConfig(_)));
419    }
420
421    #[test]
422    fn rejects_invalid_value() {
423        let entries = vec![("authorization".to_owned(), "bad\nvalue".to_owned())];
424        let error = build_metadata(&entries).unwrap_err();
425        assert!(matches!(error, OrleansError::InvalidConfig(_)));
426    }
427}