1use crate::error::{ProxyError, Result};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24use tokio::sync::RwLock;
25use tokio_rustls::TlsConnector;
26use tracing::{debug, warn};
27use zeroize::Zeroizing;
28
29const EXPIRY_BUFFER_SECS: u64 = 30;
32
33const DEFAULT_EXPIRES_IN_SECS: u64 = 3600;
35
36const EXCHANGE_TIMEOUT: Duration = Duration::from_secs(30);
38
39const MAX_TOKEN_RESPONSE: usize = 64 * 1024;
41
42pub struct OAuth2ExchangeConfig {
50 pub token_url: String,
51 pub client_id: Zeroizing<String>,
52 pub client_secret: Zeroizing<String>,
53 pub scope: String,
54}
55
56impl std::fmt::Debug for OAuth2ExchangeConfig {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("OAuth2ExchangeConfig")
60 .field("token_url", &self.token_url)
61 .field("client_id", &"[REDACTED]")
62 .field("client_secret", &"[REDACTED]")
63 .field("scope", &self.scope)
64 .finish()
65 }
66}
67
68pub struct TokenCache {
70 token: Arc<RwLock<CachedToken>>,
71 config: OAuth2ExchangeConfig,
72 tls_connector: TlsConnector,
73}
74
75impl std::fmt::Debug for TokenCache {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("TokenCache")
78 .field("config", &self.config)
79 .finish()
80 }
81}
82
83struct CachedToken {
88 access_token: Zeroizing<String>,
89 expires_at: Instant,
90}
91
92impl TokenCache {
97 pub fn new(config: OAuth2ExchangeConfig, tls_connector: TlsConnector) -> Result<Self> {
109 let (access_token, expires_in) = tokio::task::block_in_place(|| {
113 tokio::runtime::Handle::current().block_on(exchange_token(&config, &tls_connector))
114 })?;
115
116 let expires_at = Instant::now() + expires_in;
117 debug!(
118 "OAuth2 initial token acquired, expires in {}s",
119 expires_in.as_secs()
120 );
121
122 Ok(Self {
123 token: Arc::new(RwLock::new(CachedToken {
124 access_token,
125 expires_at,
126 })),
127 config,
128 tls_connector,
129 })
130 }
131
132 #[cfg(test)]
137 pub(crate) fn new_from_parts(
138 config: OAuth2ExchangeConfig,
139 tls_connector: TlsConnector,
140 token: &str,
141 ttl: Duration,
142 ) -> Self {
143 Self {
144 token: Arc::new(RwLock::new(CachedToken {
145 access_token: Zeroizing::new(token.to_string()),
146 expires_at: Instant::now() + ttl,
147 })),
148 config,
149 tls_connector,
150 }
151 }
152
153 pub async fn get_or_refresh(&self) -> Zeroizing<String> {
162 {
164 let guard = self.token.read().await;
165 if Instant::now() + Duration::from_secs(EXPIRY_BUFFER_SECS) < guard.expires_at {
166 return guard.access_token.clone();
167 }
168 }
169
170 let mut guard = self.token.write().await;
172
173 if Instant::now() + Duration::from_secs(EXPIRY_BUFFER_SECS) < guard.expires_at {
175 return guard.access_token.clone();
176 }
177
178 match exchange_token(&self.config, &self.tls_connector).await {
179 Ok((new_token, expires_in)) => {
180 debug!(
181 "OAuth2 token refreshed, expires in {}s",
182 expires_in.as_secs()
183 );
184 guard.access_token = new_token;
185 guard.expires_at = Instant::now() + expires_in;
186 guard.access_token.clone()
187 }
188 Err(e) => {
189 warn!("OAuth2 token refresh failed, returning stale token: {}", e);
190 guard.access_token.clone()
191 }
192 }
193 }
194}
195
196async fn exchange_token(
205 config: &OAuth2ExchangeConfig,
206 tls_connector: &TlsConnector,
207) -> Result<(Zeroizing<String>, Duration)> {
208 let parsed = url::Url::parse(&config.token_url).map_err(|e| {
209 ProxyError::OAuth2Exchange(format!("invalid token_url '{}': {}", config.token_url, e))
210 })?;
211
212 let scheme = parsed.scheme();
213 let is_https = match scheme {
214 "https" => true,
215 "http" => false,
216 other => {
217 return Err(ProxyError::OAuth2Exchange(format!(
218 "unsupported scheme '{}' in token_url",
219 other
220 )));
221 }
222 };
223
224 let host = parsed
225 .host_str()
226 .ok_or_else(|| {
227 ProxyError::OAuth2Exchange(format!("missing host in token_url '{}'", config.token_url))
228 })?
229 .to_string();
230
231 let default_port: u16 = if is_https { 443 } else { 80 };
232 let port = parsed.port().unwrap_or(default_port);
233 let path = if parsed.path().is_empty() {
234 "/"
235 } else {
236 parsed.path()
237 };
238 let path_with_query = match parsed.query() {
239 Some(q) => format!("{}?{}", path, q),
240 None => path.to_string(),
241 };
242
243 let body = build_token_request_body(&config.client_id, &config.client_secret, &config.scope);
245
246 let request = Zeroizing::new(format!(
248 "POST {} HTTP/1.1\r\n\
249 Host: {}\r\n\
250 Content-Type: application/x-www-form-urlencoded\r\n\
251 Content-Length: {}\r\n\
252 Accept: application/json\r\n\
253 Connection: close\r\n\
254 \r\n\
255 {}",
256 path_with_query,
257 host,
258 body.len(),
259 body.as_str()
260 ));
261
262 let addr = format!("{}:{}", host, port);
264
265 let response_bytes = tokio::time::timeout(EXCHANGE_TIMEOUT, async {
266 let tcp = TcpStream::connect(&addr)
267 .await
268 .map_err(|e| ProxyError::OAuth2Exchange(format!("TCP connect to {}: {}", addr, e)))?;
269
270 async fn send_and_read<S: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin>(
271 stream: &mut S,
272 request: &[u8],
273 host: &str,
274 ) -> Result<Vec<u8>> {
275 stream
276 .write_all(request)
277 .await
278 .map_err(|e| ProxyError::OAuth2Exchange(format!("write to {}: {}", host, e)))?;
279 stream
280 .flush()
281 .await
282 .map_err(|e| ProxyError::OAuth2Exchange(format!("flush to {}: {}", host, e)))?;
283 read_http_response(stream).await
284 }
285
286 if is_https {
287 let server_name =
288 rustls::pki_types::ServerName::try_from(host.clone()).map_err(|_| {
289 ProxyError::OAuth2Exchange(format!("invalid TLS server name: {}", host))
290 })?;
291
292 let mut tls = tls_connector.connect(server_name, tcp).await.map_err(|e| {
293 ProxyError::OAuth2Exchange(format!("TLS handshake with {}: {}", host, e))
294 })?;
295
296 send_and_read(&mut tls, request.as_bytes(), &host).await
297 } else {
298 let mut tcp = tcp;
299 send_and_read(&mut tcp, request.as_bytes(), &host).await
300 }
301 })
302 .await
303 .map_err(|_| ProxyError::OAuth2Exchange(format!("token exchange with {} timed out", addr)))??;
304
305 let response_str = String::from_utf8(response_bytes).map_err(|_| {
307 ProxyError::OAuth2Exchange("token endpoint returned non-UTF-8 response".to_string())
308 })?;
309
310 let body_start = response_str
312 .find("\r\n\r\n")
313 .map(|i| i + 4)
314 .or_else(|| response_str.find("\n\n").map(|i| i + 2))
315 .ok_or_else(|| {
316 ProxyError::OAuth2Exchange(
317 "malformed HTTP response: no header/body separator".to_string(),
318 )
319 })?;
320
321 let status_line = response_str.lines().next().unwrap_or("");
323 let status_code = parse_status_code(status_line);
324 if !(200..300).contains(&status_code) {
325 let body_preview: String = response_str[body_start..].chars().take(200).collect();
326 return Err(ProxyError::OAuth2Exchange(format!(
327 "token endpoint returned HTTP {}: {}",
328 status_code, body_preview
329 )));
330 }
331
332 let json_body = &response_str[body_start..];
333 parse_token_response(json_body)
334}
335
336async fn read_http_response<S: tokio::io::AsyncRead + Unpin>(stream: &mut S) -> Result<Vec<u8>> {
338 let mut buf = Vec::with_capacity(4096);
339 let mut tmp = [0u8; 4096];
340 loop {
341 let n = stream
342 .read(&mut tmp)
343 .await
344 .map_err(|e| ProxyError::OAuth2Exchange(format!("read response: {}", e)))?;
345 if n == 0 {
346 break;
347 }
348 buf.extend_from_slice(&tmp[..n]);
349 if buf.len() > MAX_TOKEN_RESPONSE {
350 return Err(ProxyError::OAuth2Exchange(format!(
351 "token response exceeds {} bytes",
352 MAX_TOKEN_RESPONSE
353 )));
354 }
355 }
356 Ok(buf)
357}
358
359fn parse_status_code(line: &str) -> u16 {
361 let mut parts = line.split_whitespace();
363 parts.nth(1).and_then(|code| code.parse().ok()).unwrap_or(0)
364}
365
366fn build_token_request_body(
374 client_id: &str,
375 client_secret: &str,
376 scope: &str,
377) -> Zeroizing<String> {
378 let mut body = Zeroizing::new(format!(
379 "grant_type=client_credentials&client_id={}&client_secret={}",
380 urlencoding::encode(client_id),
381 urlencoding::encode(client_secret),
382 ));
383 if !scope.is_empty() {
384 body.push_str(&format!("&scope={}", urlencoding::encode(scope)));
385 }
386 body
387}
388
389fn parse_token_response(json: &str) -> Result<(Zeroizing<String>, Duration)> {
395 let value: serde_json::Value = serde_json::from_str(json).map_err(|e| {
396 ProxyError::OAuth2Exchange(format!("invalid JSON from token endpoint: {}", e))
397 })?;
398
399 let access_token = value
400 .get("access_token")
401 .and_then(|v| v.as_str())
402 .ok_or_else(|| {
403 ProxyError::OAuth2Exchange("token response missing 'access_token' field".to_string())
404 })?;
405
406 let expires_in_secs = value
407 .get("expires_in")
408 .and_then(|v| v.as_u64())
409 .unwrap_or(DEFAULT_EXPIRES_IN_SECS);
410
411 Ok((
412 Zeroizing::new(access_token.to_string()),
413 Duration::from_secs(expires_in_secs),
414 ))
415}
416
417#[cfg(test)]
422#[allow(clippy::unwrap_used)]
423mod tests {
424 use super::*;
425
426 #[test]
429 fn test_parse_token_response_success() {
430 let json =
431 r#"{"access_token":"eyJhbGciOiJSUzI1NiJ9","token_type":"Bearer","expires_in":3600}"#;
432 let (token, expires) = parse_token_response(json).unwrap();
433 assert_eq!(token.as_str(), "eyJhbGciOiJSUzI1NiJ9");
434 assert_eq!(expires, Duration::from_secs(3600));
435 }
436
437 #[test]
438 fn test_parse_token_response_missing_expires_defaults() {
439 let json = r#"{"access_token":"tok_abc","token_type":"Bearer"}"#;
440 let (token, expires) = parse_token_response(json).unwrap();
441 assert_eq!(token.as_str(), "tok_abc");
442 assert_eq!(expires, Duration::from_secs(DEFAULT_EXPIRES_IN_SECS));
443 }
444
445 #[test]
446 fn test_parse_token_response_missing_access_token_errors() {
447 let json = r#"{"token_type":"Bearer","expires_in":3600}"#;
448 let err = parse_token_response(json).unwrap_err();
449 let msg = err.to_string();
450 assert!(
451 msg.contains("access_token"),
452 "error should mention access_token: {}",
453 msg
454 );
455 }
456
457 #[test]
458 fn test_parse_token_response_non_json_errors() {
459 let err = parse_token_response("this is not json").unwrap_err();
460 let msg = err.to_string();
461 assert!(
462 msg.contains("invalid JSON"),
463 "error should mention invalid JSON: {}",
464 msg
465 );
466 }
467
468 #[test]
471 fn test_build_token_request_body() {
472 let body = build_token_request_body("my-client", "s3cret!", "read write");
473 assert!(body.contains("grant_type=client_credentials"));
474 assert!(body.contains("client_id=my-client"));
475 assert!(body.contains("client_secret=s3cret%21"));
476 assert!(body.contains("scope=read%20write"));
477 }
478
479 #[test]
480 fn test_build_token_request_body_no_scope() {
481 let body = build_token_request_body("cid", "csec", "");
482 assert!(body.contains("grant_type=client_credentials"));
483 assert!(body.contains("client_id=cid"));
484 assert!(body.contains("client_secret=csec"));
485 assert!(!body.contains("scope="), "empty scope should be omitted");
486 }
487
488 #[test]
491 fn test_parse_status_code_200() {
492 assert_eq!(parse_status_code("HTTP/1.1 200 OK"), 200);
493 }
494
495 #[test]
496 fn test_parse_status_code_401() {
497 assert_eq!(parse_status_code("HTTP/1.1 401 Unauthorized"), 401);
498 }
499
500 #[test]
501 fn test_parse_status_code_garbage() {
502 assert_eq!(parse_status_code("not http"), 0);
503 }
504
505 #[tokio::test]
508 async fn test_token_cache_returns_valid_token() {
509 let cache = make_test_cache("valid_token", Duration::from_secs(3600));
511 let token = cache.get_or_refresh().await;
512 assert_eq!(token.as_str(), "valid_token");
513 }
514
515 #[tokio::test]
516 async fn test_token_cache_detects_expiry() {
517 let cache = make_test_cache("stale_token", Duration::from_secs(0));
520 {
522 let mut guard = cache.token.write().await;
523 guard.expires_at = Instant::now() - Duration::from_secs(10);
524 }
525 let token = cache.get_or_refresh().await;
526 assert_eq!(token.as_str(), "stale_token");
528 }
529
530 fn make_test_cache(token: &str, ttl: Duration) -> TokenCache {
536 let config = OAuth2ExchangeConfig {
537 token_url: "https://127.0.0.1:1/oauth/token".to_string(),
538 client_id: Zeroizing::new("test-client".to_string()),
539 client_secret: Zeroizing::new("test-secret".to_string()),
540 scope: String::new(),
541 };
542
543 let mut root_store = rustls::RootCertStore::empty();
545 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
546 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
547 rustls::crypto::ring::default_provider(),
548 ))
549 .with_safe_default_protocol_versions()
550 .unwrap()
551 .with_root_certificates(root_store)
552 .with_no_client_auth();
553 let tls_connector = TlsConnector::from(Arc::new(tls_config));
554
555 TokenCache::new_from_parts(config, tls_connector, token, ttl)
556 }
557}