Skip to main content

cli_engine/transport/
injector.rs

1use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use base64::{Engine, engine::general_purpose::STANDARD};
4use reqwest::header::{AUTHORIZATION, COOKIE, HeaderName, HeaderValue};
5use serde::Deserialize;
6use tokio::{
7    sync::Mutex,
8    time::{Instant, timeout},
9};
10
11use crate::{AuthProvider, CliCoreError, Result};
12
13/// Async callback that returns a token for request injection.
14pub type TokenFunc =
15    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
16
17#[async_trait::async_trait]
18/// Mutates an outbound request with authentication material.
19pub trait AuthInjector: Send + Sync + std::fmt::Debug {
20    /// Adds auth headers or cookies to `request`.
21    async fn inject(&self, request: &mut reqwest::Request) -> Result<()>;
22}
23
24/// Injects `Authorization: Bearer <token>`.
25#[derive(Clone)]
26pub struct BearerTokenInjector {
27    token: TokenFunc,
28}
29
30impl std::fmt::Debug for BearerTokenInjector {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("BearerTokenInjector")
33            .finish_non_exhaustive()
34    }
35}
36
37impl BearerTokenInjector {
38    /// Creates a bearer-token injector from an async token callback.
39    #[must_use]
40    pub fn new(token: TokenFunc) -> Self {
41        Self { token }
42    }
43}
44
45#[async_trait::async_trait]
46impl AuthInjector for BearerTokenInjector {
47    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
48        let token = (self.token)()
49            .await
50            .map_err(|err| CliCoreError::message(format!("transport: bearer inject: {err}")))?;
51        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
52    }
53}
54
55/// Appends a named token cookie to the request.
56#[derive(Clone)]
57pub struct CookieInjector {
58    name: String,
59    token: TokenFunc,
60}
61
62impl std::fmt::Debug for CookieInjector {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("CookieInjector")
65            .field("name", &self.name)
66            .finish_non_exhaustive()
67    }
68}
69
70impl CookieInjector {
71    /// Creates a cookie injector from a cookie name and async token callback.
72    #[must_use]
73    pub fn new(name: impl Into<String>, token: TokenFunc) -> Self {
74        Self {
75            name: name.into(),
76            token,
77        }
78    }
79}
80
81#[async_trait::async_trait]
82impl AuthInjector for CookieInjector {
83    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
84        let token = (self.token)()
85            .await
86            .map_err(|err| CliCoreError::message(format!("transport: cookie inject: {err}")))?;
87        let cookie = format!("{}={}", self.name, token);
88        append_cookie(request, &cookie)
89    }
90}
91
92/// Injects HTTP basic auth.
93#[derive(Clone, Debug)]
94pub struct BasicAuthInjector {
95    username: String,
96    password: String,
97}
98
99impl BasicAuthInjector {
100    /// Creates a basic-auth injector.
101    #[must_use]
102    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
103        Self {
104            username: username.into(),
105            password: password.into(),
106        }
107    }
108}
109
110#[async_trait::async_trait]
111impl AuthInjector for BasicAuthInjector {
112    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
113        let encoded = STANDARD.encode(format!("{}:{}", self.username, self.password));
114        set_header(request, AUTHORIZATION, &format!("Basic {encoded}"))
115    }
116}
117
118/// Injects an `x-api-key` header.
119#[derive(Clone, Debug)]
120pub struct ApiKeyInjector {
121    key: String,
122}
123
124impl ApiKeyInjector {
125    /// Creates an API-key injector.
126    #[must_use]
127    pub fn new(key: impl Into<String>) -> Self {
128        Self { key: key.into() }
129    }
130}
131
132#[async_trait::async_trait]
133impl AuthInjector for ApiKeyInjector {
134    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
135        set_header(request, HeaderName::from_static("x-api-key"), &self.key)
136    }
137}
138
139/// Auth injector that leaves requests unchanged.
140#[derive(Clone, Copy, Debug, Default)]
141pub struct NoopInjector;
142
143#[async_trait::async_trait]
144impl AuthInjector for NoopInjector {
145    async fn inject(&self, _request: &mut reqwest::Request) -> Result<()> {
146        Ok(())
147    }
148}
149
150/// Resolves a credential from an auth provider and injects its token as bearer auth.
151#[derive(Clone, Debug)]
152pub struct ProviderBearerInjector {
153    provider: Arc<dyn AuthProvider>,
154    env: String,
155    token: Arc<Mutex<Option<String>>>,
156}
157
158impl ProviderBearerInjector {
159    /// Creates a provider-backed bearer injector for one environment.
160    #[must_use]
161    pub fn new(provider: Arc<dyn AuthProvider>, env: impl Into<String>) -> Self {
162        Self {
163            provider,
164            env: env.into(),
165            token: Arc::new(Mutex::new(None)),
166        }
167    }
168}
169
170#[async_trait::async_trait]
171impl AuthInjector for ProviderBearerInjector {
172    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
173        let mut cached = self.token.lock().await;
174        if cached.as_deref().is_none_or(str::is_empty) {
175            let credential = self
176                .provider
177                .get_credential(&self.env, "", "")
178                .await
179                .map_err(|err| {
180                    CliCoreError::message(format!("transport: provider bearer: {err}"))
181                })?;
182            *cached = Some(credential.token);
183        }
184        let Some(token) = cached.as_ref() else {
185            return Err(CliCoreError::message(
186                "transport: provider bearer: empty token cache",
187            ));
188        };
189        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
190    }
191}
192
193/// OAuth2 client-credentials injector with in-memory token caching.
194#[derive(Clone, Debug)]
195pub struct ClientCredentialsInjector {
196    token_url: String,
197    client_id: String,
198    client_secret: String,
199    scopes: String,
200    client: reqwest::Client,
201    token: Arc<Mutex<Option<CachedToken>>>,
202}
203
204#[derive(Clone, Debug)]
205struct CachedToken {
206    token: String,
207    expiry: Instant,
208}
209
210impl ClientCredentialsInjector {
211    /// Creates a client-credentials injector.
212    #[must_use]
213    pub fn new(
214        token_url: impl Into<String>,
215        client_id: impl Into<String>,
216        client_secret: impl Into<String>,
217        scopes: impl Into<String>,
218    ) -> Self {
219        Self {
220            token_url: token_url.into(),
221            client_id: client_id.into(),
222            client_secret: client_secret.into(),
223            scopes: scopes.into(),
224            client: reqwest::Client::new(),
225            token: Arc::new(Mutex::new(None)),
226        }
227    }
228
229    async fn get_token(&self) -> Result<String> {
230        let mut cached = self.token.lock().await;
231        if let Some(token) = cached.as_ref()
232            && !token.token.is_empty()
233            && Instant::now() < token.expiry
234        {
235            return Ok(token.token.clone());
236        }
237
238        let mut form = BTreeMap::from([
239            ("grant_type", "client_credentials"),
240            ("client_id", self.client_id.as_str()),
241            ("client_secret", self.client_secret.as_str()),
242        ]);
243        if !self.scopes.is_empty() {
244            form.insert("scope", self.scopes.as_str());
245        }
246
247        let response = timeout(
248            Duration::from_secs(30),
249            self.client
250                .post(&self.token_url)
251                .header(
252                    reqwest::header::CONTENT_TYPE,
253                    "application/x-www-form-urlencoded",
254                )
255                .form(&form)
256                .send(),
257        )
258        .await
259        .map_err(|_| CliCoreError::message("token request: timed out"))?
260        .map_err(|err| CliCoreError::message(format!("token request: {err}")))?;
261
262        if response.status() != reqwest::StatusCode::OK {
263            return Err(CliCoreError::message(format!(
264                "token request: status {}",
265                response.status().as_u16()
266            )));
267        }
268
269        #[derive(Deserialize)]
270        struct TokenResponse {
271            #[serde(default)]
272            access_token: String,
273            #[serde(default)]
274            expires_in: i64,
275        }
276
277        let token_response = response
278            .json::<TokenResponse>()
279            .await
280            .map_err(|err| CliCoreError::message(format!("decode token response: {err}")))?;
281
282        let expiry = if token_response.expires_in > 30 {
283            Instant::now() + Duration::from_secs((token_response.expires_in - 30) as u64)
284        } else {
285            Instant::now()
286        };
287        *cached = Some(CachedToken {
288            token: token_response.access_token.clone(),
289            expiry,
290        });
291        Ok(token_response.access_token)
292    }
293}
294
295#[async_trait::async_trait]
296impl AuthInjector for ClientCredentialsInjector {
297    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
298        let token = self.get_token().await.map_err(|err| {
299            CliCoreError::message(format!("transport: client credentials inject: {err}"))
300        })?;
301        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
302    }
303}
304
305fn set_header(request: &mut reqwest::Request, name: HeaderName, value: &str) -> Result<()> {
306    let value = HeaderValue::from_str(value)
307        .map_err(|err| CliCoreError::message(format!("transport: invalid header value: {err}")))?;
308    request.headers_mut().insert(name, value);
309    Ok(())
310}
311
312fn append_cookie(request: &mut reqwest::Request, cookie: &str) -> Result<()> {
313    let value = match request.headers().get(COOKIE) {
314        Some(existing) => {
315            let existing = existing.to_str().map_err(|err| {
316                CliCoreError::message(format!("transport: invalid header value: {err}"))
317            })?;
318            format!("{existing}; {cookie}")
319        }
320        None => cookie.to_owned(),
321    };
322    set_header(request, COOKIE, &value)
323}