1use anyhow::{Context, Result};
18use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
19use reqwest::Client;
20use serde::Deserialize;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::RwLock;
24
25use crate::config::{ApiKeyLocation, AuthConfig};
26
27fn find_char_boundary(s: &str, max: usize) -> usize {
29 if max >= s.len() {
30 return s.len();
31 }
32 let mut end = max;
33 while end > 0 && !s.is_char_boundary(end) {
34 end -= 1;
35 }
36 end
37}
38
39pub enum ResolvedAuth {
41 Bearer {
42 token: String,
43 },
44 ApiKeyHeader {
45 name: String,
46 value: String,
47 },
48 ApiKeyQuery {
49 name: String,
50 value: String,
51 },
52 Basic {
53 username: String,
54 password: String,
55 },
56 OAuth2 {
57 token_provider: Arc<OAuth2TokenProvider>,
58 },
59}
60
61pub struct OAuth2TokenProvider {
63 token_url: String,
64 client_id: String,
65 client_secret: String,
66 scopes: Vec<String>,
67 client: Client,
68 cached_token: RwLock<Option<CachedToken>>,
69}
70
71#[derive(Clone)]
72struct CachedToken {
73 access_token: String,
74 expires_at: Instant,
75}
76
77#[derive(Deserialize)]
78struct OAuth2TokenResponse {
79 access_token: String,
80 #[serde(default)]
81 expires_in: Option<u64>,
82 #[allow(dead_code)]
83 #[serde(default)]
84 token_type: Option<String>,
85}
86
87impl OAuth2TokenProvider {
88 pub fn new(
89 token_url: String,
90 client_id: String,
91 client_secret: String,
92 scopes: Vec<String>,
93 client: Client,
94 ) -> Self {
95 Self {
96 token_url,
97 client_id,
98 client_secret,
99 scopes,
100 client,
101 cached_token: RwLock::new(None),
102 }
103 }
104
105 pub async fn get_token(&self) -> Result<String> {
107 {
109 let cache = self.cached_token.read().await;
110 if let Some(ref cached) = *cache {
111 if Instant::now() < cached.expires_at {
112 return Ok(cached.access_token.clone());
113 }
114 }
115 }
116
117 let mut cache = self.cached_token.write().await;
119 if let Some(ref cached) = *cache {
120 if Instant::now() < cached.expires_at {
121 return Ok(cached.access_token.clone());
122 }
123 }
124
125 let token = self.fetch_token().await?;
127 let access_token = token.access_token.clone();
128 *cache = Some(token);
129
130 Ok(access_token)
131 }
132
133 async fn fetch_token(&self) -> Result<CachedToken> {
134 let mut form = vec![
135 ("grant_type", "client_credentials".to_string()),
136 ("client_id", self.client_id.clone()),
137 ("client_secret", self.client_secret.clone()),
138 ];
139
140 if !self.scopes.is_empty() {
141 form.push(("scope", self.scopes.join(" ")));
142 }
143
144 let response = self
145 .client
146 .post(&self.token_url)
147 .form(&form)
148 .send()
149 .await
150 .context("Failed to request OAuth2 token")?;
151
152 if !response.status().is_success() {
153 let status = response.status();
154 let body = response
155 .text()
156 .await
157 .unwrap_or_else(|_| "Unable to read response".to_string());
158 let truncated = if body.len() > 256 {
159 let end = find_char_boundary(&body, 256);
160 format!("{}... (truncated)", &body[..end])
161 } else {
162 body
163 };
164 return Err(anyhow::anyhow!(
165 "OAuth2 token request failed with status {status}: {truncated}"
166 ));
167 }
168
169 let token_response: OAuth2TokenResponse = response
170 .json()
171 .await
172 .context("Failed to parse OAuth2 token response")?;
173
174 let expires_in = token_response.expires_in.unwrap_or(3600);
176 let expires_at = Instant::now() + Duration::from_secs(expires_in.saturating_sub(60));
177
178 Ok(CachedToken {
179 access_token: token_response.access_token,
180 expires_at,
181 })
182 }
183}
184
185pub fn resolve_auth(config: &AuthConfig, client: &Client) -> Result<ResolvedAuth> {
187 match config {
188 AuthConfig::Bearer { token_env } => {
189 let token = std::env::var(token_env)
190 .with_context(|| format!("Environment variable '{token_env}' not set"))?;
191 Ok(ResolvedAuth::Bearer { token })
192 }
193 AuthConfig::ApiKey {
194 location,
195 name,
196 value_env,
197 } => {
198 let value = std::env::var(value_env)
199 .with_context(|| format!("Environment variable '{value_env}' not set"))?;
200 match location {
201 ApiKeyLocation::Header => Ok(ResolvedAuth::ApiKeyHeader {
202 name: name.clone(),
203 value,
204 }),
205 ApiKeyLocation::Query => Ok(ResolvedAuth::ApiKeyQuery {
206 name: name.clone(),
207 value,
208 }),
209 }
210 }
211 AuthConfig::Basic {
212 username_env,
213 password_env,
214 } => {
215 let username = std::env::var(username_env)
216 .with_context(|| format!("Environment variable '{username_env}' not set"))?;
217 let password = match password_env {
218 Some(env) => std::env::var(env)
219 .with_context(|| format!("Environment variable '{env}' not set"))?,
220 None => String::new(),
221 };
222 Ok(ResolvedAuth::Basic { username, password })
223 }
224 AuthConfig::OAuth2ClientCredentials {
225 token_url,
226 client_id_env,
227 client_secret_env,
228 scopes,
229 } => {
230 let client_id = std::env::var(client_id_env)
231 .with_context(|| format!("Environment variable '{client_id_env}' not set"))?;
232 let client_secret = std::env::var(client_secret_env)
233 .with_context(|| format!("Environment variable '{client_secret_env}' not set"))?;
234
235 let provider = OAuth2TokenProvider::new(
236 token_url.clone(),
237 client_id,
238 client_secret,
239 scopes.clone(),
240 client.clone(),
241 );
242
243 Ok(ResolvedAuth::OAuth2 {
244 token_provider: Arc::new(provider),
245 })
246 }
247 }
248}
249
250pub async fn apply_auth(
252 builder: reqwest::RequestBuilder,
253 auth: &ResolvedAuth,
254) -> Result<reqwest::RequestBuilder> {
255 match auth {
256 ResolvedAuth::Bearer { token } => Ok(builder.bearer_auth(token)),
257 ResolvedAuth::ApiKeyHeader { name, value } => {
258 let mut headers = HeaderMap::new();
259 let header_name = HeaderName::try_from(name.as_str())
260 .with_context(|| format!("Invalid header name: {name}"))?;
261 let header_value = HeaderValue::from_str(value)
262 .with_context(|| format!("Invalid header value for {name}"))?;
263 headers.insert(header_name, header_value);
264 Ok(builder.headers(headers))
265 }
266 ResolvedAuth::ApiKeyQuery { name, value } => Ok(builder.query(&[(name, value)])),
267 ResolvedAuth::Basic { username, password } => {
268 Ok(builder.basic_auth(username, Some(password)))
269 }
270 ResolvedAuth::OAuth2 { token_provider } => {
271 let token = token_provider
272 .get_token()
273 .await
274 .context("Failed to get OAuth2 token")?;
275 Ok(builder.bearer_auth(token))
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use axum::response::IntoResponse;
284
285 #[tokio::test]
286 async fn test_oauth2_token_caching() {
287 let request_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
289
290 let app = {
291 let request_count = request_count.clone();
292 axum::Router::new().route(
293 "/token",
294 axum::routing::post(move || {
295 let request_count = request_count.clone();
296 async move {
297 request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298 axum::Json(serde_json::json!({
299 "access_token": "test-token-123",
300 "expires_in": 3600,
301 "token_type": "Bearer"
302 }))
303 }
304 }),
305 )
306 };
307
308 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();
310 tokio::spawn(async move {
311 axum::serve(listener, app).await.unwrap();
312 });
313
314 let token_url = format!("http://127.0.0.1:{}/token", addr.port()); let client = Client::new();
316
317 let provider = OAuth2TokenProvider::new(
318 token_url,
319 "client-id".to_string(),
320 "client-secret".to_string(),
321 vec!["read".to_string()],
322 client,
323 );
324
325 let token1 = provider.get_token().await.unwrap();
327 assert_eq!(token1, "test-token-123");
328 assert_eq!(
329 request_count.load(std::sync::atomic::Ordering::SeqCst),
330 1,
331 "First call should hit the server"
332 );
333
334 let token2 = provider.get_token().await.unwrap();
336 assert_eq!(token2, "test-token-123");
337 assert_eq!(
338 request_count.load(std::sync::atomic::Ordering::SeqCst),
339 1,
340 "Second call should use cache, not hit server"
341 );
342 }
343
344 #[tokio::test]
345 async fn test_oauth2_token_refresh_on_expiry() {
346 let request_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
347
348 let app = {
349 let request_count = request_count.clone();
350 axum::Router::new().route(
351 "/token",
352 axum::routing::post(move || {
353 let request_count = request_count.clone();
354 async move {
355 let count = request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
356 axum::Json(serde_json::json!({
357 "access_token": format!("token-{}", count + 1),
358 "expires_in": 1,
359 "token_type": "Bearer"
360 }))
361 }
362 }),
363 )
364 };
365
366 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();
368 tokio::spawn(async move {
369 axum::serve(listener, app).await.unwrap();
370 });
371
372 let token_url = format!("http://127.0.0.1:{}/token", addr.port()); let client = Client::new();
374
375 let provider = OAuth2TokenProvider::new(
376 token_url,
377 "client-id".to_string(),
378 "client-secret".to_string(),
379 vec![],
380 client,
381 );
382
383 let token1 = provider.get_token().await.unwrap();
385 assert_eq!(token1, "token-1");
386
387 let token2 = provider.get_token().await.unwrap();
389 assert_eq!(token2, "token-2");
390 assert_eq!(
391 request_count.load(std::sync::atomic::Ordering::SeqCst),
392 2,
393 "Expired token should trigger refresh"
394 );
395 }
396
397 #[tokio::test]
398 async fn test_oauth2_error_is_truncated() {
399 let app = axum::Router::new().route(
400 "/token",
401 axum::routing::post(|| async {
402 let body = "x".repeat(500);
403 (axum::http::StatusCode::BAD_REQUEST, body).into_response()
404 }),
405 );
406
407 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();
409 tokio::spawn(async move {
410 axum::serve(listener, app).await.unwrap();
411 });
412
413 let token_url = format!("http://127.0.0.1:{}/token", addr.port()); let client = Client::new();
415
416 let provider = OAuth2TokenProvider::new(
417 token_url,
418 "client-id".to_string(),
419 "client-secret".to_string(),
420 vec![],
421 client,
422 );
423
424 let err = provider.get_token().await.unwrap_err();
425 let err_msg = format!("{err}");
426 assert!(
427 err_msg.contains("truncated"),
428 "Error should be truncated: {err_msg}"
429 );
430 assert!(err_msg.len() < 400, "Error message should be bounded");
431 }
432}