1use std::collections::HashMap;
4use std::fmt;
5use std::sync::Arc;
6
7use tokio::sync::Mutex;
8
9use crate::error::ClientError;
10use crate::handle::WorkflowHandle;
11use crate::ops::StartFingerprint;
12use crate::transport::{GrpcWorkflowTransport, WorkflowTransport};
13
14#[derive(Clone)]
34pub struct Client {
35 pub(crate) transport: Arc<dyn WorkflowTransport>,
36 pub(crate) config: ClientConfig,
37 idempotent_starts: Arc<Mutex<HashMap<String, (StartFingerprint, WorkflowHandle)>>>,
38}
39
40impl Client {
41 #[must_use]
43 pub fn builder(endpoint: impl Into<String>) -> ClientBuilder {
44 ClientBuilder::new(endpoint)
45 }
46
47 pub(crate) fn from_transport(
48 config: ClientConfig,
49 transport: Arc<dyn WorkflowTransport>,
50 ) -> Self {
51 Self {
52 transport,
53 config,
54 idempotent_starts: Arc::new(Mutex::new(HashMap::new())),
55 }
56 }
57
58 #[cfg(feature = "embedded")]
59 #[must_use]
61 pub fn embedded(engine: Arc<aion::Engine>) -> Self {
62 let config = ClientConfig {
63 endpoint: String::from("embedded://engine"),
64 stream_endpoint: None,
65 auth: None,
66 tls: None,
67 namespace: String::from("default"),
68 subject: None,
69 authorized_namespaces: Vec::new(),
70 };
71 Self::from_transport(
72 config,
73 Arc::new(crate::transport::EmbeddedWorkflowTransport::new(engine)),
74 )
75 }
76
77 pub(crate) fn namespace(&self) -> &str {
78 &self.config.namespace
79 }
80
81 pub(crate) async fn cached_start(
82 &self,
83 fingerprint: &StartFingerprint,
84 ) -> Result<Option<WorkflowHandle>, ClientError> {
85 let cache = self.idempotent_starts.lock().await;
86 let Some((cached_fingerprint, handle)) = cache.get(fingerprint.key()) else {
87 return Ok(None);
88 };
89 if cached_fingerprint == fingerprint {
90 Ok(Some(handle.clone()))
91 } else {
92 Err(idempotency_conflict())
93 }
94 }
95
96 pub(crate) async fn record_start(
97 &self,
98 fingerprint: StartFingerprint,
99 handle: WorkflowHandle,
100 ) -> Result<(), ClientError> {
101 let mut cache = self.idempotent_starts.lock().await;
102 match cache.get(fingerprint.key()) {
103 Some((cached_fingerprint, _)) if cached_fingerprint == &fingerprint => Ok(()),
104 Some(_) => Err(idempotency_conflict()),
105 None => {
106 cache.insert(fingerprint.key().to_owned(), (fingerprint, handle));
107 Ok(())
108 }
109 }
110 }
111}
112
113fn idempotency_conflict() -> ClientError {
116 ClientError::already_exists(
117 "idempotency key was already used by a different start request \
118 (namespace, workflow type, or input differ)",
119 )
120}
121
122#[derive(Clone, Debug)]
124pub struct ClientBuilder {
125 endpoint: String,
126 stream_endpoint: Option<String>,
127 auth: Option<ClientAuth>,
128 tls: Option<TlsOptions>,
129 namespace: String,
130 subject: Option<String>,
131 authorized_namespaces: Vec<String>,
132}
133
134impl ClientBuilder {
135 #[must_use]
137 pub fn new(endpoint: impl Into<String>) -> Self {
138 Self {
139 endpoint: endpoint.into(),
140 stream_endpoint: None,
141 auth: None,
142 tls: None,
143 namespace: String::from("default"),
144 subject: None,
145 authorized_namespaces: Vec::new(),
146 }
147 }
148
149 #[must_use]
159 pub fn with_stream_endpoint(mut self, stream_endpoint: impl Into<String>) -> Self {
160 self.stream_endpoint = Some(stream_endpoint.into());
161 self
162 }
163
164 #[must_use]
166 pub fn with_auth(mut self, auth: ClientAuth) -> Self {
167 self.auth = Some(auth);
168 self
169 }
170
171 #[must_use]
173 pub fn with_tls(mut self, tls: TlsOptions) -> Self {
174 self.tls = Some(tls);
175 self
176 }
177
178 #[must_use]
180 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
181 self.namespace = namespace.into();
182 self
183 }
184
185 #[must_use]
187 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
188 self.subject = Some(subject.into());
189 self
190 }
191
192 #[must_use]
194 pub fn with_authorized_namespaces<I, S>(mut self, namespaces: I) -> Self
195 where
196 I: IntoIterator<Item = S>,
197 S: Into<String>,
198 {
199 self.authorized_namespaces = namespaces.into_iter().map(Into::into).collect();
200 self
201 }
202
203 pub async fn build(self) -> Result<Client, ClientError> {
211 let config = ClientConfig::from(self);
212 let transport = GrpcWorkflowTransport::connect(config.clone()).await?;
213 Ok(Client::from_transport(config, Arc::new(transport)))
214 }
215}
216
217#[derive(Clone, PartialEq, Eq)]
219pub struct ClientAuth {
220 token: String,
221}
222
223impl ClientAuth {
224 #[must_use]
226 pub fn bearer(token: impl Into<String>) -> Self {
227 Self {
228 token: token.into(),
229 }
230 }
231
232 pub(crate) fn token(&self) -> &str {
233 &self.token
234 }
235}
236
237impl fmt::Debug for ClientAuth {
238 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
239 formatter
240 .debug_struct("ClientAuth")
241 .field("token", &"<redacted>")
242 .finish()
243 }
244}
245
246#[derive(Clone, Debug, Default, PartialEq, Eq)]
248pub struct TlsOptions {
249 pub(crate) domain_name: Option<String>,
250 pub(crate) ca_certificate_pem: Option<Vec<u8>>,
251}
252
253impl TlsOptions {
254 #[must_use]
256 pub fn new() -> Self {
257 Self::default()
258 }
259
260 #[must_use]
262 pub fn with_domain_name(mut self, domain_name: impl Into<String>) -> Self {
263 self.domain_name = Some(domain_name.into());
264 self
265 }
266
267 #[must_use]
269 pub fn with_ca_certificate_pem(mut self, ca_certificate_pem: impl Into<Vec<u8>>) -> Self {
270 self.ca_certificate_pem = Some(ca_certificate_pem.into());
271 self
272 }
273}
274
275#[derive(Clone, Debug, PartialEq, Eq)]
277pub struct ClientConfig {
278 pub(crate) endpoint: String,
279 pub(crate) stream_endpoint: Option<String>,
280 pub(crate) auth: Option<ClientAuth>,
281 pub(crate) tls: Option<TlsOptions>,
282 pub(crate) namespace: String,
283 pub(crate) subject: Option<String>,
284 pub(crate) authorized_namespaces: Vec<String>,
285}
286
287impl From<ClientBuilder> for ClientConfig {
288 fn from(builder: ClientBuilder) -> Self {
289 Self {
290 endpoint: builder.endpoint,
291 stream_endpoint: builder.stream_endpoint,
292 auth: builder.auth,
293 tls: builder.tls,
294 namespace: builder.namespace,
295 subject: builder.subject,
296 authorized_namespaces: builder.authorized_namespaces,
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::{Client, ClientAuth, ClientBuilder, ClientConfig, TlsOptions};
304
305 fn assert_send_sync<T: Send + Sync>() {}
306
307 #[test]
308 fn client_is_clone_send_sync() {
309 assert_send_sync::<Client>();
310 }
311
312 #[test]
313 fn auth_debug_redacts_token() {
314 let auth = ClientAuth::bearer("secret-token");
315 assert_eq!(format!("{auth:?}"), "ClientAuth { token: \"<redacted>\" }");
316 }
317
318 #[test]
319 fn builder_captures_connection_options() {
320 let config = ClientConfig::from(
321 ClientBuilder::new("https://aion.example.com")
322 .with_stream_endpoint("wss://aion-http.example.com/events/stream")
323 .with_auth(ClientAuth::bearer("secret-token"))
324 .with_tls(TlsOptions::new().with_domain_name("aion.example.com"))
325 .with_namespace("tenant-a")
326 .with_subject("alice")
327 .with_authorized_namespaces(["tenant-a", "tenant-b"]),
328 );
329
330 assert_eq!(config.endpoint, "https://aion.example.com");
331 assert_eq!(
332 config.stream_endpoint,
333 Some(String::from("wss://aion-http.example.com/events/stream"))
334 );
335 assert!(config.auth.is_some());
336 assert!(config.tls.is_some());
337 assert_eq!(config.namespace, "tenant-a");
338 assert_eq!(config.subject, Some(String::from("alice")));
339 assert_eq!(
340 config.authorized_namespaces,
341 vec![String::from("tenant-a"), String::from("tenant-b")]
342 );
343 }
344
345 #[test]
346 fn stream_endpoint_has_no_default() {
347 let config = ClientConfig::from(ClientBuilder::new("https://aion.example.com"));
348 assert_eq!(config.stream_endpoint, None);
349 }
350}