Skip to main content

armada_client/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10    CancellationResult, EventStreamMessage, JobCancelRequest, JobSetCancelRequest, JobSetRequest,
11    JobSubmitRequest, JobSubmitResponse, event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections,
21/// [`ArmadaClient::connect_tls`] for production clusters using system root certificates,
22/// or [`ArmadaClient::connect_tls_with_config`] when you need a custom CA, domain
23/// override, or mutual TLS. All constructors accept any [`TokenProvider`] — pass
24/// [`crate::StaticTokenProvider`] for a static bearer token or supply your own
25/// implementation for dynamic auth.
26///
27/// ```no_run
28/// # use armada_client::{ArmadaClient, StaticTokenProvider};
29/// # use armada_client::tonic::transport::{Certificate, ClientTlsConfig};
30/// # async fn example() -> Result<(), armada_client::Error> {
31/// // Plaintext
32/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
33///     .await?;
34///
35/// // TLS (uses system root certificates)
36/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
37///     .await?;
38///
39/// // TLS with a custom CA
40/// let pem = std::fs::read("ca.pem")?;
41/// let client = ArmadaClient::connect_tls_with_config(
42///     "https://armada.example.com:443",
43///     ClientTlsConfig::new().ca_certificate(Certificate::from_pem(pem)),
44///     StaticTokenProvider::new("tok"),
45/// ).await?;
46/// # Ok(())
47/// # }
48/// ```
49///
50/// # Cloning
51///
52/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
53/// connection pool — cloning is `O(1)` and is the correct way to distribute the
54/// client across tasks:
55///
56/// ```no_run
57/// # use armada_client::{ArmadaClient, StaticTokenProvider};
58/// # async fn example() -> Result<(), armada_client::Error> {
59/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
60///     .await?;
61///
62/// let c1 = client.clone();
63/// let c2 = client.clone();
64/// tokio::spawn(async move { /* use c1 */ });
65/// tokio::spawn(async move { /* use c2 */ });
66/// # Ok(())
67/// # }
68/// ```
69///
70/// # Timeouts
71///
72/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
73/// every RPC is governed by the deadline for its **entire duration**. For
74/// unary calls (`submit`) the deadline covers the round-trip. For streaming
75/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
76/// elapses while events are still being received the stream is cancelled with
77/// a `DEADLINE_EXCEEDED` status.
78#[derive(Clone)]
79pub struct ArmadaClient {
80    submit_client: SubmitClient<Channel>,
81    event_client: EventClient<Channel>,
82    token_provider: Arc<dyn TokenProvider + Send + Sync>,
83    timeout: Option<Duration>,
84}
85
86impl ArmadaClient {
87    /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
88    ///
89    /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
94    /// [`Error::Transport`] if the connection cannot be established.
95    pub async fn connect(
96        endpoint: impl Into<String>,
97        token_provider: impl TokenProvider + 'static,
98    ) -> Result<Self, Error> {
99        let channel = Channel::from_shared(endpoint.into())
100            .map_err(|e| Error::InvalidUri(e.to_string()))?
101            .connect()
102            .await?;
103        Ok(Self::from_parts(channel, token_provider))
104    }
105
106    /// Connect to an Armada server at `endpoint` using TLS.
107    ///
108    /// Uses the system's native root certificates to verify the server
109    /// certificate. `endpoint` should use the `https://` scheme,
110    /// e.g. `"https://armada.example.com:443"`.
111    ///
112    /// For clusters with a private or self-signed CA, use
113    /// [`ArmadaClient::connect_tls_with_config`] instead.
114    ///
115    /// # Errors
116    ///
117    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
118    /// [`Error::Transport`] if TLS configuration or the connection fails.
119    pub async fn connect_tls(
120        endpoint: impl Into<String>,
121        token_provider: impl TokenProvider + 'static,
122    ) -> Result<Self, Error> {
123        let channel = Channel::from_shared(endpoint.into())
124            .map_err(|e| Error::InvalidUri(e.to_string()))?
125            .tls_config(ClientTlsConfig::new())?
126            .connect()
127            .await?;
128        Ok(Self::from_parts(channel, token_provider))
129    }
130
131    /// Connect to an Armada server at `endpoint` using a caller-supplied TLS config.
132    ///
133    /// Use this when you need to supply a custom CA certificate (e.g. a private or
134    /// self-signed CA), override the server domain name, or configure mutual TLS.
135    /// Build the config with [`tonic::transport::ClientTlsConfig`], accessible via
136    /// `armada_client::tonic::transport::ClientTlsConfig` — no direct tonic dependency needed.
137    ///
138    /// # Example — custom CA
139    ///
140    /// ```no_run
141    /// use armada_client::{ArmadaClient, StaticTokenProvider};
142    /// use armada_client::tonic::transport::{Certificate, ClientTlsConfig};
143    ///
144    /// # async fn example() -> Result<(), armada_client::Error> {
145    /// let pem = std::fs::read("ca.pem")?;
146    /// let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(pem));
147    /// let client = ArmadaClient::connect_tls_with_config(
148    ///     "https://armada.example.com:443",
149    ///     tls,
150    ///     StaticTokenProvider::new("tok"),
151    /// ).await?;
152    /// # Ok(())
153    /// # }
154    /// ```
155    ///
156    /// # Errors
157    ///
158    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
159    /// [`Error::Transport`] if TLS configuration or the connection fails.
160    pub async fn connect_tls_with_config(
161        endpoint: impl Into<String>,
162        tls_config: ClientTlsConfig,
163        token_provider: impl TokenProvider + 'static,
164    ) -> Result<Self, Error> {
165        let channel = Channel::from_shared(endpoint.into())
166            .map_err(|e| Error::InvalidUri(e.to_string()))?
167            .tls_config(tls_config)?
168            .connect()
169            .await?;
170        Ok(Self::from_parts(channel, token_provider))
171    }
172
173    fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
174        Self {
175            submit_client: SubmitClient::new(channel.clone()),
176            event_client: EventClient::new(channel),
177            token_provider: Arc::new(token_provider),
178            timeout: None,
179        }
180    }
181
182    /// Set a default timeout applied to every RPC call.
183    ///
184    /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
185    /// `DEADLINE_EXCEEDED` status. For streaming calls like
186    /// [`ArmadaClient::watch`], the deadline covers the **entire stream
187    /// duration** — if it elapses while events are still arriving the stream
188    /// is cancelled immediately.
189    ///
190    /// Returns `self` so the call can be chained directly after construction:
191    ///
192    /// ```no_run
193    /// # use std::time::Duration;
194    /// # use armada_client::{ArmadaClient, StaticTokenProvider};
195    /// # async fn example() -> Result<(), armada_client::Error> {
196    /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
197    ///     .await?
198    ///     .with_timeout(Duration::from_secs(30));
199    /// # Ok(())
200    /// # }
201    /// ```
202    pub fn with_timeout(mut self, timeout: Duration) -> Self {
203        self.timeout = Some(timeout);
204        self
205    }
206
207    fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
208        if let Some(t) = self.timeout {
209            req.set_timeout(t);
210        }
211    }
212
213    /// Submit a batch of jobs to Armada.
214    ///
215    /// Attaches an `authorization` header on every call using the configured
216    /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
217    /// a single request — they are all submitted atomically to the same queue
218    /// and job set.
219    ///
220    /// # Example
221    ///
222    /// ```no_run
223    /// use armada_client::{
224    ///     ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
225    /// };
226    /// use armada_client::k8s::io::api::core::v1::PodSpec;
227    ///
228    /// # async fn example() -> Result<(), armada_client::Error> {
229    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
230    /// let item = JobRequestItemBuilder::new()
231    ///     .namespace("default")
232    ///     .pod_spec(PodSpec { containers: vec![], ..Default::default() })
233    ///     .build();
234    ///
235    /// let response = client
236    ///     .submit(JobSubmitRequest {
237    ///         queue: "my-queue".into(),
238    ///         job_set_id: "my-job-set".into(),
239    ///         job_request_items: vec![item],
240    ///     })
241    ///     .await?;
242    ///
243    /// for r in &response.job_response_items {
244    ///     if r.error.is_empty() {
245    ///         println!("submitted: {}", r.job_id);
246    ///     } else {
247    ///         eprintln!("rejected: {}", r.error);
248    ///     }
249    /// }
250    /// # Ok(())
251    /// # }
252    /// ```
253    ///
254    /// # Errors
255    ///
256    /// - [`Error::Auth`] if the token provider fails.
257    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
258    /// - [`Error::Grpc`] if the server returns a non-OK status.
259    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
260    pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
261        let token = self.token_provider.token().await?;
262        let mut req = tonic::Request::new(request);
263        if !token.is_empty() {
264            req.metadata_mut().insert("authorization", token.parse()?);
265        }
266        self.apply_timeout(&mut req);
267        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
268        // shared channel, not the underlying connection.
269        let resp = self.submit_client.clone().submit_jobs(req).await?;
270        Ok(resp.into_inner())
271    }
272
273    /// Cancel one or more jobs.
274    ///
275    /// # Arguments
276    ///
277    /// * `request.queue` — Armada queue name.
278    /// * `request.job_set_id` — Job set the jobs belong to.
279    /// * `request.job_id` — Single job ID to cancel (legacy, optional).
280    /// * `request.job_ids` — Multiple job IDs to cancel in one call.
281    /// * `request.reason` — Human-readable cancellation reason (optional).
282    ///
283    /// # Example
284    ///
285    /// ```no_run
286    /// use armada_client::{ArmadaClient, JobCancelRequest, StaticTokenProvider};
287    ///
288    /// # async fn example() -> Result<(), armada_client::Error> {
289    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
290    /// let result = client
291    ///     .cancel_jobs(JobCancelRequest {
292    ///         queue: "my-queue".into(),
293    ///         job_set_id: "my-job-set".into(),
294    ///         job_ids: vec!["01abc".into(), "01def".into()],
295    ///         reason: "no longer needed".into(),
296    ///         ..Default::default()
297    ///     })
298    ///     .await?;
299    ///
300    /// println!("cancelled: {:?}", result.cancelled_ids);
301    /// # Ok(())
302    /// # }
303    /// ```
304    ///
305    /// # Errors
306    ///
307    /// - [`Error::Auth`] if the token provider fails.
308    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
309    /// - [`Error::Grpc`] if the server returns a non-OK status.
310    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
311    pub async fn cancel_jobs(
312        &self,
313        request: JobCancelRequest,
314    ) -> Result<CancellationResult, Error> {
315        let token = self.token_provider.token().await?;
316        let mut req = tonic::Request::new(request);
317        if !token.is_empty() {
318            req.metadata_mut().insert("authorization", token.parse()?);
319        }
320        self.apply_timeout(&mut req);
321        let resp = self.submit_client.clone().cancel_jobs(req).await?;
322        Ok(resp.into_inner())
323    }
324
325    /// Cancel all (or a filtered subset of) jobs in a job set.
326    ///
327    /// # Arguments
328    ///
329    /// * `request.queue` — Armada queue name.
330    /// * `request.job_set_id` — Job set to cancel.
331    /// * `request.filter` — Optional [`crate::JobSetFilter`] limiting cancellation to
332    ///   jobs in specific states (e.g. only `Queued` and `Running`). Pass
333    ///   `None` to cancel all non-terminal jobs.
334    /// * `request.reason` — Human-readable cancellation reason (optional).
335    ///
336    /// # Example
337    ///
338    /// ```no_run
339    /// use armada_client::{ArmadaClient, JobSetCancelRequest, JobSetFilter, JobState, StaticTokenProvider};
340    ///
341    /// # async fn example() -> Result<(), armada_client::Error> {
342    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
343    /// client
344    ///     .cancel_job_set(JobSetCancelRequest {
345    ///         queue: "my-queue".into(),
346    ///         job_set_id: "my-job-set".into(),
347    ///         filter: Some(JobSetFilter {
348    ///             states: vec![JobState::Queued as i32, JobState::Running as i32],
349    ///         }),
350    ///         reason: "aborting experiment".into(),
351    ///     })
352    ///     .await?;
353    /// # Ok(())
354    /// # }
355    /// ```
356    ///
357    /// # Errors
358    ///
359    /// - [`Error::Auth`] if the token provider fails.
360    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
361    /// - [`Error::Grpc`] if the server returns a non-OK status.
362    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
363    pub async fn cancel_job_set(&self, request: JobSetCancelRequest) -> Result<(), Error> {
364        let token = self.token_provider.token().await?;
365        let mut req = tonic::Request::new(request);
366        if !token.is_empty() {
367            req.metadata_mut().insert("authorization", token.parse()?);
368        }
369        self.apply_timeout(&mut req);
370        self.submit_client.clone().cancel_job_set(req).await?;
371        Ok(())
372    }
373
374    /// Watch a job set, returning a stream of events.
375    ///
376    /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
377    /// yields [`EventStreamMessage`] values as the server pushes them. The
378    /// stream ends when the server closes the connection.
379    ///
380    /// **Reconnection is the caller's responsibility.** Store the last
381    /// `message_id` you received and pass it back as `from_message_id` when
382    /// reconnecting to avoid replaying events you have already processed.
383    ///
384    /// # Arguments
385    ///
386    /// * `queue` — Armada queue name.
387    /// * `job_set_id` — Job set to watch.
388    /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
389    ///   receive only events that occurred after `id`; pass `None` to receive
390    ///   all events from the beginning.
391    ///
392    /// # Example
393    ///
394    /// ```no_run
395    /// use futures::StreamExt;
396    /// use armada_client::{ArmadaClient, StaticTokenProvider};
397    ///
398    /// # async fn example() -> Result<(), armada_client::Error> {
399    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
400    /// let mut stream = client
401    ///     .watch("my-queue", "my-job-set", None)
402    ///     .await?;
403    ///
404    /// let mut last_id = String::new();
405    /// while let Some(result) = stream.next().await {
406    ///     match result {
407    ///         Ok(msg) => {
408    ///             last_id = msg.id.clone();
409    ///             println!("event id={} message={:?}", msg.id, msg.message);
410    ///         }
411    ///         Err(e) => {
412    ///             eprintln!("stream error: {e}");
413    ///             break; // reconnect using `from_message_id: Some(last_id)`
414    ///         }
415    ///     }
416    /// }
417    /// # Ok(())
418    /// # }
419    /// ```
420    ///
421    /// # Errors
422    ///
423    /// - [`Error::Auth`] if the token provider fails.
424    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
425    /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
426    ///   The stream will not error simply because the job set does not exist yet —
427    ///   it will wait for events, which avoids races when `watch` is called
428    ///   immediately after `submit`.
429    /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
430    ///   sends a trailing error status.
431    #[instrument(skip_all, fields(queue, job_set_id))]
432    pub async fn watch(
433        &self,
434        queue: impl Into<String>,
435        job_set_id: impl Into<String>,
436        from_message_id: Option<String>,
437    ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
438        let queue: String = queue.into();
439        let job_set_id: String = job_set_id.into();
440        tracing::Span::current()
441            .record("queue", queue.as_str())
442            .record("job_set_id", job_set_id.as_str());
443
444        let token = self.token_provider.token().await?;
445        let job_set_request = JobSetRequest {
446            id: job_set_id,
447            queue,
448            from_message_id: from_message_id.unwrap_or_default(),
449            // Keep the stream open for new events.
450            watch: true,
451            // Do not fail immediately if the job set does not exist yet — this
452            // avoids a race between submit() and watch() where the server has
453            // not yet created the job set by the time the watch RPC arrives.
454            error_if_missing: false,
455        };
456        let mut req = tonic::Request::new(job_set_request);
457        if !token.is_empty() {
458            req.metadata_mut().insert("authorization", token.parse()?);
459        }
460        self.apply_timeout(&mut req);
461        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
462        // shared channel, not the underlying connection.
463        let stream = self
464            .event_client
465            .clone()
466            .get_job_set_events(req)
467            .await?
468            .into_inner();
469        Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
470    }
471}