armada_client/client.rs
1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10 CancellationResult, EventStreamMessage, JobCancelRequest, JobSetCancelRequest, JobSetRequest,
11 JobSubmitRequest, JobSubmitResponse, event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections,
21/// [`ArmadaClient::connect_tls`] for production clusters using system root certificates,
22/// or [`ArmadaClient::connect_tls_with_config`] when you need a custom CA, domain
23/// override, or mutual TLS. All constructors accept any [`TokenProvider`] — pass
24/// [`crate::StaticTokenProvider`] for a static bearer token or supply your own
25/// implementation for dynamic auth.
26///
27/// ```no_run
28/// # use armada_client::{ArmadaClient, StaticTokenProvider};
29/// # use armada_client::tonic::transport::{Certificate, ClientTlsConfig};
30/// # async fn example() -> Result<(), armada_client::Error> {
31/// // Plaintext
32/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
33/// .await?;
34///
35/// // TLS (uses system root certificates)
36/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
37/// .await?;
38///
39/// // TLS with a custom CA
40/// let pem = std::fs::read("ca.pem")?;
41/// let client = ArmadaClient::connect_tls_with_config(
42/// "https://armada.example.com:443",
43/// ClientTlsConfig::new().ca_certificate(Certificate::from_pem(pem)),
44/// StaticTokenProvider::new("tok"),
45/// ).await?;
46/// # Ok(())
47/// # }
48/// ```
49///
50/// # Cloning
51///
52/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
53/// connection pool — cloning is `O(1)` and is the correct way to distribute the
54/// client across tasks:
55///
56/// ```no_run
57/// # use armada_client::{ArmadaClient, StaticTokenProvider};
58/// # async fn example() -> Result<(), armada_client::Error> {
59/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
60/// .await?;
61///
62/// let c1 = client.clone();
63/// let c2 = client.clone();
64/// tokio::spawn(async move { /* use c1 */ });
65/// tokio::spawn(async move { /* use c2 */ });
66/// # Ok(())
67/// # }
68/// ```
69///
70/// # Timeouts
71///
72/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
73/// every RPC is governed by the deadline for its **entire duration**. For
74/// unary calls (`submit`) the deadline covers the round-trip. For streaming
75/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
76/// elapses while events are still being received the stream is cancelled with
77/// a `DEADLINE_EXCEEDED` status.
78#[derive(Clone)]
79pub struct ArmadaClient {
80 submit_client: SubmitClient<Channel>,
81 event_client: EventClient<Channel>,
82 token_provider: Arc<dyn TokenProvider + Send + Sync>,
83 timeout: Option<Duration>,
84}
85
86impl ArmadaClient {
87 /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
88 ///
89 /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
90 ///
91 /// # Errors
92 ///
93 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
94 /// [`Error::Transport`] if the connection cannot be established.
95 pub async fn connect(
96 endpoint: impl Into<String>,
97 token_provider: impl TokenProvider + 'static,
98 ) -> Result<Self, Error> {
99 let channel = Channel::from_shared(endpoint.into())
100 .map_err(|e| Error::InvalidUri(e.to_string()))?
101 .connect()
102 .await?;
103 Ok(Self::from_parts(channel, token_provider))
104 }
105
106 /// Connect to an Armada server at `endpoint` using TLS.
107 ///
108 /// Uses the system's native root certificates to verify the server
109 /// certificate. `endpoint` should use the `https://` scheme,
110 /// e.g. `"https://armada.example.com:443"`.
111 ///
112 /// For clusters with a private or self-signed CA, use
113 /// [`ArmadaClient::connect_tls_with_config`] instead.
114 ///
115 /// # Errors
116 ///
117 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
118 /// [`Error::Transport`] if TLS configuration or the connection fails.
119 pub async fn connect_tls(
120 endpoint: impl Into<String>,
121 token_provider: impl TokenProvider + 'static,
122 ) -> Result<Self, Error> {
123 let channel = Channel::from_shared(endpoint.into())
124 .map_err(|e| Error::InvalidUri(e.to_string()))?
125 .tls_config(ClientTlsConfig::new())?
126 .connect()
127 .await?;
128 Ok(Self::from_parts(channel, token_provider))
129 }
130
131 /// Connect to an Armada server at `endpoint` using a caller-supplied TLS config.
132 ///
133 /// Use this when you need to supply a custom CA certificate (e.g. a private or
134 /// self-signed CA), override the server domain name, or configure mutual TLS.
135 /// Build the config with [`tonic::transport::ClientTlsConfig`], accessible via
136 /// `armada_client::tonic::transport::ClientTlsConfig` — no direct tonic dependency needed.
137 ///
138 /// # Example — custom CA
139 ///
140 /// ```no_run
141 /// use armada_client::{ArmadaClient, StaticTokenProvider};
142 /// use armada_client::tonic::transport::{Certificate, ClientTlsConfig};
143 ///
144 /// # async fn example() -> Result<(), armada_client::Error> {
145 /// let pem = std::fs::read("ca.pem")?;
146 /// let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(pem));
147 /// let client = ArmadaClient::connect_tls_with_config(
148 /// "https://armada.example.com:443",
149 /// tls,
150 /// StaticTokenProvider::new("tok"),
151 /// ).await?;
152 /// # Ok(())
153 /// # }
154 /// ```
155 ///
156 /// # Errors
157 ///
158 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
159 /// [`Error::Transport`] if TLS configuration or the connection fails.
160 pub async fn connect_tls_with_config(
161 endpoint: impl Into<String>,
162 tls_config: ClientTlsConfig,
163 token_provider: impl TokenProvider + 'static,
164 ) -> Result<Self, Error> {
165 let channel = Channel::from_shared(endpoint.into())
166 .map_err(|e| Error::InvalidUri(e.to_string()))?
167 .tls_config(tls_config)?
168 .connect()
169 .await?;
170 Ok(Self::from_parts(channel, token_provider))
171 }
172
173 fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
174 Self {
175 submit_client: SubmitClient::new(channel.clone()),
176 event_client: EventClient::new(channel),
177 token_provider: Arc::new(token_provider),
178 timeout: None,
179 }
180 }
181
182 /// Set a default timeout applied to every RPC call.
183 ///
184 /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
185 /// `DEADLINE_EXCEEDED` status. For streaming calls like
186 /// [`ArmadaClient::watch`], the deadline covers the **entire stream
187 /// duration** — if it elapses while events are still arriving the stream
188 /// is cancelled immediately.
189 ///
190 /// Returns `self` so the call can be chained directly after construction:
191 ///
192 /// ```no_run
193 /// # use std::time::Duration;
194 /// # use armada_client::{ArmadaClient, StaticTokenProvider};
195 /// # async fn example() -> Result<(), armada_client::Error> {
196 /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
197 /// .await?
198 /// .with_timeout(Duration::from_secs(30));
199 /// # Ok(())
200 /// # }
201 /// ```
202 pub fn with_timeout(mut self, timeout: Duration) -> Self {
203 self.timeout = Some(timeout);
204 self
205 }
206
207 fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
208 if let Some(t) = self.timeout {
209 req.set_timeout(t);
210 }
211 }
212
213 /// Submit a batch of jobs to Armada.
214 ///
215 /// Attaches an `authorization` header on every call using the configured
216 /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
217 /// a single request — they are all submitted atomically to the same queue
218 /// and job set.
219 ///
220 /// # Example
221 ///
222 /// ```no_run
223 /// use armada_client::{
224 /// ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
225 /// };
226 /// use armada_client::k8s::io::api::core::v1::PodSpec;
227 ///
228 /// # async fn example() -> Result<(), armada_client::Error> {
229 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
230 /// let item = JobRequestItemBuilder::new()
231 /// .namespace("default")
232 /// .pod_spec(PodSpec { containers: vec![], ..Default::default() })
233 /// .build();
234 ///
235 /// let response = client
236 /// .submit(JobSubmitRequest {
237 /// queue: "my-queue".into(),
238 /// job_set_id: "my-job-set".into(),
239 /// job_request_items: vec![item],
240 /// })
241 /// .await?;
242 ///
243 /// for r in &response.job_response_items {
244 /// if r.error.is_empty() {
245 /// println!("submitted: {}", r.job_id);
246 /// } else {
247 /// eprintln!("rejected: {}", r.error);
248 /// }
249 /// }
250 /// # Ok(())
251 /// # }
252 /// ```
253 ///
254 /// # Errors
255 ///
256 /// - [`Error::Auth`] if the token provider fails.
257 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
258 /// - [`Error::Grpc`] if the server returns a non-OK status.
259 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
260 pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
261 let token = self.token_provider.token().await?;
262 let mut req = tonic::Request::new(request);
263 if !token.is_empty() {
264 req.metadata_mut().insert("authorization", token.parse()?);
265 }
266 self.apply_timeout(&mut req);
267 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
268 // shared channel, not the underlying connection.
269 let resp = self.submit_client.clone().submit_jobs(req).await?;
270 Ok(resp.into_inner())
271 }
272
273 /// Cancel one or more jobs.
274 ///
275 /// # Arguments
276 ///
277 /// * `request.queue` — Armada queue name.
278 /// * `request.job_set_id` — Job set the jobs belong to.
279 /// * `request.job_id` — Single job ID to cancel (legacy, optional).
280 /// * `request.job_ids` — Multiple job IDs to cancel in one call.
281 /// * `request.reason` — Human-readable cancellation reason (optional).
282 ///
283 /// # Example
284 ///
285 /// ```no_run
286 /// use armada_client::{ArmadaClient, JobCancelRequest, StaticTokenProvider};
287 ///
288 /// # async fn example() -> Result<(), armada_client::Error> {
289 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
290 /// let result = client
291 /// .cancel_jobs(JobCancelRequest {
292 /// queue: "my-queue".into(),
293 /// job_set_id: "my-job-set".into(),
294 /// job_ids: vec!["01abc".into(), "01def".into()],
295 /// reason: "no longer needed".into(),
296 /// ..Default::default()
297 /// })
298 /// .await?;
299 ///
300 /// println!("cancelled: {:?}", result.cancelled_ids);
301 /// # Ok(())
302 /// # }
303 /// ```
304 ///
305 /// # Errors
306 ///
307 /// - [`Error::Auth`] if the token provider fails.
308 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
309 /// - [`Error::Grpc`] if the server returns a non-OK status.
310 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
311 pub async fn cancel_jobs(
312 &self,
313 request: JobCancelRequest,
314 ) -> Result<CancellationResult, Error> {
315 let token = self.token_provider.token().await?;
316 let mut req = tonic::Request::new(request);
317 if !token.is_empty() {
318 req.metadata_mut().insert("authorization", token.parse()?);
319 }
320 self.apply_timeout(&mut req);
321 let resp = self.submit_client.clone().cancel_jobs(req).await?;
322 Ok(resp.into_inner())
323 }
324
325 /// Cancel all (or a filtered subset of) jobs in a job set.
326 ///
327 /// # Arguments
328 ///
329 /// * `request.queue` — Armada queue name.
330 /// * `request.job_set_id` — Job set to cancel.
331 /// * `request.filter` — Optional [`crate::JobSetFilter`] limiting cancellation to
332 /// jobs in specific states (e.g. only `Queued` and `Running`). Pass
333 /// `None` to cancel all non-terminal jobs.
334 /// * `request.reason` — Human-readable cancellation reason (optional).
335 ///
336 /// # Example
337 ///
338 /// ```no_run
339 /// use armada_client::{ArmadaClient, JobSetCancelRequest, JobSetFilter, JobState, StaticTokenProvider};
340 ///
341 /// # async fn example() -> Result<(), armada_client::Error> {
342 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
343 /// client
344 /// .cancel_job_set(JobSetCancelRequest {
345 /// queue: "my-queue".into(),
346 /// job_set_id: "my-job-set".into(),
347 /// filter: Some(JobSetFilter {
348 /// states: vec![JobState::Queued as i32, JobState::Running as i32],
349 /// }),
350 /// reason: "aborting experiment".into(),
351 /// })
352 /// .await?;
353 /// # Ok(())
354 /// # }
355 /// ```
356 ///
357 /// # Errors
358 ///
359 /// - [`Error::Auth`] if the token provider fails.
360 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
361 /// - [`Error::Grpc`] if the server returns a non-OK status.
362 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
363 pub async fn cancel_job_set(&self, request: JobSetCancelRequest) -> Result<(), Error> {
364 let token = self.token_provider.token().await?;
365 let mut req = tonic::Request::new(request);
366 if !token.is_empty() {
367 req.metadata_mut().insert("authorization", token.parse()?);
368 }
369 self.apply_timeout(&mut req);
370 self.submit_client.clone().cancel_job_set(req).await?;
371 Ok(())
372 }
373
374 /// Watch a job set, returning a stream of events.
375 ///
376 /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
377 /// yields [`EventStreamMessage`] values as the server pushes them. The
378 /// stream ends when the server closes the connection.
379 ///
380 /// **Reconnection is the caller's responsibility.** Store the last
381 /// `message_id` you received and pass it back as `from_message_id` when
382 /// reconnecting to avoid replaying events you have already processed.
383 ///
384 /// # Arguments
385 ///
386 /// * `queue` — Armada queue name.
387 /// * `job_set_id` — Job set to watch.
388 /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
389 /// receive only events that occurred after `id`; pass `None` to receive
390 /// all events from the beginning.
391 ///
392 /// # Example
393 ///
394 /// ```no_run
395 /// use futures::StreamExt;
396 /// use armada_client::{ArmadaClient, StaticTokenProvider};
397 ///
398 /// # async fn example() -> Result<(), armada_client::Error> {
399 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
400 /// let mut stream = client
401 /// .watch("my-queue", "my-job-set", None)
402 /// .await?;
403 ///
404 /// let mut last_id = String::new();
405 /// while let Some(result) = stream.next().await {
406 /// match result {
407 /// Ok(msg) => {
408 /// last_id = msg.id.clone();
409 /// println!("event id={} message={:?}", msg.id, msg.message);
410 /// }
411 /// Err(e) => {
412 /// eprintln!("stream error: {e}");
413 /// break; // reconnect using `from_message_id: Some(last_id)`
414 /// }
415 /// }
416 /// }
417 /// # Ok(())
418 /// # }
419 /// ```
420 ///
421 /// # Errors
422 ///
423 /// - [`Error::Auth`] if the token provider fails.
424 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
425 /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
426 /// The stream will not error simply because the job set does not exist yet —
427 /// it will wait for events, which avoids races when `watch` is called
428 /// immediately after `submit`.
429 /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
430 /// sends a trailing error status.
431 #[instrument(skip_all, fields(queue, job_set_id))]
432 pub async fn watch(
433 &self,
434 queue: impl Into<String>,
435 job_set_id: impl Into<String>,
436 from_message_id: Option<String>,
437 ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
438 let queue: String = queue.into();
439 let job_set_id: String = job_set_id.into();
440 tracing::Span::current()
441 .record("queue", queue.as_str())
442 .record("job_set_id", job_set_id.as_str());
443
444 let token = self.token_provider.token().await?;
445 let job_set_request = JobSetRequest {
446 id: job_set_id,
447 queue,
448 from_message_id: from_message_id.unwrap_or_default(),
449 // Keep the stream open for new events.
450 watch: true,
451 // Do not fail immediately if the job set does not exist yet — this
452 // avoids a race between submit() and watch() where the server has
453 // not yet created the job set by the time the watch RPC arrives.
454 error_if_missing: false,
455 };
456 let mut req = tonic::Request::new(job_set_request);
457 if !token.is_empty() {
458 req.metadata_mut().insert("authorization", token.parse()?);
459 }
460 self.apply_timeout(&mut req);
461 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
462 // shared channel, not the underlying connection.
463 let stream = self
464 .event_client
465 .clone()
466 .get_job_set_events(req)
467 .await?
468 .into_inner();
469 Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
470 }
471}