Skip to main content

drasi_bootstrap_http/
auth.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Authentication strategies for HTTP bootstrap requests.
16
17use anyhow::{Context, Result};
18use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
19use reqwest::Client;
20use serde::Deserialize;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::RwLock;
24
25use crate::config::{ApiKeyLocation, AuthConfig};
26
27/// Find the largest byte index <= `max` that is a valid UTF-8 char boundary.
28fn find_char_boundary(s: &str, max: usize) -> usize {
29    if max >= s.len() {
30        return s.len();
31    }
32    let mut end = max;
33    while end > 0 && !s.is_char_boundary(end) {
34        end -= 1;
35    }
36    end
37}
38
39/// Resolved authentication that can be applied to requests.
40pub enum ResolvedAuth {
41    Bearer {
42        token: String,
43    },
44    ApiKeyHeader {
45        name: String,
46        value: String,
47    },
48    ApiKeyQuery {
49        name: String,
50        value: String,
51    },
52    Basic {
53        username: String,
54        password: String,
55    },
56    OAuth2 {
57        token_provider: Arc<OAuth2TokenProvider>,
58    },
59}
60
61/// OAuth2 token provider with caching.
62pub struct OAuth2TokenProvider {
63    token_url: String,
64    client_id: String,
65    client_secret: String,
66    scopes: Vec<String>,
67    client: Client,
68    cached_token: RwLock<Option<CachedToken>>,
69}
70
71#[derive(Clone)]
72struct CachedToken {
73    access_token: String,
74    expires_at: Instant,
75}
76
77#[derive(Deserialize)]
78struct OAuth2TokenResponse {
79    access_token: String,
80    #[serde(default)]
81    expires_in: Option<u64>,
82    #[allow(dead_code)]
83    #[serde(default)]
84    token_type: Option<String>,
85}
86
87impl OAuth2TokenProvider {
88    pub fn new(
89        token_url: String,
90        client_id: String,
91        client_secret: String,
92        scopes: Vec<String>,
93        client: Client,
94    ) -> Self {
95        Self {
96            token_url,
97            client_id,
98            client_secret,
99            scopes,
100            client,
101            cached_token: RwLock::new(None),
102        }
103    }
104
105    /// Get a valid access token, refreshing if expired.
106    pub async fn get_token(&self) -> Result<String> {
107        // Check cache first under read lock
108        {
109            let cache = self.cached_token.read().await;
110            if let Some(ref cached) = *cache {
111                if Instant::now() < cached.expires_at {
112                    return Ok(cached.access_token.clone());
113                }
114            }
115        }
116
117        // Acquire write lock and re-check to avoid stampede
118        let mut cache = self.cached_token.write().await;
119        if let Some(ref cached) = *cache {
120            if Instant::now() < cached.expires_at {
121                return Ok(cached.access_token.clone());
122            }
123        }
124
125        // Token expired or not cached, fetch new one
126        let token = self.fetch_token().await?;
127        let access_token = token.access_token.clone();
128        *cache = Some(token);
129
130        Ok(access_token)
131    }
132
133    async fn fetch_token(&self) -> Result<CachedToken> {
134        let mut form = vec![
135            ("grant_type", "client_credentials".to_string()),
136            ("client_id", self.client_id.clone()),
137            ("client_secret", self.client_secret.clone()),
138        ];
139
140        if !self.scopes.is_empty() {
141            form.push(("scope", self.scopes.join(" ")));
142        }
143
144        let response = self
145            .client
146            .post(&self.token_url)
147            .form(&form)
148            .send()
149            .await
150            .context("Failed to request OAuth2 token")?;
151
152        if !response.status().is_success() {
153            let status = response.status();
154            let body = response
155                .text()
156                .await
157                .unwrap_or_else(|_| "Unable to read response".to_string());
158            let truncated = if body.len() > 256 {
159                let end = find_char_boundary(&body, 256);
160                format!("{}... (truncated)", &body[..end])
161            } else {
162                body
163            };
164            return Err(anyhow::anyhow!(
165                "OAuth2 token request failed with status {status}: {truncated}"
166            ));
167        }
168
169        let token_response: OAuth2TokenResponse = response
170            .json()
171            .await
172            .context("Failed to parse OAuth2 token response")?;
173
174        // Default to 1 hour expiry with 60-second safety margin
175        let expires_in = token_response.expires_in.unwrap_or(3600);
176        let expires_at = Instant::now() + Duration::from_secs(expires_in.saturating_sub(60));
177
178        Ok(CachedToken {
179            access_token: token_response.access_token,
180            expires_at,
181        })
182    }
183}
184
185/// Resolve an AuthConfig into a ResolvedAuth by reading environment variables.
186pub fn resolve_auth(config: &AuthConfig, client: &Client) -> Result<ResolvedAuth> {
187    match config {
188        AuthConfig::Bearer { token_env } => {
189            let token = std::env::var(token_env)
190                .with_context(|| format!("Environment variable '{token_env}' not set"))?;
191            Ok(ResolvedAuth::Bearer { token })
192        }
193        AuthConfig::ApiKey {
194            location,
195            name,
196            value_env,
197        } => {
198            let value = std::env::var(value_env)
199                .with_context(|| format!("Environment variable '{value_env}' not set"))?;
200            match location {
201                ApiKeyLocation::Header => Ok(ResolvedAuth::ApiKeyHeader {
202                    name: name.clone(),
203                    value,
204                }),
205                ApiKeyLocation::Query => Ok(ResolvedAuth::ApiKeyQuery {
206                    name: name.clone(),
207                    value,
208                }),
209            }
210        }
211        AuthConfig::Basic {
212            username_env,
213            password_env,
214        } => {
215            let username = std::env::var(username_env)
216                .with_context(|| format!("Environment variable '{username_env}' not set"))?;
217            let password = match password_env {
218                Some(env) => std::env::var(env)
219                    .with_context(|| format!("Environment variable '{env}' not set"))?,
220                None => String::new(),
221            };
222            Ok(ResolvedAuth::Basic { username, password })
223        }
224        AuthConfig::OAuth2ClientCredentials {
225            token_url,
226            client_id_env,
227            client_secret_env,
228            scopes,
229        } => {
230            let client_id = std::env::var(client_id_env)
231                .with_context(|| format!("Environment variable '{client_id_env}' not set"))?;
232            let client_secret = std::env::var(client_secret_env)
233                .with_context(|| format!("Environment variable '{client_secret_env}' not set"))?;
234
235            let provider = OAuth2TokenProvider::new(
236                token_url.clone(),
237                client_id,
238                client_secret,
239                scopes.clone(),
240                client.clone(),
241            );
242
243            Ok(ResolvedAuth::OAuth2 {
244                token_provider: Arc::new(provider),
245            })
246        }
247    }
248}
249
250/// Apply resolved authentication to a request builder.
251pub async fn apply_auth(
252    builder: reqwest::RequestBuilder,
253    auth: &ResolvedAuth,
254) -> Result<reqwest::RequestBuilder> {
255    match auth {
256        ResolvedAuth::Bearer { token } => Ok(builder.bearer_auth(token)),
257        ResolvedAuth::ApiKeyHeader { name, value } => {
258            let mut headers = HeaderMap::new();
259            let header_name = HeaderName::try_from(name.as_str())
260                .with_context(|| format!("Invalid header name: {name}"))?;
261            let header_value = HeaderValue::from_str(value)
262                .with_context(|| format!("Invalid header value for {name}"))?;
263            headers.insert(header_name, header_value);
264            Ok(builder.headers(headers))
265        }
266        ResolvedAuth::ApiKeyQuery { name, value } => Ok(builder.query(&[(name, value)])),
267        ResolvedAuth::Basic { username, password } => {
268            Ok(builder.basic_auth(username, Some(password)))
269        }
270        ResolvedAuth::OAuth2 { token_provider } => {
271            let token = token_provider
272                .get_token()
273                .await
274                .context("Failed to get OAuth2 token")?;
275            Ok(builder.bearer_auth(token))
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use axum::response::IntoResponse;
284
285    #[tokio::test]
286    async fn test_oauth2_token_caching() {
287        // Start a mock token server that counts requests
288        let request_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
289
290        let app = {
291            let request_count = request_count.clone();
292            axum::Router::new().route(
293                "/token",
294                axum::routing::post(move || {
295                    let request_count = request_count.clone();
296                    async move {
297                        request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298                        axum::Json(serde_json::json!({
299                            "access_token": "test-token-123",
300                            "expires_in": 3600,
301                            "token_type": "Bearer"
302                        }))
303                    }
304                }),
305            )
306        };
307
308        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); // DevSkim: ignore DS137138
309        let addr = listener.local_addr().unwrap();
310        tokio::spawn(async move {
311            axum::serve(listener, app).await.unwrap();
312        });
313
314        let token_url = format!("http://127.0.0.1:{}/token", addr.port()); // DevSkim: ignore DS137138
315        let client = Client::new();
316
317        let provider = OAuth2TokenProvider::new(
318            token_url,
319            "client-id".to_string(),
320            "client-secret".to_string(),
321            vec!["read".to_string()],
322            client,
323        );
324
325        // First call fetches from server
326        let token1 = provider.get_token().await.unwrap();
327        assert_eq!(token1, "test-token-123");
328        assert_eq!(
329            request_count.load(std::sync::atomic::Ordering::SeqCst),
330            1,
331            "First call should hit the server"
332        );
333
334        // Second call should return cached token (no additional request)
335        let token2 = provider.get_token().await.unwrap();
336        assert_eq!(token2, "test-token-123");
337        assert_eq!(
338            request_count.load(std::sync::atomic::Ordering::SeqCst),
339            1,
340            "Second call should use cache, not hit server"
341        );
342    }
343
344    #[tokio::test]
345    async fn test_oauth2_token_refresh_on_expiry() {
346        let request_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
347
348        let app = {
349            let request_count = request_count.clone();
350            axum::Router::new().route(
351                "/token",
352                axum::routing::post(move || {
353                    let request_count = request_count.clone();
354                    async move {
355                        let count = request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
356                        axum::Json(serde_json::json!({
357                            "access_token": format!("token-{}", count + 1),
358                            "expires_in": 1,
359                            "token_type": "Bearer"
360                        }))
361                    }
362                }),
363            )
364        };
365
366        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); // DevSkim: ignore DS137138
367        let addr = listener.local_addr().unwrap();
368        tokio::spawn(async move {
369            axum::serve(listener, app).await.unwrap();
370        });
371
372        let token_url = format!("http://127.0.0.1:{}/token", addr.port()); // DevSkim: ignore DS137138
373        let client = Client::new();
374
375        let provider = OAuth2TokenProvider::new(
376            token_url,
377            "client-id".to_string(),
378            "client-secret".to_string(),
379            vec![],
380            client,
381        );
382
383        // First call — token expires immediately (1s - 60s safety = already expired)
384        let token1 = provider.get_token().await.unwrap();
385        assert_eq!(token1, "token-1");
386
387        // Second call should refresh since token is already expired
388        let token2 = provider.get_token().await.unwrap();
389        assert_eq!(token2, "token-2");
390        assert_eq!(
391            request_count.load(std::sync::atomic::Ordering::SeqCst),
392            2,
393            "Expired token should trigger refresh"
394        );
395    }
396
397    #[tokio::test]
398    async fn test_oauth2_error_is_truncated() {
399        let app = axum::Router::new().route(
400            "/token",
401            axum::routing::post(|| async {
402                let body = "x".repeat(500);
403                (axum::http::StatusCode::BAD_REQUEST, body).into_response()
404            }),
405        );
406
407        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); // DevSkim: ignore DS137138
408        let addr = listener.local_addr().unwrap();
409        tokio::spawn(async move {
410            axum::serve(listener, app).await.unwrap();
411        });
412
413        let token_url = format!("http://127.0.0.1:{}/token", addr.port()); // DevSkim: ignore DS137138
414        let client = Client::new();
415
416        let provider = OAuth2TokenProvider::new(
417            token_url,
418            "client-id".to_string(),
419            "client-secret".to_string(),
420            vec![],
421            client,
422        );
423
424        let err = provider.get_token().await.unwrap_err();
425        let err_msg = format!("{err}");
426        assert!(
427            err_msg.contains("truncated"),
428            "Error should be truncated: {err_msg}"
429        );
430        assert!(err_msg.len() < 400, "Error message should be bounded");
431    }
432}