armada_client/client.rs
1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10 CancellationResult, EventStreamMessage, JobCancelRequest, JobSetCancelRequest, JobSetRequest,
11 JobSubmitRequest, JobSubmitResponse, event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections
21/// and [`ArmadaClient::connect_tls`] for production clusters behind TLS.
22/// Both accept any [`TokenProvider`] — pass [`crate::StaticTokenProvider`] for
23/// a static bearer token or supply your own implementation for dynamic auth.
24///
25/// ```no_run
26/// # use armada_client::{ArmadaClient, StaticTokenProvider};
27/// # async fn example() -> Result<(), armada_client::Error> {
28/// // Plaintext
29/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
30/// .await?;
31///
32/// // TLS (uses system root certificates)
33/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
34/// .await?;
35/// # Ok(())
36/// # }
37/// ```
38///
39/// # Cloning
40///
41/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
42/// connection pool — cloning is `O(1)` and is the correct way to distribute the
43/// client across tasks:
44///
45/// ```no_run
46/// # use armada_client::{ArmadaClient, StaticTokenProvider};
47/// # async fn example() -> Result<(), armada_client::Error> {
48/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
49/// .await?;
50///
51/// let c1 = client.clone();
52/// let c2 = client.clone();
53/// tokio::spawn(async move { /* use c1 */ });
54/// tokio::spawn(async move { /* use c2 */ });
55/// # Ok(())
56/// # }
57/// ```
58///
59/// # Timeouts
60///
61/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
62/// every RPC is governed by the deadline for its **entire duration**. For
63/// unary calls (`submit`) the deadline covers the round-trip. For streaming
64/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
65/// elapses while events are still being received the stream is cancelled with
66/// a `DEADLINE_EXCEEDED` status.
67#[derive(Clone)]
68pub struct ArmadaClient {
69 submit_client: SubmitClient<Channel>,
70 event_client: EventClient<Channel>,
71 token_provider: Arc<dyn TokenProvider + Send + Sync>,
72 timeout: Option<Duration>,
73}
74
75impl ArmadaClient {
76 /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
77 ///
78 /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
79 ///
80 /// # Errors
81 ///
82 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
83 /// [`Error::Transport`] if the connection cannot be established.
84 pub async fn connect(
85 endpoint: impl Into<String>,
86 token_provider: impl TokenProvider + 'static,
87 ) -> Result<Self, Error> {
88 let channel = Channel::from_shared(endpoint.into())
89 .map_err(|e| Error::InvalidUri(e.to_string()))?
90 .connect()
91 .await?;
92 Ok(Self::from_parts(channel, token_provider))
93 }
94
95 /// Connect to an Armada server at `endpoint` using TLS.
96 ///
97 /// Uses the system's native root certificates to verify the server
98 /// certificate. `endpoint` should use the `https://` scheme,
99 /// e.g. `"https://armada.example.com:443"`.
100 ///
101 /// # Errors
102 ///
103 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
104 /// [`Error::Transport`] if TLS configuration or the connection fails.
105 pub async fn connect_tls(
106 endpoint: impl Into<String>,
107 token_provider: impl TokenProvider + 'static,
108 ) -> Result<Self, Error> {
109 let channel = Channel::from_shared(endpoint.into())
110 .map_err(|e| Error::InvalidUri(e.to_string()))?
111 .tls_config(ClientTlsConfig::new())?
112 .connect()
113 .await?;
114 Ok(Self::from_parts(channel, token_provider))
115 }
116
117 fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
118 Self {
119 submit_client: SubmitClient::new(channel.clone()),
120 event_client: EventClient::new(channel),
121 token_provider: Arc::new(token_provider),
122 timeout: None,
123 }
124 }
125
126 /// Set a default timeout applied to every RPC call.
127 ///
128 /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
129 /// `DEADLINE_EXCEEDED` status. For streaming calls like
130 /// [`ArmadaClient::watch`], the deadline covers the **entire stream
131 /// duration** — if it elapses while events are still arriving the stream
132 /// is cancelled immediately.
133 ///
134 /// Returns `self` so the call can be chained directly after construction:
135 ///
136 /// ```no_run
137 /// # use std::time::Duration;
138 /// # use armada_client::{ArmadaClient, StaticTokenProvider};
139 /// # async fn example() -> Result<(), armada_client::Error> {
140 /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
141 /// .await?
142 /// .with_timeout(Duration::from_secs(30));
143 /// # Ok(())
144 /// # }
145 /// ```
146 pub fn with_timeout(mut self, timeout: Duration) -> Self {
147 self.timeout = Some(timeout);
148 self
149 }
150
151 fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
152 if let Some(t) = self.timeout {
153 req.set_timeout(t);
154 }
155 }
156
157 /// Submit a batch of jobs to Armada.
158 ///
159 /// Attaches an `authorization` header on every call using the configured
160 /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
161 /// a single request — they are all submitted atomically to the same queue
162 /// and job set.
163 ///
164 /// # Example
165 ///
166 /// ```no_run
167 /// use armada_client::{
168 /// ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
169 /// };
170 /// use armada_client::k8s::io::api::core::v1::PodSpec;
171 ///
172 /// # async fn example() -> Result<(), armada_client::Error> {
173 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
174 /// let item = JobRequestItemBuilder::new()
175 /// .namespace("default")
176 /// .pod_spec(PodSpec { containers: vec![], ..Default::default() })
177 /// .build();
178 ///
179 /// let response = client
180 /// .submit(JobSubmitRequest {
181 /// queue: "my-queue".into(),
182 /// job_set_id: "my-job-set".into(),
183 /// job_request_items: vec![item],
184 /// })
185 /// .await?;
186 ///
187 /// for r in &response.job_response_items {
188 /// if r.error.is_empty() {
189 /// println!("submitted: {}", r.job_id);
190 /// } else {
191 /// eprintln!("rejected: {}", r.error);
192 /// }
193 /// }
194 /// # Ok(())
195 /// # }
196 /// ```
197 ///
198 /// # Errors
199 ///
200 /// - [`Error::Auth`] if the token provider fails.
201 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
202 /// - [`Error::Grpc`] if the server returns a non-OK status.
203 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
204 pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
205 let token = self.token_provider.token().await?;
206 let mut req = tonic::Request::new(request);
207 if !token.is_empty() {
208 req.metadata_mut().insert("authorization", token.parse()?);
209 }
210 self.apply_timeout(&mut req);
211 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
212 // shared channel, not the underlying connection.
213 let resp = self.submit_client.clone().submit_jobs(req).await?;
214 Ok(resp.into_inner())
215 }
216
217 /// Cancel one or more jobs.
218 ///
219 /// # Arguments
220 ///
221 /// * `request.queue` — Armada queue name.
222 /// * `request.job_set_id` — Job set the jobs belong to.
223 /// * `request.job_id` — Single job ID to cancel (legacy, optional).
224 /// * `request.job_ids` — Multiple job IDs to cancel in one call.
225 /// * `request.reason` — Human-readable cancellation reason (optional).
226 ///
227 /// # Example
228 ///
229 /// ```no_run
230 /// use armada_client::{ArmadaClient, JobCancelRequest, StaticTokenProvider};
231 ///
232 /// # async fn example() -> Result<(), armada_client::Error> {
233 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
234 /// let result = client
235 /// .cancel_jobs(JobCancelRequest {
236 /// queue: "my-queue".into(),
237 /// job_set_id: "my-job-set".into(),
238 /// job_ids: vec!["01abc".into(), "01def".into()],
239 /// reason: "no longer needed".into(),
240 /// ..Default::default()
241 /// })
242 /// .await?;
243 ///
244 /// println!("cancelled: {:?}", result.cancelled_ids);
245 /// # Ok(())
246 /// # }
247 /// ```
248 ///
249 /// # Errors
250 ///
251 /// - [`Error::Auth`] if the token provider fails.
252 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
253 /// - [`Error::Grpc`] if the server returns a non-OK status.
254 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
255 pub async fn cancel_jobs(
256 &self,
257 request: JobCancelRequest,
258 ) -> Result<CancellationResult, Error> {
259 let token = self.token_provider.token().await?;
260 let mut req = tonic::Request::new(request);
261 if !token.is_empty() {
262 req.metadata_mut().insert("authorization", token.parse()?);
263 }
264 self.apply_timeout(&mut req);
265 let resp = self.submit_client.clone().cancel_jobs(req).await?;
266 Ok(resp.into_inner())
267 }
268
269 /// Cancel all (or a filtered subset of) jobs in a job set.
270 ///
271 /// # Arguments
272 ///
273 /// * `request.queue` — Armada queue name.
274 /// * `request.job_set_id` — Job set to cancel.
275 /// * `request.filter` — Optional [`crate::JobSetFilter`] limiting cancellation to
276 /// jobs in specific states (e.g. only `Queued` and `Running`). Pass
277 /// `None` to cancel all non-terminal jobs.
278 /// * `request.reason` — Human-readable cancellation reason (optional).
279 ///
280 /// # Example
281 ///
282 /// ```no_run
283 /// use armada_client::{ArmadaClient, JobSetCancelRequest, JobSetFilter, JobState, StaticTokenProvider};
284 ///
285 /// # async fn example() -> Result<(), armada_client::Error> {
286 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
287 /// client
288 /// .cancel_job_set(JobSetCancelRequest {
289 /// queue: "my-queue".into(),
290 /// job_set_id: "my-job-set".into(),
291 /// filter: Some(JobSetFilter {
292 /// states: vec![JobState::Queued as i32, JobState::Running as i32],
293 /// }),
294 /// reason: "aborting experiment".into(),
295 /// })
296 /// .await?;
297 /// # Ok(())
298 /// # }
299 /// ```
300 ///
301 /// # Errors
302 ///
303 /// - [`Error::Auth`] if the token provider fails.
304 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
305 /// - [`Error::Grpc`] if the server returns a non-OK status.
306 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
307 pub async fn cancel_job_set(&self, request: JobSetCancelRequest) -> Result<(), Error> {
308 let token = self.token_provider.token().await?;
309 let mut req = tonic::Request::new(request);
310 if !token.is_empty() {
311 req.metadata_mut().insert("authorization", token.parse()?);
312 }
313 self.apply_timeout(&mut req);
314 self.submit_client.clone().cancel_job_set(req).await?;
315 Ok(())
316 }
317
318 /// Watch a job set, returning a stream of events.
319 ///
320 /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
321 /// yields [`EventStreamMessage`] values as the server pushes them. The
322 /// stream ends when the server closes the connection.
323 ///
324 /// **Reconnection is the caller's responsibility.** Store the last
325 /// `message_id` you received and pass it back as `from_message_id` when
326 /// reconnecting to avoid replaying events you have already processed.
327 ///
328 /// # Arguments
329 ///
330 /// * `queue` — Armada queue name.
331 /// * `job_set_id` — Job set to watch.
332 /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
333 /// receive only events that occurred after `id`; pass `None` to receive
334 /// all events from the beginning.
335 ///
336 /// # Example
337 ///
338 /// ```no_run
339 /// use futures::StreamExt;
340 /// use armada_client::{ArmadaClient, StaticTokenProvider};
341 ///
342 /// # async fn example() -> Result<(), armada_client::Error> {
343 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
344 /// let mut stream = client
345 /// .watch("my-queue", "my-job-set", None)
346 /// .await?;
347 ///
348 /// let mut last_id = String::new();
349 /// while let Some(result) = stream.next().await {
350 /// match result {
351 /// Ok(msg) => {
352 /// last_id = msg.id.clone();
353 /// println!("event id={} message={:?}", msg.id, msg.message);
354 /// }
355 /// Err(e) => {
356 /// eprintln!("stream error: {e}");
357 /// break; // reconnect using `from_message_id: Some(last_id)`
358 /// }
359 /// }
360 /// }
361 /// # Ok(())
362 /// # }
363 /// ```
364 ///
365 /// # Errors
366 ///
367 /// - [`Error::Auth`] if the token provider fails.
368 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
369 /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
370 /// The stream will not error simply because the job set does not exist yet —
371 /// it will wait for events, which avoids races when `watch` is called
372 /// immediately after `submit`.
373 /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
374 /// sends a trailing error status.
375 #[instrument(skip_all, fields(queue, job_set_id))]
376 pub async fn watch(
377 &self,
378 queue: impl Into<String>,
379 job_set_id: impl Into<String>,
380 from_message_id: Option<String>,
381 ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
382 let queue: String = queue.into();
383 let job_set_id: String = job_set_id.into();
384 tracing::Span::current()
385 .record("queue", queue.as_str())
386 .record("job_set_id", job_set_id.as_str());
387
388 let token = self.token_provider.token().await?;
389 let job_set_request = JobSetRequest {
390 id: job_set_id,
391 queue,
392 from_message_id: from_message_id.unwrap_or_default(),
393 // Keep the stream open for new events.
394 watch: true,
395 // Do not fail immediately if the job set does not exist yet — this
396 // avoids a race between submit() and watch() where the server has
397 // not yet created the job set by the time the watch RPC arrives.
398 error_if_missing: false,
399 };
400 let mut req = tonic::Request::new(job_set_request);
401 if !token.is_empty() {
402 req.metadata_mut().insert("authorization", token.parse()?);
403 }
404 self.apply_timeout(&mut req);
405 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
406 // shared channel, not the underlying connection.
407 let stream = self
408 .event_client
409 .clone()
410 .get_job_set_events(req)
411 .await?
412 .into_inner();
413 Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
414 }
415}