Skip to main content

aion_client/
client.rs

1//! `Client` and `ClientBuilder` connection, auth, and TLS support.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::Arc;
6
7use tokio::sync::Mutex;
8
9use crate::error::ClientError;
10use crate::handle::WorkflowHandle;
11use crate::ops::StartFingerprint;
12use crate::transport::{GrpcWorkflowTransport, WorkflowTransport};
13
14/// Reusable caller-side SDK client for an `aion-server` deployment.
15///
16/// # Examples
17///
18/// ```no_run
19/// # async fn connect() -> Result<(), aion_client::ClientError> {
20/// use aion_client::{ClientAuth, ClientBuilder};
21///
22/// let client = ClientBuilder::new("https://aion.example.com")
23///     .with_auth(ClientAuth::bearer("secret-token"))
24///     .with_namespace("tenant-a")
25///     .build()
26///     .await?;
27///
28/// let shared = client.clone();
29/// # let _ = shared;
30/// # Ok(())
31/// # }
32/// ```
33#[derive(Clone)]
34pub struct Client {
35    pub(crate) transport: Arc<dyn WorkflowTransport>,
36    pub(crate) config: ClientConfig,
37    idempotent_starts: Arc<Mutex<HashMap<String, (StartFingerprint, WorkflowHandle)>>>,
38}
39
40impl Client {
41    /// Creates a builder for an `aion-server` endpoint.
42    #[must_use]
43    pub fn builder(endpoint: impl Into<String>) -> ClientBuilder {
44        ClientBuilder::new(endpoint)
45    }
46
47    pub(crate) fn from_transport(
48        config: ClientConfig,
49        transport: Arc<dyn WorkflowTransport>,
50    ) -> Self {
51        Self {
52            transport,
53            config,
54            idempotent_starts: Arc::new(Mutex::new(HashMap::new())),
55        }
56    }
57
58    #[cfg(feature = "embedded")]
59    /// Creates a client backed by an in-process embedded engine.
60    #[must_use]
61    pub fn embedded(engine: Arc<aion::Engine>) -> Self {
62        let config = ClientConfig {
63            endpoint: String::from("embedded://engine"),
64            stream_endpoint: None,
65            auth: None,
66            tls: None,
67            namespace: String::from("default"),
68            subject: None,
69            authorized_namespaces: Vec::new(),
70        };
71        Self::from_transport(
72            config,
73            Arc::new(crate::transport::EmbeddedWorkflowTransport::new(engine)),
74        )
75    }
76
77    pub(crate) fn namespace(&self) -> &str {
78        &self.config.namespace
79    }
80
81    pub(crate) async fn cached_start(
82        &self,
83        fingerprint: &StartFingerprint,
84    ) -> Result<Option<WorkflowHandle>, ClientError> {
85        let cache = self.idempotent_starts.lock().await;
86        let Some((cached_fingerprint, handle)) = cache.get(fingerprint.key()) else {
87            return Ok(None);
88        };
89        if cached_fingerprint == fingerprint {
90            Ok(Some(handle.clone()))
91        } else {
92            Err(idempotency_conflict())
93        }
94    }
95
96    pub(crate) async fn record_start(
97        &self,
98        fingerprint: StartFingerprint,
99        handle: WorkflowHandle,
100    ) -> Result<(), ClientError> {
101        let mut cache = self.idempotent_starts.lock().await;
102        match cache.get(fingerprint.key()) {
103            Some((cached_fingerprint, _)) if cached_fingerprint == &fingerprint => Ok(()),
104            Some(_) => Err(idempotency_conflict()),
105            None => {
106                cache.insert(fingerprint.key().to_owned(), (fingerprint, handle));
107                Ok(())
108            }
109        }
110    }
111}
112
113/// The SDK-boundary idempotency conflict: the same key was reused with a
114/// different start request.
115fn idempotency_conflict() -> ClientError {
116    ClientError::already_exists(
117        "idempotency key was already used by a different start request \
118         (namespace, workflow type, or input differ)",
119    )
120}
121
122/// Builder for [`Client`] connection, authentication, and TLS options.
123#[derive(Clone, Debug)]
124pub struct ClientBuilder {
125    endpoint: String,
126    stream_endpoint: Option<String>,
127    auth: Option<ClientAuth>,
128    tls: Option<TlsOptions>,
129    namespace: String,
130    subject: Option<String>,
131    authorized_namespaces: Vec<String>,
132}
133
134impl ClientBuilder {
135    /// Creates a builder for the supplied server endpoint.
136    #[must_use]
137    pub fn new(endpoint: impl Into<String>) -> Self {
138        Self {
139            endpoint: endpoint.into(),
140            stream_endpoint: None,
141            auth: None,
142            tls: None,
143            namespace: String::from("default"),
144            subject: None,
145            authorized_namespaces: Vec::new(),
146        }
147    }
148
149    /// Configures the WebSocket event-stream endpoint used by subscribe
150    /// operations: the full URL of the server's `/events/stream` route, e.g.
151    /// `ws://127.0.0.1:8080/events/stream` (`http`/`https` URLs are accepted
152    /// and protocol-mapped to `ws`/`wss`).
153    ///
154    /// There is no default and nothing is derived: the gRPC endpoint and the
155    /// HTTP/WebSocket listener are separate addresses. Subscribing without
156    /// this option returns [`ClientError::InvalidArgument`] with a precise
157    /// message.
158    #[must_use]
159    pub fn with_stream_endpoint(mut self, stream_endpoint: impl Into<String>) -> Self {
160        self.stream_endpoint = Some(stream_endpoint.into());
161        self
162    }
163
164    /// Configures the credential attached to every request.
165    #[must_use]
166    pub fn with_auth(mut self, auth: ClientAuth) -> Self {
167        self.auth = Some(auth);
168        self
169    }
170
171    /// Configures TLS options for the tonic channel.
172    #[must_use]
173    pub fn with_tls(mut self, tls: TlsOptions) -> Self {
174        self.tls = Some(tls);
175        self
176    }
177
178    /// Configures the namespace used by operations unless an operation option overrides it.
179    #[must_use]
180    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
181        self.namespace = namespace.into();
182        self
183    }
184
185    /// Configures the caller subject metadata sent to the server.
186    #[must_use]
187    pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
188        self.subject = Some(subject.into());
189        self
190    }
191
192    /// Configures the namespaces advertised in auth metadata.
193    #[must_use]
194    pub fn with_authorized_namespaces<I, S>(mut self, namespaces: I) -> Self
195    where
196        I: IntoIterator<Item = S>,
197        S: Into<String>,
198    {
199        self.authorized_namespaces = namespaces.into_iter().map(Into::into).collect();
200        self
201    }
202
203    /// Connects once and returns a cheaply cloneable [`Client`].
204    ///
205    /// # Errors
206    ///
207    /// Returns [`ClientError::Unavailable`] for malformed endpoints and failed
208    /// channel/TLS handshakes. Server-side credential rejection is surfaced as
209    /// [`ClientError::Unauthenticated`] when AW returns gRPC `Unauthenticated`.
210    pub async fn build(self) -> Result<Client, ClientError> {
211        let config = ClientConfig::from(self);
212        let transport = GrpcWorkflowTransport::connect(config.clone()).await?;
213        Ok(Client::from_transport(config, Arc::new(transport)))
214    }
215}
216
217/// Bearer authentication credential for server calls.
218#[derive(Clone, PartialEq, Eq)]
219pub struct ClientAuth {
220    token: String,
221}
222
223impl ClientAuth {
224    /// Creates a bearer-token credential.
225    #[must_use]
226    pub fn bearer(token: impl Into<String>) -> Self {
227        Self {
228            token: token.into(),
229        }
230    }
231
232    pub(crate) fn token(&self) -> &str {
233        &self.token
234    }
235}
236
237impl fmt::Debug for ClientAuth {
238    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
239        formatter
240            .debug_struct("ClientAuth")
241            .field("token", &"<redacted>")
242            .finish()
243    }
244}
245
246/// TLS options for connecting to an HTTPS/TLS endpoint.
247#[derive(Clone, Debug, Default, PartialEq, Eq)]
248pub struct TlsOptions {
249    pub(crate) domain_name: Option<String>,
250    pub(crate) ca_certificate_pem: Option<Vec<u8>>,
251}
252
253impl TlsOptions {
254    /// Creates empty TLS options using platform/webpki roots.
255    #[must_use]
256    pub fn new() -> Self {
257        Self::default()
258    }
259
260    /// Overrides the TLS domain name checked during handshake.
261    #[must_use]
262    pub fn with_domain_name(mut self, domain_name: impl Into<String>) -> Self {
263        self.domain_name = Some(domain_name.into());
264        self
265    }
266
267    /// Adds a PEM-encoded CA certificate trusted for this connection.
268    #[must_use]
269    pub fn with_ca_certificate_pem(mut self, ca_certificate_pem: impl Into<Vec<u8>>) -> Self {
270        self.ca_certificate_pem = Some(ca_certificate_pem.into());
271        self
272    }
273}
274
275/// Fully resolved client connection configuration.
276#[derive(Clone, Debug, PartialEq, Eq)]
277pub struct ClientConfig {
278    pub(crate) endpoint: String,
279    pub(crate) stream_endpoint: Option<String>,
280    pub(crate) auth: Option<ClientAuth>,
281    pub(crate) tls: Option<TlsOptions>,
282    pub(crate) namespace: String,
283    pub(crate) subject: Option<String>,
284    pub(crate) authorized_namespaces: Vec<String>,
285}
286
287impl From<ClientBuilder> for ClientConfig {
288    fn from(builder: ClientBuilder) -> Self {
289        Self {
290            endpoint: builder.endpoint,
291            stream_endpoint: builder.stream_endpoint,
292            auth: builder.auth,
293            tls: builder.tls,
294            namespace: builder.namespace,
295            subject: builder.subject,
296            authorized_namespaces: builder.authorized_namespaces,
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::{Client, ClientAuth, ClientBuilder, ClientConfig, TlsOptions};
304
305    fn assert_send_sync<T: Send + Sync>() {}
306
307    #[test]
308    fn client_is_clone_send_sync() {
309        assert_send_sync::<Client>();
310    }
311
312    #[test]
313    fn auth_debug_redacts_token() {
314        let auth = ClientAuth::bearer("secret-token");
315        assert_eq!(format!("{auth:?}"), "ClientAuth { token: \"<redacted>\" }");
316    }
317
318    #[test]
319    fn builder_captures_connection_options() {
320        let config = ClientConfig::from(
321            ClientBuilder::new("https://aion.example.com")
322                .with_stream_endpoint("wss://aion-http.example.com/events/stream")
323                .with_auth(ClientAuth::bearer("secret-token"))
324                .with_tls(TlsOptions::new().with_domain_name("aion.example.com"))
325                .with_namespace("tenant-a")
326                .with_subject("alice")
327                .with_authorized_namespaces(["tenant-a", "tenant-b"]),
328        );
329
330        assert_eq!(config.endpoint, "https://aion.example.com");
331        assert_eq!(
332            config.stream_endpoint,
333            Some(String::from("wss://aion-http.example.com/events/stream"))
334        );
335        assert!(config.auth.is_some());
336        assert!(config.tls.is_some());
337        assert_eq!(config.namespace, "tenant-a");
338        assert_eq!(config.subject, Some(String::from("alice")));
339        assert_eq!(
340            config.authorized_namespaces,
341            vec![String::from("tenant-a"), String::from("tenant-b")]
342        );
343    }
344
345    #[test]
346    fn stream_endpoint_has_no_default() {
347        let config = ClientConfig::from(ClientBuilder::new("https://aion.example.com"));
348        assert_eq!(config.stream_endpoint, None);
349    }
350}