Skip to main content

armada_client/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::StreamExt;
5use futures::stream::BoxStream;
6use tonic::transport::{Channel, ClientTlsConfig};
7use tracing::instrument;
8
9use crate::api::{
10    EventStreamMessage, JobSetRequest, JobSubmitRequest, JobSubmitResponse,
11    event_client::EventClient, submit_client::SubmitClient,
12};
13use crate::auth::TokenProvider;
14use crate::error::Error;
15
16/// Armada gRPC client providing job submission and event-stream watching.
17///
18/// # Construction
19///
20/// Use [`ArmadaClient::connect`] for plaintext (dev / in-cluster) connections
21/// and [`ArmadaClient::connect_tls`] for production clusters behind TLS.
22/// Both accept any [`TokenProvider`] — pass [`crate::StaticTokenProvider`] for
23/// a static bearer token or supply your own implementation for dynamic auth.
24///
25/// ```no_run
26/// # use armada_client::{ArmadaClient, StaticTokenProvider};
27/// # async fn example() -> Result<(), armada_client::Error> {
28/// // Plaintext
29/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
30///     .await?;
31///
32/// // TLS (uses system root certificates)
33/// let client = ArmadaClient::connect_tls("https://armada.example.com:443", StaticTokenProvider::new("tok"))
34///     .await?;
35/// # Ok(())
36/// # }
37/// ```
38///
39/// # Cloning
40///
41/// `ArmadaClient` is `Clone`. All clones share the same underlying channel and
42/// connection pool — cloning is `O(1)` and is the correct way to distribute the
43/// client across tasks:
44///
45/// ```no_run
46/// # use armada_client::{ArmadaClient, StaticTokenProvider};
47/// # async fn example() -> Result<(), armada_client::Error> {
48/// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
49///     .await?;
50///
51/// let c1 = client.clone();
52/// let c2 = client.clone();
53/// tokio::spawn(async move { /* use c1 */ });
54/// tokio::spawn(async move { /* use c2 */ });
55/// # Ok(())
56/// # }
57/// ```
58///
59/// # Timeouts
60///
61/// Apply a per-call deadline with [`ArmadaClient::with_timeout`]. When set,
62/// every RPC is governed by the deadline for its **entire duration**. For
63/// unary calls (`submit`) the deadline covers the round-trip. For streaming
64/// calls (`watch`) it covers the full lifetime of the stream — if the timeout
65/// elapses while events are still being received the stream is cancelled with
66/// a `DEADLINE_EXCEEDED` status.
67#[derive(Clone)]
68pub struct ArmadaClient {
69    submit_client: SubmitClient<Channel>,
70    event_client: EventClient<Channel>,
71    token_provider: Arc<dyn TokenProvider + Send + Sync>,
72    timeout: Option<Duration>,
73}
74
75impl ArmadaClient {
76    /// Connect to an Armada server at `endpoint` using plaintext (no TLS).
77    ///
78    /// `endpoint` must be a valid URI, e.g. `"http://localhost:50051"`.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
83    /// [`Error::Transport`] if the connection cannot be established.
84    pub async fn connect(
85        endpoint: impl Into<String>,
86        token_provider: impl TokenProvider + 'static,
87    ) -> Result<Self, Error> {
88        let channel = Channel::from_shared(endpoint.into())
89            .map_err(|e| Error::InvalidUri(e.to_string()))?
90            .connect()
91            .await?;
92        Ok(Self::from_parts(channel, token_provider))
93    }
94
95    /// Connect to an Armada server at `endpoint` using TLS.
96    ///
97    /// Uses the system's native root certificates to verify the server
98    /// certificate. `endpoint` should use the `https://` scheme,
99    /// e.g. `"https://armada.example.com:443"`.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`Error::InvalidUri`] if the URI is malformed, or
104    /// [`Error::Transport`] if TLS configuration or the connection fails.
105    pub async fn connect_tls(
106        endpoint: impl Into<String>,
107        token_provider: impl TokenProvider + 'static,
108    ) -> Result<Self, Error> {
109        let channel = Channel::from_shared(endpoint.into())
110            .map_err(|e| Error::InvalidUri(e.to_string()))?
111            .tls_config(ClientTlsConfig::new())?
112            .connect()
113            .await?;
114        Ok(Self::from_parts(channel, token_provider))
115    }
116
117    fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
118        Self {
119            submit_client: SubmitClient::new(channel.clone()),
120            event_client: EventClient::new(channel),
121            token_provider: Arc::new(token_provider),
122            timeout: None,
123        }
124    }
125
126    /// Set a default timeout applied to every RPC call.
127    ///
128    /// When the timeout elapses the call fails with [`Error::Grpc`] wrapping a
129    /// `DEADLINE_EXCEEDED` status. For streaming calls like
130    /// [`ArmadaClient::watch`], the deadline covers the **entire stream
131    /// duration** — if it elapses while events are still arriving the stream
132    /// is cancelled immediately.
133    ///
134    /// Returns `self` so the call can be chained directly after construction:
135    ///
136    /// ```no_run
137    /// # use std::time::Duration;
138    /// # use armada_client::{ArmadaClient, StaticTokenProvider};
139    /// # async fn example() -> Result<(), armada_client::Error> {
140    /// let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("tok"))
141    ///     .await?
142    ///     .with_timeout(Duration::from_secs(30));
143    /// # Ok(())
144    /// # }
145    /// ```
146    pub fn with_timeout(mut self, timeout: Duration) -> Self {
147        self.timeout = Some(timeout);
148        self
149    }
150
151    fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
152        if let Some(t) = self.timeout {
153            req.set_timeout(t);
154        }
155    }
156
157    /// Submit a batch of jobs to Armada.
158    ///
159    /// Attaches an `authorization` header on every call using the configured
160    /// [`TokenProvider`] (e.g. `Bearer <token>` or `Basic <credentials>`). Multiple job items can be included in
161    /// a single request — they are all submitted atomically to the same queue
162    /// and job set.
163    ///
164    /// # Example
165    ///
166    /// ```no_run
167    /// use armada_client::{
168    ///     ArmadaClient, JobRequestItemBuilder, JobSubmitRequest, StaticTokenProvider,
169    /// };
170    /// use armada_client::k8s::io::api::core::v1::PodSpec;
171    ///
172    /// # async fn example() -> Result<(), armada_client::Error> {
173    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
174    /// let item = JobRequestItemBuilder::new()
175    ///     .namespace("default")
176    ///     .pod_spec(PodSpec { containers: vec![], ..Default::default() })
177    ///     .build();
178    ///
179    /// let response = client
180    ///     .submit(JobSubmitRequest {
181    ///         queue: "my-queue".into(),
182    ///         job_set_id: "my-job-set".into(),
183    ///         job_request_items: vec![item],
184    ///     })
185    ///     .await?;
186    ///
187    /// for r in &response.job_response_items {
188    ///     if r.error.is_empty() {
189    ///         println!("submitted: {}", r.job_id);
190    ///     } else {
191    ///         eprintln!("rejected: {}", r.error);
192    ///     }
193    /// }
194    /// # Ok(())
195    /// # }
196    /// ```
197    ///
198    /// # Errors
199    ///
200    /// - [`Error::Auth`] if the token provider fails.
201    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
202    /// - [`Error::Grpc`] if the server returns a non-OK status.
203    #[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
204    pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
205        let token = self.token_provider.token().await?;
206        let mut req = tonic::Request::new(request);
207        if !token.is_empty() {
208            req.metadata_mut().insert("authorization", token.parse()?);
209        }
210        self.apply_timeout(&mut req);
211        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
212        // shared channel, not the underlying connection.
213        let resp = self.submit_client.clone().submit_jobs(req).await?;
214        Ok(resp.into_inner())
215    }
216
217    /// Watch a job set, returning a stream of events.
218    ///
219    /// Opens a server-streaming gRPC call and returns a [`BoxStream`] that
220    /// yields [`EventStreamMessage`] values as the server pushes them. The
221    /// stream ends when the server closes the connection.
222    ///
223    /// **Reconnection is the caller's responsibility.** Store the last
224    /// `message_id` you received and pass it back as `from_message_id` when
225    /// reconnecting to avoid replaying events you have already processed.
226    ///
227    /// # Arguments
228    ///
229    /// * `queue` — Armada queue name.
230    /// * `job_set_id` — Job set to watch.
231    /// * `from_message_id` — Optional resume cursor. Pass `Some(id)` to
232    ///   receive only events that occurred after `id`; pass `None` to receive
233    ///   all events from the beginning.
234    ///
235    /// # Example
236    ///
237    /// ```no_run
238    /// use futures::StreamExt;
239    /// use armada_client::{ArmadaClient, StaticTokenProvider};
240    ///
241    /// # async fn example() -> Result<(), armada_client::Error> {
242    /// # let client = ArmadaClient::connect("http://localhost:50051", StaticTokenProvider::new("")).await?;
243    /// let mut stream = client
244    ///     .watch("my-queue", "my-job-set", None)
245    ///     .await?;
246    ///
247    /// let mut last_id = String::new();
248    /// while let Some(result) = stream.next().await {
249    ///     match result {
250    ///         Ok(msg) => {
251    ///             last_id = msg.id.clone();
252    ///             println!("event id={} message={:?}", msg.id, msg.message);
253    ///         }
254    ///         Err(e) => {
255    ///             eprintln!("stream error: {e}");
256    ///             break; // reconnect using `from_message_id: Some(last_id)`
257    ///         }
258    ///     }
259    /// }
260    /// # Ok(())
261    /// # }
262    /// ```
263    ///
264    /// # Errors
265    ///
266    /// - [`Error::Auth`] if the token provider fails.
267    /// - [`Error::InvalidMetadata`] if the token contains invalid header characters.
268    /// - [`Error::Grpc`] if the server returns a non-OK status on the initial call.
269    ///   The stream will not error simply because the job set does not exist yet —
270    ///   it will wait for events, which avoids races when `watch` is called
271    ///   immediately after `submit`.
272    /// - Individual stream items may also be [`Err(Error::Grpc)`] if the server
273    ///   sends a trailing error status.
274    #[instrument(skip_all, fields(queue, job_set_id))]
275    pub async fn watch(
276        &self,
277        queue: impl Into<String>,
278        job_set_id: impl Into<String>,
279        from_message_id: Option<String>,
280    ) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
281        let queue: String = queue.into();
282        let job_set_id: String = job_set_id.into();
283        tracing::Span::current()
284            .record("queue", queue.as_str())
285            .record("job_set_id", job_set_id.as_str());
286
287        let token = self.token_provider.token().await?;
288        let job_set_request = JobSetRequest {
289            id: job_set_id,
290            queue,
291            from_message_id: from_message_id.unwrap_or_default(),
292            // Keep the stream open for new events.
293            watch: true,
294            // Do not fail immediately if the job set does not exist yet — this
295            // avoids a race between submit() and watch() where the server has
296            // not yet created the job set by the time the watch RPC arrives.
297            error_if_missing: false,
298        };
299        let mut req = tonic::Request::new(job_set_request);
300        if !token.is_empty() {
301            req.metadata_mut().insert("authorization", token.parse()?);
302        }
303        self.apply_timeout(&mut req);
304        // `.clone()` on a tonic client is O(1) — it clones an Arc over the
305        // shared channel, not the underlying connection.
306        let stream = self
307            .event_client
308            .clone()
309            .get_job_set_events(req)
310            .await?
311            .into_inner();
312        Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
313    }
314}