Skip to main content

cli_engine/transport/
injector.rs

1use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use base64::{Engine, engine::general_purpose::STANDARD};
4use reqwest::header::{AUTHORIZATION, COOKIE, HeaderName, HeaderValue};
5use serde::Deserialize;
6use tokio::{
7    sync::Mutex,
8    time::{Instant, timeout},
9};
10
11use crate::{AuthProvider, CliCoreError, Result};
12
13/// Async callback that returns a token for request injection.
14pub type TokenFunc =
15    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
16
17#[async_trait::async_trait]
18/// Mutates an outbound request with authentication material.
19pub trait AuthInjector: Send + Sync + std::fmt::Debug {
20    /// Adds auth headers or cookies to `request`.
21    async fn inject(&self, request: &mut reqwest::Request) -> Result<()>;
22}
23
24/// Injects `Authorization: Bearer <token>`.
25#[derive(Clone)]
26pub struct BearerTokenInjector {
27    token: TokenFunc,
28}
29
30impl std::fmt::Debug for BearerTokenInjector {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("BearerTokenInjector")
33            .finish_non_exhaustive()
34    }
35}
36
37impl BearerTokenInjector {
38    /// Creates a bearer-token injector from an async token callback.
39    #[must_use]
40    pub fn new(token: TokenFunc) -> Self {
41        Self { token }
42    }
43}
44
45#[async_trait::async_trait]
46impl AuthInjector for BearerTokenInjector {
47    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
48        let token = (self.token)()
49            .await
50            .map_err(|err| CliCoreError::message(format!("transport: bearer inject: {err}")))?;
51        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
52    }
53}
54
55/// Appends a named token cookie to the request.
56#[derive(Clone)]
57pub struct CookieInjector {
58    name: String,
59    token: TokenFunc,
60}
61
62impl std::fmt::Debug for CookieInjector {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("CookieInjector")
65            .field("name", &self.name)
66            .finish_non_exhaustive()
67    }
68}
69
70impl CookieInjector {
71    /// Creates a cookie injector from a cookie name and async token callback.
72    #[must_use]
73    pub fn new(name: impl Into<String>, token: TokenFunc) -> Self {
74        Self {
75            name: name.into(),
76            token,
77        }
78    }
79}
80
81#[async_trait::async_trait]
82impl AuthInjector for CookieInjector {
83    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
84        let token = (self.token)()
85            .await
86            .map_err(|err| CliCoreError::message(format!("transport: cookie inject: {err}")))?;
87        let cookie = format!("{}={}", self.name, token);
88        append_cookie(request, &cookie)
89    }
90}
91
92/// Injects HTTP basic auth.
93#[derive(Clone, Debug)]
94pub struct BasicAuthInjector {
95    username: String,
96    password: String,
97}
98
99impl BasicAuthInjector {
100    /// Creates a basic-auth injector.
101    #[must_use]
102    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
103        Self {
104            username: username.into(),
105            password: password.into(),
106        }
107    }
108}
109
110#[async_trait::async_trait]
111impl AuthInjector for BasicAuthInjector {
112    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
113        let encoded = STANDARD.encode(format!("{}:{}", self.username, self.password));
114        set_header(request, AUTHORIZATION, &format!("Basic {encoded}"))
115    }
116}
117
118/// Injects an `x-api-key` header.
119#[derive(Clone, Debug)]
120pub struct ApiKeyInjector {
121    key: String,
122}
123
124impl ApiKeyInjector {
125    /// Creates an API-key injector.
126    #[must_use]
127    pub fn new(key: impl Into<String>) -> Self {
128        Self { key: key.into() }
129    }
130}
131
132#[async_trait::async_trait]
133impl AuthInjector for ApiKeyInjector {
134    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
135        set_header(request, HeaderName::from_static("x-api-key"), &self.key)
136    }
137}
138
139/// Auth injector that leaves requests unchanged.
140#[derive(Clone, Copy, Debug, Default)]
141pub struct NoopInjector;
142
143#[async_trait::async_trait]
144impl AuthInjector for NoopInjector {
145    async fn inject(&self, _request: &mut reqwest::Request) -> Result<()> {
146        Ok(())
147    }
148}
149
150/// Resolves a credential from an auth provider and injects its token as bearer auth.
151#[derive(Clone, Debug)]
152pub struct ProviderBearerInjector {
153    provider: Arc<dyn AuthProvider>,
154    env: String,
155    token: Arc<Mutex<Option<String>>>,
156}
157
158impl ProviderBearerInjector {
159    /// Creates a provider-backed bearer injector for one environment.
160    #[must_use]
161    pub fn new(provider: Arc<dyn AuthProvider>, env: impl Into<String>) -> Self {
162        Self {
163            provider,
164            env: env.into(),
165            token: Arc::new(Mutex::new(None)),
166        }
167    }
168}
169
170#[async_trait::async_trait]
171impl AuthInjector for ProviderBearerInjector {
172    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
173        let mut cached = self.token.lock().await;
174        if cached.as_deref().is_none_or(str::is_empty) {
175            // Scope-unaware on purpose: this fetches whatever token the provider
176            // has for `env` (no command scopes) and caches it for the injector's
177            // lifetime. A handler needing OAuth scope step-up over HTTP must
178            // resolve the wider scopes first (CredentialResolver::resolve_with_scopes),
179            // which populates the provider cache so the token fetched here already
180            // covers them; resolving after the first inject would send the
181            // narrower token.
182            let credential = self
183                .provider
184                .get_credential(&self.env, "", "")
185                .await
186                .map_err(|err| {
187                    CliCoreError::message(format!("transport: provider bearer: {err}"))
188                })?;
189            *cached = Some(credential.token);
190        }
191        let Some(token) = cached.as_ref() else {
192            return Err(CliCoreError::message(
193                "transport: provider bearer: empty token cache",
194            ));
195        };
196        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
197    }
198}
199
200/// OAuth2 client-credentials injector with in-memory token caching.
201#[derive(Clone, Debug)]
202pub struct ClientCredentialsInjector {
203    token_url: String,
204    client_id: String,
205    client_secret: String,
206    scopes: String,
207    client: reqwest::Client,
208    token: Arc<Mutex<Option<CachedToken>>>,
209}
210
211#[derive(Clone, Debug)]
212struct CachedToken {
213    token: String,
214    expiry: Instant,
215}
216
217impl ClientCredentialsInjector {
218    /// Creates a client-credentials injector.
219    #[must_use]
220    pub fn new(
221        token_url: impl Into<String>,
222        client_id: impl Into<String>,
223        client_secret: impl Into<String>,
224        scopes: impl Into<String>,
225    ) -> Self {
226        Self {
227            token_url: token_url.into(),
228            client_id: client_id.into(),
229            client_secret: client_secret.into(),
230            scopes: scopes.into(),
231            client: reqwest::Client::new(),
232            token: Arc::new(Mutex::new(None)),
233        }
234    }
235
236    async fn get_token(&self) -> Result<String> {
237        let mut cached = self.token.lock().await;
238        if let Some(token) = cached.as_ref()
239            && !token.token.is_empty()
240            && Instant::now() < token.expiry
241        {
242            return Ok(token.token.clone());
243        }
244
245        let mut form = BTreeMap::from([
246            ("grant_type", "client_credentials"),
247            ("client_id", self.client_id.as_str()),
248            ("client_secret", self.client_secret.as_str()),
249        ]);
250        if !self.scopes.is_empty() {
251            form.insert("scope", self.scopes.as_str());
252        }
253
254        let response = timeout(
255            Duration::from_secs(30),
256            self.client
257                .post(&self.token_url)
258                .header(
259                    reqwest::header::CONTENT_TYPE,
260                    "application/x-www-form-urlencoded",
261                )
262                .header(
263                    reqwest::header::USER_AGENT,
264                    crate::transport::client::default_user_agent(),
265                )
266                .form(&form)
267                .send(),
268        )
269        .await
270        .map_err(|_| CliCoreError::message("token request: timed out"))?
271        .map_err(|err| CliCoreError::message(format!("token request: {err}")))?;
272
273        if response.status() != reqwest::StatusCode::OK {
274            return Err(CliCoreError::message(format!(
275                "token request: status {}",
276                response.status().as_u16()
277            )));
278        }
279
280        #[derive(Deserialize)]
281        struct TokenResponse {
282            #[serde(default)]
283            access_token: String,
284            #[serde(default)]
285            expires_in: i64,
286        }
287
288        let token_response = response
289            .json::<TokenResponse>()
290            .await
291            .map_err(|err| CliCoreError::message(format!("decode token response: {err}")))?;
292
293        let expiry = if token_response.expires_in > 30 {
294            Instant::now() + Duration::from_secs((token_response.expires_in - 30) as u64)
295        } else {
296            Instant::now()
297        };
298        *cached = Some(CachedToken {
299            token: token_response.access_token.clone(),
300            expiry,
301        });
302        Ok(token_response.access_token)
303    }
304}
305
306#[async_trait::async_trait]
307impl AuthInjector for ClientCredentialsInjector {
308    async fn inject(&self, request: &mut reqwest::Request) -> Result<()> {
309        let token = self.get_token().await.map_err(|err| {
310            CliCoreError::message(format!("transport: client credentials inject: {err}"))
311        })?;
312        set_header(request, AUTHORIZATION, &format!("Bearer {token}"))
313    }
314}
315
316fn set_header(request: &mut reqwest::Request, name: HeaderName, value: &str) -> Result<()> {
317    let value = HeaderValue::from_str(value)
318        .map_err(|err| CliCoreError::message(format!("transport: invalid header value: {err}")))?;
319    request.headers_mut().insert(name, value);
320    Ok(())
321}
322
323fn append_cookie(request: &mut reqwest::Request, cookie: &str) -> Result<()> {
324    let value = match request.headers().get(COOKIE) {
325        Some(existing) => {
326            let existing = existing.to_str().map_err(|err| {
327                CliCoreError::message(format!("transport: invalid header value: {err}"))
328            })?;
329            format!("{existing}; {cookie}")
330        }
331        None => cookie.to_owned(),
332    };
333    set_header(request, COOKIE, &value)
334}