Skip to main content

armada_client/
auth.rs

1use futures::future::BoxFuture;
2
3use crate::error::Error;
4
5/// Provides an `Authorization` header value for each outgoing gRPC call.
6///
7/// Implement this trait to integrate with any auth source — static bearer
8/// token, HTTP Basic Auth, OAuth2, OIDC, service-account token rotation, etc.
9///
10/// # Contract
11///
12/// [`token`](TokenProvider::token) must return the **complete** value to be
13/// set in the `authorization` gRPC metadata field, including the scheme
14/// prefix:
15///
16/// - Bearer auth: `"Bearer <token>"`
17/// - Basic auth: `"Basic <base64(user:pass)>"`
18/// - No auth (unauthenticated clusters): `""` — the client omits the header.
19///
20/// # Custom implementation
21///
22/// The future must be `Send` because the client may call it from any async
23/// task. Use `Box::pin(async move { … })` to construct the return value:
24///
25/// ```no_run
26/// use futures::future::BoxFuture;
27/// use armada_client::{Error, TokenProvider};
28///
29/// struct MyTokenProvider;
30///
31/// impl TokenProvider for MyTokenProvider {
32///     fn token(&self) -> BoxFuture<'_, Result<String, Error>> {
33///         Box::pin(async move {
34///             // Fetch or refresh from your auth backend here.
35///             Ok("Bearer my-dynamic-token".to_string())
36///         })
37///     }
38/// }
39/// ```
40///
41/// Return [`Error::auth`] to signal that token retrieval failed:
42///
43/// ```no_run
44/// # use futures::future::BoxFuture;
45/// # use armada_client::{Error, TokenProvider};
46/// # struct Failing;
47/// # impl TokenProvider for Failing {
48/// #     fn token(&self) -> BoxFuture<'_, Result<String, Error>> {
49///         Box::pin(async move {
50///             Err(Error::auth("token expired"))
51///         })
52/// #     }
53/// # }
54/// ```
55pub trait TokenProvider: Send + Sync {
56    /// Retrieve the current `Authorization` header value asynchronously.
57    ///
58    /// Return the full scheme-prefixed value (e.g. `"Bearer <token>"`) or an
59    /// empty string to send no `Authorization` header. The client calls this
60    /// before every RPC; implementations that cache tokens should handle expiry
61    /// and refresh internally.
62    fn token(&self) -> BoxFuture<'_, Result<String, Error>>;
63}
64
65/// A [`TokenProvider`] that always returns the same static bearer token.
66///
67/// Suitable for development, testing, or clusters where a single long-lived
68/// token is acceptable. For production workloads with token rotation, implement
69/// [`TokenProvider`] directly.
70///
71/// # Debug output
72///
73/// `StaticTokenProvider` implements [`Debug`] but redacts the token value so
74/// that secrets are not accidentally leaked into logs:
75///
76/// ```
77/// use armada_client::StaticTokenProvider;
78///
79/// let p = StaticTokenProvider::new("super-secret");
80/// assert_eq!(format!("{p:?}"), "StaticTokenProvider { token: \"[redacted]\" }");
81/// ```
82///
83/// [`Debug`]: std::fmt::Debug
84pub struct StaticTokenProvider {
85    token: String,
86}
87
88impl std::fmt::Debug for StaticTokenProvider {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("StaticTokenProvider")
91            .field("token", &"[redacted]")
92            .finish()
93    }
94}
95
96impl StaticTokenProvider {
97    /// Create a new `StaticTokenProvider` from a raw bearer token string.
98    ///
99    /// Pass the raw token value **without** the `"Bearer "` scheme prefix —
100    /// the prefix is added automatically. If the value already starts with
101    /// `"Bearer "` (e.g. copied from an HTTP header) it is stripped first so
102    /// the header is never double-prefixed. Pass an empty string for
103    /// unauthenticated clusters:
104    ///
105    /// ```
106    /// use armada_client::StaticTokenProvider;
107    ///
108    /// let provider = StaticTokenProvider::new("my-bearer-token");
109    /// let same     = StaticTokenProvider::new("Bearer my-bearer-token"); // identical result
110    /// let also     = StaticTokenProvider::new("bearer my-bearer-token"); // also identical
111    /// let empty    = StaticTokenProvider::new("");   // unauthenticated
112    /// ```
113    pub fn new(token: impl Into<String>) -> Self {
114        let token = token.into();
115        // Strip any pre-existing "Bearer " prefix (case-insensitive) so callers
116        // who copy-paste a full header value don't accidentally produce
117        // "Bearer Bearer <token>" or "Bearer bearer <token>".
118        let raw = if token
119            .get(..7)
120            .is_some_and(|s| s.eq_ignore_ascii_case("bearer "))
121        {
122            &token[7..]
123        } else {
124            &token
125        };
126        Self {
127            token: if raw.is_empty() {
128                raw.to_string()
129            } else {
130                format!("Bearer {raw}")
131            },
132        }
133    }
134}
135
136impl TokenProvider for Box<dyn TokenProvider + Send + Sync> {
137    fn token(&self) -> BoxFuture<'_, Result<String, Error>> {
138        (**self).token()
139    }
140}
141
142impl TokenProvider for StaticTokenProvider {
143    fn token(&self) -> BoxFuture<'_, Result<String, Error>> {
144        let token = self.token.clone();
145        Box::pin(async move { Ok(token) })
146    }
147}
148
149/// A [`TokenProvider`] that authenticates using HTTP Basic Auth.
150///
151/// Encodes `username:password` as Base64 and returns the full
152/// `Basic <credentials>` header value. Suitable for Armada clusters
153/// configured with `basicAuth.enableAuthentication: true`.
154///
155/// # Example
156///
157/// ```
158/// use armada_client::BasicAuthProvider;
159///
160/// let provider = BasicAuthProvider::new("admin", "admin");
161/// ```
162pub struct BasicAuthProvider {
163    header: String,
164}
165
166impl BasicAuthProvider {
167    /// Create a new `BasicAuthProvider` from a username and password.
168    pub fn new(username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
169        use base64::Engine as _;
170        let raw = format!("{}:{}", username.as_ref(), password.as_ref());
171        let encoded = base64::engine::general_purpose::STANDARD.encode(raw);
172        Self {
173            header: format!("Basic {encoded}"),
174        }
175    }
176}
177
178impl std::fmt::Debug for BasicAuthProvider {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        f.debug_struct("BasicAuthProvider")
181            .field("header", &"[redacted]")
182            .finish()
183    }
184}
185
186impl TokenProvider for BasicAuthProvider {
187    fn token(&self) -> BoxFuture<'_, Result<String, Error>> {
188        let header = self.header.clone();
189        Box::pin(async move { Ok(header) })
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn static_provider_debug_redacts_token() {
199        let p = StaticTokenProvider::new("super-secret");
200        assert_eq!(
201            format!("{p:?}"),
202            "StaticTokenProvider { token: \"[redacted]\" }"
203        );
204    }
205
206    #[tokio::test]
207    async fn static_provider_returns_bearer_header() {
208        let provider = StaticTokenProvider::new("tok");
209        assert_eq!(provider.token().await.unwrap(), "Bearer tok");
210    }
211
212    #[tokio::test]
213    async fn static_provider_strips_bearer_prefix_case_insensitive() {
214        for prefix in &["Bearer ", "bearer ", "BEARER "] {
215            let provider = StaticTokenProvider::new(format!("{prefix}tok"));
216            assert_eq!(
217                provider.token().await.unwrap(),
218                "Bearer tok",
219                "failed for prefix {prefix:?}"
220            );
221        }
222    }
223
224    #[tokio::test]
225    async fn static_provider_empty_token_returns_empty() {
226        let provider = StaticTokenProvider::new("");
227        assert_eq!(provider.token().await.unwrap(), "");
228    }
229
230    #[tokio::test]
231    async fn basic_provider_returns_basic_header() {
232        let provider = BasicAuthProvider::new("admin", "admin");
233        let result = provider.token().await.unwrap();
234        // base64("admin:admin") = "YWRtaW46YWRtaW4="
235        assert_eq!(result, "Basic YWRtaW46YWRtaW4=");
236    }
237}