Skip to main content

armada_client/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10    CancellationResult, EventStreamMessage, JobCancelRequest, JobSetCancelRequest, JobSetRequest,
11    JobSubmitRequest, JobSubmitResponse, event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections
21/// and [`ArmadaClient::connect_tls`] for production clusters behind TLS.
22/// Both accept any [`TokenProvider`] — pass [`crate::StaticTokenProvider`] for
23/// a static bearer token or supply your own implementation for dynamic auth.
24///
25/// ```no_run
26/// # use armada_client::{ArmadaClient, StaticTokenProvider};
27/// # async fn example() -> Result<(), armada_client::Error> {
28/// // Plaintext
29/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
30///     .await?;
31///
32/// // TLS (uses system root certificates)
33/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
34///     .await?;
35/// # Ok(())
36/// # }
37/// ```
38///
39/// # Cloning
40///
41/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
42/// connection pool — cloning is `O(1)` and is the correct way to distribute the
43/// client across tasks:
44///
45/// ```no_run
46/// # use armada_client::{ArmadaClient, StaticTokenProvider};
47/// # async fn example() -> Result<(), armada_client::Error> {
48/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
49///     .await?;
50///
51/// let c1 = client.clone();
52/// let c2 = client.clone();
53/// tokio::spawn(async move { /* use c1 */ });
54/// tokio::spawn(async move { /* use c2 */ });
55/// # Ok(())
56/// # }
57/// ```
58///
59/// # Timeouts
60///
61/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
62/// every RPC is governed by the deadline for its **entire duration**. For
63/// unary calls (`submit`) the deadline covers the round-trip. For streaming
64/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
65/// elapses while events are still being received the stream is cancelled with
66/// a `DEADLINE_EXCEEDED` status.
67#[derive(Clone)]
68pub struct ArmadaClient {
69    submit_client: SubmitClient<Channel>,
70    event_client: EventClient<Channel>,
71    token_provider: Arc<dyn TokenProvider + Send + Sync>,
72    timeout: Option<Duration>,
73}
74
75impl ArmadaClient {
76    /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
77    ///
78    /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
83    /// [`Error::Transport`] if the connection cannot be established.
84    pub async fn connect(
85        endpoint: impl Into<String>,
86        token_provider: impl TokenProvider + 'static,
87    ) -> Result<Self, Error> {
88        let channel = Channel::from_shared(endpoint.into())
89            .map_err(|e| Error::InvalidUri(e.to_string()))?
90            .connect()
91            .await?;
92        Ok(Self::from_parts(channel, token_provider))
93    }
94
95    /// Connect to an Armada server at `endpoint` using TLS.
96    ///
97    /// Uses the system's native root certificates to verify the server
98    /// certificate. `endpoint` should use the `https://` scheme,
99    /// e.g. `"https://armada.example.com:443"`.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
104    /// [`Error::Transport`] if TLS configuration or the connection fails.
105    pub async fn connect_tls(
106        endpoint: impl Into<String>,
107        token_provider: impl TokenProvider + 'static,
108    ) -> Result<Self, Error> {
109        let channel = Channel::from_shared(endpoint.into())
110            .map_err(|e| Error::InvalidUri(e.to_string()))?
111            .tls_config(ClientTlsConfig::new())?
112            .connect()
113            .await?;
114        Ok(Self::from_parts(channel, token_provider))
115    }
116
117    fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
118        Self {
119            submit_client: SubmitClient::new(channel.clone()),
120            event_client: EventClient::new(channel),
121            token_provider: Arc::new(token_provider),
122            timeout: None,
123        }
124    }
125
126    /// Set a default timeout applied to every RPC call.
127    ///
128    /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
129    /// `DEADLINE_EXCEEDED` status. For streaming calls like
130    /// [`ArmadaClient::watch`], the deadline covers the **entire stream
131    /// duration** — if it elapses while events are still arriving the stream
132    /// is cancelled immediately.
133    ///
134    /// Returns `self` so the call can be chained directly after construction:
135    ///
136    /// ```no_run
137    /// # use std::time::Duration;
138    /// # use armada_client::{ArmadaClient, StaticTokenProvider};
139    /// # async fn example() -> Result<(), armada_client::Error> {
140    /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
141    ///     .await?
142    ///     .with_timeout(Duration::from_secs(30));
143    /// # Ok(())
144    /// # }
145    /// ```
146    pub fn with_timeout(mut self, timeout: Duration) -> Self {
147        self.timeout = Some(timeout);
148        self
149    }
150
151    fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
152        if let Some(t) = self.timeout {
153            req.set_timeout(t);
154        }
155    }
156
157    /// Submit a batch of jobs to Armada.
158    ///
159    /// Attaches an `authorization` header on every call using the configured
160    /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
161    /// a single request — they are all submitted atomically to the same queue
162    /// and job set.
163    ///
164    /// # Example
165    ///
166    /// ```no_run
167    /// use armada_client::{
168    ///     ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
169    /// };
170    /// use armada_client::k8s::io::api::core::v1::PodSpec;
171    ///
172    /// # async fn example() -> Result<(), armada_client::Error> {
173    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
174    /// let item = JobRequestItemBuilder::new()
175    ///     .namespace("default")
176    ///     .pod_spec(PodSpec { containers: vec![], ..Default::default() })
177    ///     .build();
178    ///
179    /// let response = client
180    ///     .submit(JobSubmitRequest {
181    ///         queue: "my-queue".into(),
182    ///         job_set_id: "my-job-set".into(),
183    ///         job_request_items: vec![item],
184    ///     })
185    ///     .await?;
186    ///
187    /// for r in &response.job_response_items {
188    ///     if r.error.is_empty() {
189    ///         println!("submitted: {}", r.job_id);
190    ///     } else {
191    ///         eprintln!("rejected: {}", r.error);
192    ///     }
193    /// }
194    /// # Ok(())
195    /// # }
196    /// ```
197    ///
198    /// # Errors
199    ///
200    /// - [`Error::Auth`] if the token provider fails.
201    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
202    /// - [`Error::Grpc`] if the server returns a non-OK status.
203    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
204    pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
205        let token = self.token_provider.token().await?;
206        let mut req = tonic::Request::new(request);
207        if !token.is_empty() {
208            req.metadata_mut().insert("authorization", token.parse()?);
209        }
210        self.apply_timeout(&mut req);
211        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
212        // shared channel, not the underlying connection.
213        let resp = self.submit_client.clone().submit_jobs(req).await?;
214        Ok(resp.into_inner())
215    }
216
217    /// Cancel one or more jobs.
218    ///
219    /// # Arguments
220    ///
221    /// * `request.queue` — Armada queue name.
222    /// * `request.job_set_id` — Job set the jobs belong to.
223    /// * `request.job_id` — Single job ID to cancel (legacy, optional).
224    /// * `request.job_ids` — Multiple job IDs to cancel in one call.
225    /// * `request.reason` — Human-readable cancellation reason (optional).
226    ///
227    /// # Example
228    ///
229    /// ```no_run
230    /// use armada_client::{ArmadaClient, JobCancelRequest, StaticTokenProvider};
231    ///
232    /// # async fn example() -> Result<(), armada_client::Error> {
233    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
234    /// let result = client
235    ///     .cancel_jobs(JobCancelRequest {
236    ///         queue: "my-queue".into(),
237    ///         job_set_id: "my-job-set".into(),
238    ///         job_ids: vec!["01abc".into(), "01def".into()],
239    ///         reason: "no longer needed".into(),
240    ///         ..Default::default()
241    ///     })
242    ///     .await?;
243    ///
244    /// println!("cancelled: {:?}", result.cancelled_ids);
245    /// # Ok(())
246    /// # }
247    /// ```
248    ///
249    /// # Errors
250    ///
251    /// - [`Error::Auth`] if the token provider fails.
252    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
253    /// - [`Error::Grpc`] if the server returns a non-OK status.
254    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
255    pub async fn cancel_jobs(
256        &self,
257        request: JobCancelRequest,
258    ) -> Result<CancellationResult, Error> {
259        let token = self.token_provider.token().await?;
260        let mut req = tonic::Request::new(request);
261        if !token.is_empty() {
262            req.metadata_mut().insert("authorization", token.parse()?);
263        }
264        self.apply_timeout(&mut req);
265        let resp = self.submit_client.clone().cancel_jobs(req).await?;
266        Ok(resp.into_inner())
267    }
268
269    /// Cancel all (or a filtered subset of) jobs in a job set.
270    ///
271    /// # Arguments
272    ///
273    /// * `request.queue` — Armada queue name.
274    /// * `request.job_set_id` — Job set to cancel.
275    /// * `request.filter` — Optional [`crate::JobSetFilter`] limiting cancellation to
276    ///   jobs in specific states (e.g. only `Queued` and `Running`). Pass
277    ///   `None` to cancel all non-terminal jobs.
278    /// * `request.reason` — Human-readable cancellation reason (optional).
279    ///
280    /// # Example
281    ///
282    /// ```no_run
283    /// use armada_client::{ArmadaClient, JobSetCancelRequest, JobSetFilter, JobState, StaticTokenProvider};
284    ///
285    /// # async fn example() -> Result<(), armada_client::Error> {
286    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
287    /// client
288    ///     .cancel_job_set(JobSetCancelRequest {
289    ///         queue: "my-queue".into(),
290    ///         job_set_id: "my-job-set".into(),
291    ///         filter: Some(JobSetFilter {
292    ///             states: vec![JobState::Queued as i32, JobState::Running as i32],
293    ///         }),
294    ///         reason: "aborting experiment".into(),
295    ///     })
296    ///     .await?;
297    /// # Ok(())
298    /// # }
299    /// ```
300    ///
301    /// # Errors
302    ///
303    /// - [`Error::Auth`] if the token provider fails.
304    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
305    /// - [`Error::Grpc`] if the server returns a non-OK status.
306    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
307    pub async fn cancel_job_set(&self, request: JobSetCancelRequest) -> Result<(), Error> {
308        let token = self.token_provider.token().await?;
309        let mut req = tonic::Request::new(request);
310        if !token.is_empty() {
311            req.metadata_mut().insert("authorization", token.parse()?);
312        }
313        self.apply_timeout(&mut req);
314        self.submit_client.clone().cancel_job_set(req).await?;
315        Ok(())
316    }
317
318    /// Watch a job set, returning a stream of events.
319    ///
320    /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
321    /// yields [`EventStreamMessage`] values as the server pushes them. The
322    /// stream ends when the server closes the connection.
323    ///
324    /// **Reconnection is the caller's responsibility.** Store the last
325    /// `message_id` you received and pass it back as `from_message_id` when
326    /// reconnecting to avoid replaying events you have already processed.
327    ///
328    /// # Arguments
329    ///
330    /// * `queue` — Armada queue name.
331    /// * `job_set_id` — Job set to watch.
332    /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
333    ///   receive only events that occurred after `id`; pass `None` to receive
334    ///   all events from the beginning.
335    ///
336    /// # Example
337    ///
338    /// ```no_run
339    /// use futures::StreamExt;
340    /// use armada_client::{ArmadaClient, StaticTokenProvider};
341    ///
342    /// # async fn example() -> Result<(), armada_client::Error> {
343    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
344    /// let mut stream = client
345    ///     .watch("my-queue", "my-job-set", None)
346    ///     .await?;
347    ///
348    /// let mut last_id = String::new();
349    /// while let Some(result) = stream.next().await {
350    ///     match result {
351    ///         Ok(msg) => {
352    ///             last_id = msg.id.clone();
353    ///             println!("event id={} message={:?}", msg.id, msg.message);
354    ///         }
355    ///         Err(e) => {
356    ///             eprintln!("stream error: {e}");
357    ///             break; // reconnect using `from_message_id: Some(last_id)`
358    ///         }
359    ///     }
360    /// }
361    /// # Ok(())
362    /// # }
363    /// ```
364    ///
365    /// # Errors
366    ///
367    /// - [`Error::Auth`] if the token provider fails.
368    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
369    /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
370    ///   The stream will not error simply because the job set does not exist yet —
371    ///   it will wait for events, which avoids races when `watch` is called
372    ///   immediately after `submit`.
373    /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
374    ///   sends a trailing error status.
375    #[instrument(skip_all, fields(queue, job_set_id))]
376    pub async fn watch(
377        &self,
378        queue: impl Into<String>,
379        job_set_id: impl Into<String>,
380        from_message_id: Option<String>,
381    ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
382        let queue: String = queue.into();
383        let job_set_id: String = job_set_id.into();
384        tracing::Span::current()
385            .record("queue", queue.as_str())
386            .record("job_set_id", job_set_id.as_str());
387
388        let token = self.token_provider.token().await?;
389        let job_set_request = JobSetRequest {
390            id: job_set_id,
391            queue,
392            from_message_id: from_message_id.unwrap_or_default(),
393            // Keep the stream open for new events.
394            watch: true,
395            // Do not fail immediately if the job set does not exist yet — this
396            // avoids a race between submit() and watch() where the server has
397            // not yet created the job set by the time the watch RPC arrives.
398            error_if_missing: false,
399        };
400        let mut req = tonic::Request::new(job_set_request);
401        if !token.is_empty() {
402            req.metadata_mut().insert("authorization", token.parse()?);
403        }
404        self.apply_timeout(&mut req);
405        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
406        // shared channel, not the underlying connection.
407        let stream = self
408            .event_client
409            .clone()
410            .get_job_set_events(req)
411            .await?
412            .into_inner();
413        Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
414    }
415}