1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tonic::metadata::{Ascii, AsciiMetadataValue, MetadataKey};
8use tonic::transport::{Channel, Endpoint};
9
10use crate::config::{ClientConfig, TlsConfig};
11use crate::error::OrleansError;
12use crate::generated::pb;
13use crate::grain::GrainRef;
14use crate::key::GrainKey;
15use crate::request_context::RequestContext;
16use crate::retry::RetryPolicy;
17
18type BridgeClient = pb::orleans_bridge_client::OrleansBridgeClient<Channel>;
19
20#[derive(Clone)]
25pub struct OrleansClient {
26 inner: BridgeClient,
27 config: Arc<ClientConfig>,
28 retry: Arc<RetryPolicy>,
29 metadata: Arc<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>>,
30}
31
32pub(crate) struct InvokeCall<'a> {
34 pub interface_name: &'a str,
35 pub grain_type: &'a str,
36 pub key: &'a GrainKey,
37 pub method: &'a str,
38 pub payload: Vec<u8>,
39 pub codec: &'a str,
40 pub context: &'a RequestContext,
41 pub timeout: Option<Duration>,
42}
43
44#[derive(Debug, Clone)]
47pub struct RawResponse {
48 pub payload: Vec<u8>,
50 pub codec: String,
52 pub response_context: HashMap<String, String>,
54}
55
56impl OrleansClient {
57 pub async fn connect(endpoint: impl Into<String>) -> Result<Self, OrleansError> {
64 Self::from_config(ClientConfig::new(endpoint)).await
65 }
66
67 #[must_use]
69 pub fn builder(endpoint: impl Into<String>) -> OrleansClientBuilder {
70 OrleansClientBuilder::new(endpoint)
71 }
72
73 pub async fn from_config(config: ClientConfig) -> Result<Self, OrleansError> {
78 Self::build(config, RetryPolicy::disabled()).await
79 }
80
81 async fn build(config: ClientConfig, retry: RetryPolicy) -> Result<Self, OrleansError> {
82 let mut endpoint = Endpoint::from_shared(config.endpoint.clone())
83 .map_err(|e| OrleansError::InvalidConfig(format!("invalid endpoint: {e}")))?;
84 if let Some(connect_timeout) = config.connect_timeout {
85 endpoint = endpoint.connect_timeout(connect_timeout);
86 }
87 endpoint = configure_tls(endpoint, config.tls.as_ref())?;
88
89 let metadata = build_metadata(&config.metadata)?;
90
91 let channel = endpoint.connect().await?;
92 let mut client = BridgeClient::new(channel);
93 if let Some(n) = config.max_decoding_message_size {
94 client = client.max_decoding_message_size(n);
95 }
96 if let Some(n) = config.max_encoding_message_size {
97 client = client.max_encoding_message_size(n);
98 }
99
100 Ok(Self {
101 inner: client,
102 config: Arc::new(config),
103 retry: Arc::new(retry),
104 metadata: Arc::new(metadata),
105 })
106 }
107
108 fn request<T>(&self, message: T) -> tonic::Request<T> {
111 let mut request = tonic::Request::new(message);
112 let metadata = request.metadata_mut();
113 for (key, value) in self.metadata.iter() {
114 metadata.insert(key.clone(), value.clone());
115 }
116 request
117 }
118
119 #[must_use]
121 pub fn config(&self) -> &ClientConfig {
122 &self.config
123 }
124
125 pub async fn health(&self) -> Result<pb::HealthResponse, OrleansError> {
130 let mut client = self.inner.clone();
131 let response = client
132 .health(self.request(pb::HealthRequest {}))
133 .await
134 .map_err(OrleansError::from_status)?;
135 Ok(response.into_inner())
136 }
137
138 pub async fn manifest(&self) -> Result<pb::ContractManifest, OrleansError> {
143 let mut client = self.inner.clone();
144 let response = client
145 .get_manifest(self.request(pb::GetManifestRequest {}))
146 .await
147 .map_err(OrleansError::from_status)?;
148 Ok(response.into_inner().manifest.unwrap_or_default())
149 }
150
151 #[must_use]
153 pub fn grain(
154 &self,
155 interface_name: impl Into<String>,
156 grain_type: impl Into<String>,
157 key: impl Into<GrainKey>,
158 ) -> GrainRef {
159 GrainRef::new(
160 self.clone(),
161 interface_name.into(),
162 grain_type.into(),
163 key.into(),
164 )
165 }
166
167 pub(crate) async fn invoke_raw(
168 &self,
169 call: InvokeCall<'_>,
170 ) -> Result<RawResponse, OrleansError> {
171 let effective_timeout = call.timeout.unwrap_or(self.config.default_timeout);
172 let target = pb::GrainTarget {
173 interface_name: call.interface_name.to_owned(),
174 grain_type: call.grain_type.to_owned(),
175 key: Some(call.key.to_proto()),
176 };
177 let context_map = call.context.clone().into_map();
178
179 let mut attempt: u32 = 0;
180 loop {
181 let request = pb::InvokeRequest {
182 target: Some(target.clone()),
183 method: call.method.to_owned(),
184 payload: call.payload.clone(),
185 payload_codec: call.codec.to_owned(),
186 request_context: context_map.clone(),
187 timeout_ms: u32::try_from(effective_timeout.as_millis()).unwrap_or(u32::MAX),
188 };
189
190 match self.invoke_once(request, effective_timeout).await {
191 Ok(response) => return Ok(response),
192 Err(error) => {
193 let can_retry = self.retry.is_enabled()
194 && attempt < self.retry.max_retries
195 && error.is_retryable();
196 if !can_retry {
197 return Err(error);
198 }
199 let backoff = self.retry.backoff_for(attempt + 1);
200 if !backoff.is_zero() {
201 tokio::time::sleep(backoff).await;
202 }
203 attempt += 1;
204 }
205 }
206 }
207 }
208
209 async fn invoke_once(
210 &self,
211 message: pb::InvokeRequest,
212 timeout: Duration,
213 ) -> Result<RawResponse, OrleansError> {
214 let mut client = self.inner.clone();
215 let request = self.request(message);
222 let guard = timeout.saturating_add(Duration::from_secs(5));
223 let call = client.invoke(request);
224 let result = match tokio::time::timeout(guard, call).await {
225 Ok(result) => result,
226 Err(_) => return Err(OrleansError::Timeout),
227 };
228
229 match result {
230 Ok(response) => {
231 let inner = response.into_inner();
232 Ok(RawResponse {
233 payload: inner.payload,
234 codec: inner.payload_codec,
235 response_context: inner.response_context,
236 })
237 }
238 Err(status) => Err(OrleansError::from_status(status)),
239 }
240 }
241}
242
243pub struct OrleansClientBuilder {
245 config: ClientConfig,
246 retry: RetryPolicy,
247}
248
249impl OrleansClientBuilder {
250 fn new(endpoint: impl Into<String>) -> Self {
251 Self {
252 config: ClientConfig::new(endpoint),
253 retry: RetryPolicy::disabled(),
254 }
255 }
256
257 #[must_use]
259 pub fn default_timeout(mut self, timeout: Duration) -> Self {
260 self.config.default_timeout = timeout;
261 self
262 }
263
264 #[must_use]
266 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
267 self.config.connect_timeout = Some(timeout);
268 self
269 }
270
271 #[must_use]
273 pub fn max_decoding_message_size(mut self, bytes: usize) -> Self {
274 self.config.max_decoding_message_size = Some(bytes);
275 self
276 }
277
278 #[must_use]
280 pub fn max_encoding_message_size(mut self, bytes: usize) -> Self {
281 self.config.max_encoding_message_size = Some(bytes);
282 self
283 }
284
285 #[must_use]
287 pub fn default_context(mut self, context: RequestContext) -> Self {
288 self.config.default_context = context;
289 self
290 }
291
292 #[must_use]
294 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
295 self.retry = policy;
296 self
297 }
298
299 #[must_use]
301 pub fn tls(mut self, tls: TlsConfig) -> Self {
302 self.config.tls = Some(tls);
303 self
304 }
305
306 #[must_use]
310 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
311 self.config.metadata.push((key.into(), value.into()));
312 self
313 }
314
315 #[must_use]
318 pub fn bearer_token(self, token: impl AsRef<str>) -> Self {
319 self.metadata("authorization", format!("Bearer {}", token.as_ref()))
320 }
321
322 #[must_use]
324 pub fn api_key(self, header: impl Into<String>, value: impl Into<String>) -> Self {
325 self.metadata(header, value)
326 }
327
328 pub async fn connect(self) -> Result<OrleansClient, OrleansError> {
333 OrleansClient::build(self.config, self.retry).await
334 }
335}
336
337#[cfg(feature = "tls")]
340#[allow(clippy::result_large_err)]
341fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
342 use tonic::transport::{Certificate, ClientTlsConfig, Identity};
343
344 let Some(tls) = tls else {
345 return Ok(endpoint);
346 };
347
348 let mut tls_config = ClientTlsConfig::new();
349 match &tls.ca_certificate_pem {
350 Some(ca) => tls_config = tls_config.ca_certificate(Certificate::from_pem(ca)),
351 None => tls_config = tls_config.with_webpki_roots(),
352 }
353 if let Some(domain) = &tls.domain_name {
354 tls_config = tls_config.domain_name(domain.clone());
355 }
356 if let Some((certificate, key)) = &tls.client_identity_pem {
357 tls_config = tls_config.identity(Identity::from_pem(certificate, key));
358 }
359
360 endpoint.tls_config(tls_config).map_err(OrleansError::from)
361}
362
363#[cfg(not(feature = "tls"))]
364#[allow(clippy::result_large_err)]
365fn configure_tls(endpoint: Endpoint, tls: Option<&TlsConfig>) -> Result<Endpoint, OrleansError> {
366 if tls.is_some() {
367 return Err(OrleansError::InvalidConfig(
368 "TLS was configured but the `tls` cargo feature is not enabled".to_owned(),
369 ));
370 }
371 Ok(endpoint)
372}
373
374#[allow(clippy::result_large_err)]
377fn build_metadata(
378 entries: &[(String, String)],
379) -> Result<Vec<(MetadataKey<Ascii>, AsciiMetadataValue)>, OrleansError> {
380 let mut out = Vec::with_capacity(entries.len());
381 for (key, value) in entries {
382 let parsed_key = MetadataKey::<Ascii>::from_bytes(key.to_ascii_lowercase().as_bytes())
383 .map_err(|_| OrleansError::InvalidConfig(format!("invalid metadata key: {key:?}")))?;
384 let parsed_value = AsciiMetadataValue::try_from(value.as_str()).map_err(|_| {
385 OrleansError::InvalidConfig(format!("invalid metadata value for {key:?}"))
386 })?;
387 out.push((parsed_key, parsed_value));
388 }
389 Ok(out)
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn builds_valid_metadata() {
398 let entries = vec![
399 ("authorization".to_owned(), "Bearer abc.def".to_owned()),
400 ("x-api-key".to_owned(), "key123".to_owned()),
401 ];
402 let built = build_metadata(&entries).expect("valid metadata");
403 assert_eq!(built.len(), 2);
404 assert_eq!(built[0].0.as_str(), "authorization");
405 }
406
407 #[test]
408 fn lowercases_header_names() {
409 let entries = vec![("Authorization".to_owned(), "Bearer t".to_owned())];
410 let built = build_metadata(&entries).unwrap();
411 assert_eq!(built[0].0.as_str(), "authorization");
412 }
413
414 #[test]
415 fn rejects_invalid_key() {
416 let entries = vec![("bad key".to_owned(), "v".to_owned())];
417 let error = build_metadata(&entries).unwrap_err();
418 assert!(matches!(error, OrleansError::InvalidConfig(_)));
419 }
420
421 #[test]
422 fn rejects_invalid_value() {
423 let entries = vec![("authorization".to_owned(), "bad\nvalue".to_owned())];
424 let error = build_metadata(&entries).unwrap_err();
425 assert!(matches!(error, OrleansError::InvalidConfig(_)));
426 }
427}