armada_client/client.rs
1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10 EventStreamMessage, JobSetRequest, JobSubmitRequest, JobSubmitResponse,
11 event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections
21/// and [`ArmadaClient::connect_tls`] for production clusters behind TLS.
22/// Both accept any [`TokenProvider`] — pass [`crate::StaticTokenProvider`] for
23/// a static bearer token or supply your own implementation for dynamic auth.
24///
25/// ```no_run
26/// # use armada_client::{ArmadaClient, StaticTokenProvider};
27/// # async fn example() -> Result<(), armada_client::Error> {
28/// // Plaintext
29/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
30/// .await?;
31///
32/// // TLS (uses system root certificates)
33/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
34/// .await?;
35/// # Ok(())
36/// # }
37/// ```
38///
39/// # Cloning
40///
41/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
42/// connection pool — cloning is `O(1)` and is the correct way to distribute the
43/// client across tasks:
44///
45/// ```no_run
46/// # use armada_client::{ArmadaClient, StaticTokenProvider};
47/// # async fn example() -> Result<(), armada_client::Error> {
48/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
49/// .await?;
50///
51/// let c1 = client.clone();
52/// let c2 = client.clone();
53/// tokio::spawn(async move { /* use c1 */ });
54/// tokio::spawn(async move { /* use c2 */ });
55/// # Ok(())
56/// # }
57/// ```
58///
59/// # Timeouts
60///
61/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
62/// every RPC is governed by the deadline for its **entire duration**. For
63/// unary calls (`submit`) the deadline covers the round-trip. For streaming
64/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
65/// elapses while events are still being received the stream is cancelled with
66/// a `DEADLINE_EXCEEDED` status.
67#[derive(Clone)]
68pub struct ArmadaClient {
69 submit_client: SubmitClient<Channel>,
70 event_client: EventClient<Channel>,
71 token_provider: Arc<dyn TokenProvider + Send + Sync>,
72 timeout: Option<Duration>,
73}
74
75impl ArmadaClient {
76 /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
77 ///
78 /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
79 ///
80 /// # Errors
81 ///
82 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
83 /// [`Error::Transport`] if the connection cannot be established.
84 pub async fn connect(
85 endpoint: impl Into<String>,
86 token_provider: impl TokenProvider + 'static,
87 ) -> Result<Self, Error> {
88 let channel = Channel::from_shared(endpoint.into())
89 .map_err(|e| Error::InvalidUri(e.to_string()))?
90 .connect()
91 .await?;
92 Ok(Self::from_parts(channel, token_provider))
93 }
94
95 /// Connect to an Armada server at `endpoint` using TLS.
96 ///
97 /// Uses the system's native root certificates to verify the server
98 /// certificate. `endpoint` should use the `https://` scheme,
99 /// e.g. `"https://armada.example.com:443"`.
100 ///
101 /// # Errors
102 ///
103 /// Returns [`Error::InvalidUri`] if the URI is malformed, or
104 /// [`Error::Transport`] if TLS configuration or the connection fails.
105 pub async fn connect_tls(
106 endpoint: impl Into<String>,
107 token_provider: impl TokenProvider + 'static,
108 ) -> Result<Self, Error> {
109 let channel = Channel::from_shared(endpoint.into())
110 .map_err(|e| Error::InvalidUri(e.to_string()))?
111 .tls_config(ClientTlsConfig::new())?
112 .connect()
113 .await?;
114 Ok(Self::from_parts(channel, token_provider))
115 }
116
117 fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
118 Self {
119 submit_client: SubmitClient::new(channel.clone()),
120 event_client: EventClient::new(channel),
121 token_provider: Arc::new(token_provider),
122 timeout: None,
123 }
124 }
125
126 /// Set a default timeout applied to every RPC call.
127 ///
128 /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
129 /// `DEADLINE_EXCEEDED` status. For streaming calls like
130 /// [`ArmadaClient::watch`], the deadline covers the **entire stream
131 /// duration** — if it elapses while events are still arriving the stream
132 /// is cancelled immediately.
133 ///
134 /// Returns `self` so the call can be chained directly after construction:
135 ///
136 /// ```no_run
137 /// # use std::time::Duration;
138 /// # use armada_client::{ArmadaClient, StaticTokenProvider};
139 /// # async fn example() -> Result<(), armada_client::Error> {
140 /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
141 /// .await?
142 /// .with_timeout(Duration::from_secs(30));
143 /// # Ok(())
144 /// # }
145 /// ```
146 pub fn with_timeout(mut self, timeout: Duration) -> Self {
147 self.timeout = Some(timeout);
148 self
149 }
150
151 fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
152 if let Some(t) = self.timeout {
153 req.set_timeout(t);
154 }
155 }
156
157 /// Submit a batch of jobs to Armada.
158 ///
159 /// Attaches an `authorization` header on every call using the configured
160 /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
161 /// a single request — they are all submitted atomically to the same queue
162 /// and job set.
163 ///
164 /// # Example
165 ///
166 /// ```no_run
167 /// use armada_client::{
168 /// ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
169 /// };
170 /// use armada_client::k8s::io::api::core::v1::PodSpec;
171 ///
172 /// # async fn example() -> Result<(), armada_client::Error> {
173 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
174 /// let item = JobRequestItemBuilder::new()
175 /// .namespace("default")
176 /// .pod_spec(PodSpec { containers: vec![], ..Default::default() })
177 /// .build();
178 ///
179 /// let response = client
180 /// .submit(JobSubmitRequest {
181 /// queue: "my-queue".into(),
182 /// job_set_id: "my-job-set".into(),
183 /// job_request_items: vec![item],
184 /// })
185 /// .await?;
186 ///
187 /// for r in &response.job_response_items {
188 /// if r.error.is_empty() {
189 /// println!("submitted: {}", r.job_id);
190 /// } else {
191 /// eprintln!("rejected: {}", r.error);
192 /// }
193 /// }
194 /// # Ok(())
195 /// # }
196 /// ```
197 ///
198 /// # Errors
199 ///
200 /// - [`Error::Auth`] if the token provider fails.
201 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
202 /// - [`Error::Grpc`] if the server returns a non-OK status.
203 #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
204 pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
205 let token = self.token_provider.token().await?;
206 let mut req = tonic::Request::new(request);
207 if !token.is_empty() {
208 req.metadata_mut().insert("authorization", token.parse()?);
209 }
210 self.apply_timeout(&mut req);
211 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
212 // shared channel, not the underlying connection.
213 let resp = self.submit_client.clone().submit_jobs(req).await?;
214 Ok(resp.into_inner())
215 }
216
217 /// Watch a job set, returning a stream of events.
218 ///
219 /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
220 /// yields [`EventStreamMessage`] values as the server pushes them. The
221 /// stream ends when the server closes the connection.
222 ///
223 /// **Reconnection is the caller's responsibility.** Store the last
224 /// `message_id` you received and pass it back as `from_message_id` when
225 /// reconnecting to avoid replaying events you have already processed.
226 ///
227 /// # Arguments
228 ///
229 /// * `queue` — Armada queue name.
230 /// * `job_set_id` — Job set to watch.
231 /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
232 /// receive only events that occurred after `id`; pass `None` to receive
233 /// all events from the beginning.
234 ///
235 /// # Example
236 ///
237 /// ```no_run
238 /// use futures::StreamExt;
239 /// use armada_client::{ArmadaClient, StaticTokenProvider};
240 ///
241 /// # async fn example() -> Result<(), armada_client::Error> {
242 /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
243 /// let mut stream = client
244 /// .watch("my-queue", "my-job-set", None)
245 /// .await?;
246 ///
247 /// let mut last_id = String::new();
248 /// while let Some(result) = stream.next().await {
249 /// match result {
250 /// Ok(msg) => {
251 /// last_id = msg.id.clone();
252 /// println!("event id={} message={:?}", msg.id, msg.message);
253 /// }
254 /// Err(e) => {
255 /// eprintln!("stream error: {e}");
256 /// break; // reconnect using `from_message_id: Some(last_id)`
257 /// }
258 /// }
259 /// }
260 /// # Ok(())
261 /// # }
262 /// ```
263 ///
264 /// # Errors
265 ///
266 /// - [`Error::Auth`] if the token provider fails.
267 /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
268 /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
269 /// The stream will not error simply because the job set does not exist yet —
270 /// it will wait for events, which avoids races when `watch` is called
271 /// immediately after `submit`.
272 /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
273 /// sends a trailing error status.
274 #[instrument(skip_all, fields(queue, job_set_id))]
275 pub async fn watch(
276 &self,
277 queue: impl Into<String>,
278 job_set_id: impl Into<String>,
279 from_message_id: Option<String>,
280 ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
281 let queue: String = queue.into();
282 let job_set_id: String = job_set_id.into();
283 tracing::Span::current()
284 .record("queue", queue.as_str())
285 .record("job_set_id", job_set_id.as_str());
286
287 let token = self.token_provider.token().await?;
288 let job_set_request = JobSetRequest {
289 id: job_set_id,
290 queue,
291 from_message_id: from_message_id.unwrap_or_default(),
292 // Keep the stream open for new events.
293 watch: true,
294 // Do not fail immediately if the job set does not exist yet — this
295 // avoids a race between submit() and watch() where the server has
296 // not yet created the job set by the time the watch RPC arrives.
297 error_if_missing: false,
298 };
299 let mut req = tonic::Request::new(job_set_request);
300 if !token.is_empty() {
301 req.metadata_mut().insert("authorization", token.parse()?);
302 }
303 self.apply_timeout(&mut req);
304 // `.clone()` on a tonic client is O(1) — it clones an Arc over the
305 // shared channel, not the underlying connection.
306 let stream = self
307 .event_client
308 .clone()
309 .get_job_set_events(req)
310 .await?
311 .into_inner();
312 Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
313 }
314}