Skip to main content

nono_proxy/
oauth2.rs

1//! OAuth2 `client_credentials` token exchange and caching.
2//!
3//! Provides [`TokenCache`] — a thread-safe cache that holds an OAuth2 access
4//! token and refreshes it on demand before expiry. Designed for the reverse
5//! proxy credential injection flow where the agent never sees the real
6//! client_id/client_secret.
7//!
8//! ## Design
9//!
10//! - **No background tasks**: Token validity is checked on each use via
11//!   [`TokenCache::get_or_refresh()`]. If the cached token is about to expire
12//!   (within 30 seconds), a synchronous refresh is attempted.
13//! - **Graceful degradation**: If a refresh attempt fails but a stale token
14//!   exists, the stale token is returned with a warning log. This avoids
15//!   transient auth-server outages from cascading into request failures.
16//! - **TLS via rustls**: Uses the same `webpki-roots` + `tokio-rustls` stack
17//!   as the rest of the proxy. No additional HTTP client dependencies.
18
19use crate::error::{ProxyError, Result};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24use tokio::sync::RwLock;
25use tokio_rustls::TlsConnector;
26use tracing::{debug, warn};
27use zeroize::Zeroizing;
28
29/// Buffer subtracted from `expires_in` to refresh before the token actually
30/// expires. Avoids edge cases where a token expires between check and use.
31const EXPIRY_BUFFER_SECS: u64 = 30;
32
33/// Default TTL when the token endpoint omits `expires_in`.
34const DEFAULT_EXPIRES_IN_SECS: u64 = 3600;
35
36/// Timeout for the TCP connect + TLS handshake + HTTP exchange.
37const EXCHANGE_TIMEOUT: Duration = Duration::from_secs(30);
38
39/// Maximum response body size from the token endpoint (64 KiB).
40const MAX_TOKEN_RESPONSE: usize = 64 * 1024;
41
42// ────────────────────────────────────────────────────────────────────────────
43// Public types
44// ────────────────────────────────────────────────────────────────────────────
45
46/// Resolved OAuth2 credentials ready for token exchange.
47///
48/// All secret fields use [`Zeroizing`] so they are zeroed on drop.
49pub struct OAuth2ExchangeConfig {
50    pub token_url: String,
51    pub client_id: Zeroizing<String>,
52    pub client_secret: Zeroizing<String>,
53    pub scope: String,
54}
55
56/// Custom Debug that redacts secrets.
57impl std::fmt::Debug for OAuth2ExchangeConfig {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("OAuth2ExchangeConfig")
60            .field("token_url", &self.token_url)
61            .field("client_id", &"[REDACTED]")
62            .field("client_secret", &"[REDACTED]")
63            .field("scope", &self.scope)
64            .finish()
65    }
66}
67
68/// Thread-safe OAuth2 access-token cache with on-demand refresh.
69pub struct TokenCache {
70    token: Arc<RwLock<CachedToken>>,
71    config: OAuth2ExchangeConfig,
72    tls_connector: TlsConnector,
73}
74
75impl std::fmt::Debug for TokenCache {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("TokenCache")
78            .field("config", &self.config)
79            .finish()
80    }
81}
82
83// ────────────────────────────────────────────────────────────────────────────
84// Internal types
85// ────────────────────────────────────────────────────────────────────────────
86
87struct CachedToken {
88    access_token: Zeroizing<String>,
89    expires_at: Instant,
90}
91
92// ────────────────────────────────────────────────────────────────────────────
93// TokenCache implementation
94// ────────────────────────────────────────────────────────────────────────────
95
96impl TokenCache {
97    /// Create a new cache and perform the **initial** token exchange.
98    ///
99    /// Called during [`CredentialStore::load()`](crate::credential::CredentialStore::load)
100    /// which is synchronous. We bridge into async via
101    /// [`tokio::runtime::Handle::current().block_on()`].
102    ///
103    /// # Errors
104    ///
105    /// Returns [`ProxyError::OAuth2Exchange`] if the initial exchange fails
106    /// (DNS, TCP, TLS, non-200, malformed JSON). The calling code skips the
107    /// route so the proxy can still start for other routes.
108    pub fn new(config: OAuth2ExchangeConfig, tls_connector: TlsConnector) -> Result<Self> {
109        // Use block_in_place to avoid panicking when called from within an
110        // async context (e.g., server::start() which is async). This moves
111        // the blocking work off the async worker thread.
112        let (access_token, expires_in) = tokio::task::block_in_place(|| {
113            tokio::runtime::Handle::current().block_on(exchange_token(&config, &tls_connector))
114        })?;
115
116        let expires_at = Instant::now() + expires_in;
117        debug!(
118            "OAuth2 initial token acquired, expires in {}s",
119            expires_in.as_secs()
120        );
121
122        Ok(Self {
123            token: Arc::new(RwLock::new(CachedToken {
124                access_token,
125                expires_at,
126            })),
127            config,
128            tls_connector,
129        })
130    }
131
132    /// Create a `TokenCache` with a pre-populated token (for testing).
133    ///
134    /// Skips the initial token exchange. Used by tests that need a cache
135    /// without a real OAuth2 server.
136    #[cfg(test)]
137    pub(crate) fn new_from_parts(
138        config: OAuth2ExchangeConfig,
139        tls_connector: TlsConnector,
140        token: &str,
141        ttl: Duration,
142    ) -> Self {
143        Self {
144            token: Arc::new(RwLock::new(CachedToken {
145                access_token: Zeroizing::new(token.to_string()),
146                expires_at: Instant::now() + ttl,
147            })),
148            config,
149            tls_connector,
150        }
151    }
152
153    /// Return a valid access token, refreshing if needed.
154    ///
155    /// If the cached token is still valid (expires > 30 s from now), returns
156    /// the cached value without any network call.
157    ///
158    /// If expired, attempts one exchange. On failure, returns the **stale**
159    /// token with a warning — better to try a possibly-expired token than to
160    /// fail the request outright.
161    pub async fn get_or_refresh(&self) -> Zeroizing<String> {
162        // Fast path — token still valid.
163        {
164            let guard = self.token.read().await;
165            if Instant::now() + Duration::from_secs(EXPIRY_BUFFER_SECS) < guard.expires_at {
166                return guard.access_token.clone();
167            }
168        }
169
170        // Slow path — need to refresh.
171        let mut guard = self.token.write().await;
172
173        // Double-check after acquiring write lock (another task may have refreshed).
174        if Instant::now() + Duration::from_secs(EXPIRY_BUFFER_SECS) < guard.expires_at {
175            return guard.access_token.clone();
176        }
177
178        match exchange_token(&self.config, &self.tls_connector).await {
179            Ok((new_token, expires_in)) => {
180                debug!(
181                    "OAuth2 token refreshed, expires in {}s",
182                    expires_in.as_secs()
183                );
184                guard.access_token = new_token;
185                guard.expires_at = Instant::now() + expires_in;
186                guard.access_token.clone()
187            }
188            Err(e) => {
189                warn!("OAuth2 token refresh failed, returning stale token: {}", e);
190                guard.access_token.clone()
191            }
192        }
193    }
194}
195
196// ────────────────────────────────────────────────────────────────────────────
197// Token exchange (HTTP POST)
198// ────────────────────────────────────────────────────────────────────────────
199
200/// Perform a single `client_credentials` token exchange against the token
201/// endpoint described in `config`.
202///
203/// Returns `(access_token, expires_in_duration)`.
204async fn exchange_token(
205    config: &OAuth2ExchangeConfig,
206    tls_connector: &TlsConnector,
207) -> Result<(Zeroizing<String>, Duration)> {
208    let parsed = url::Url::parse(&config.token_url).map_err(|e| {
209        ProxyError::OAuth2Exchange(format!("invalid token_url '{}': {}", config.token_url, e))
210    })?;
211
212    let scheme = parsed.scheme();
213    let is_https = match scheme {
214        "https" => true,
215        "http" => false,
216        other => {
217            return Err(ProxyError::OAuth2Exchange(format!(
218                "unsupported scheme '{}' in token_url",
219                other
220            )));
221        }
222    };
223
224    let host = parsed
225        .host_str()
226        .ok_or_else(|| {
227            ProxyError::OAuth2Exchange(format!("missing host in token_url '{}'", config.token_url))
228        })?
229        .to_string();
230
231    let default_port: u16 = if is_https { 443 } else { 80 };
232    let port = parsed.port().unwrap_or(default_port);
233    let path = if parsed.path().is_empty() {
234        "/"
235    } else {
236        parsed.path()
237    };
238    let path_with_query = match parsed.query() {
239        Some(q) => format!("{}?{}", path, q),
240        None => path.to_string(),
241    };
242
243    // ── Build form body ──────────────────────────────────────────────────
244    let body = build_token_request_body(&config.client_id, &config.client_secret, &config.scope);
245
246    // ── Build HTTP/1.1 request ───────────────────────────────────────────
247    let request = Zeroizing::new(format!(
248        "POST {} HTTP/1.1\r\n\
249         Host: {}\r\n\
250         Content-Type: application/x-www-form-urlencoded\r\n\
251         Content-Length: {}\r\n\
252         Accept: application/json\r\n\
253         Connection: close\r\n\
254         \r\n\
255         {}",
256        path_with_query,
257        host,
258        body.len(),
259        body.as_str()
260    ));
261
262    // ── TCP + optional TLS ───────────────────────────────────────────────
263    let addr = format!("{}:{}", host, port);
264
265    let response_bytes = tokio::time::timeout(EXCHANGE_TIMEOUT, async {
266        let tcp = TcpStream::connect(&addr)
267            .await
268            .map_err(|e| ProxyError::OAuth2Exchange(format!("TCP connect to {}: {}", addr, e)))?;
269
270        async fn send_and_read<S: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin>(
271            stream: &mut S,
272            request: &[u8],
273            host: &str,
274        ) -> Result<Vec<u8>> {
275            stream
276                .write_all(request)
277                .await
278                .map_err(|e| ProxyError::OAuth2Exchange(format!("write to {}: {}", host, e)))?;
279            stream
280                .flush()
281                .await
282                .map_err(|e| ProxyError::OAuth2Exchange(format!("flush to {}: {}", host, e)))?;
283            read_http_response(stream).await
284        }
285
286        if is_https {
287            let server_name =
288                rustls::pki_types::ServerName::try_from(host.clone()).map_err(|_| {
289                    ProxyError::OAuth2Exchange(format!("invalid TLS server name: {}", host))
290                })?;
291
292            let mut tls = tls_connector.connect(server_name, tcp).await.map_err(|e| {
293                ProxyError::OAuth2Exchange(format!("TLS handshake with {}: {}", host, e))
294            })?;
295
296            send_and_read(&mut tls, request.as_bytes(), &host).await
297        } else {
298            let mut tcp = tcp;
299            send_and_read(&mut tcp, request.as_bytes(), &host).await
300        }
301    })
302    .await
303    .map_err(|_| ProxyError::OAuth2Exchange(format!("token exchange with {} timed out", addr)))??;
304
305    // ── Parse HTTP response ──────────────────────────────────────────────
306    let response_str = String::from_utf8(response_bytes).map_err(|_| {
307        ProxyError::OAuth2Exchange("token endpoint returned non-UTF-8 response".to_string())
308    })?;
309
310    // Split headers from body
311    let body_start = response_str
312        .find("\r\n\r\n")
313        .map(|i| i + 4)
314        .or_else(|| response_str.find("\n\n").map(|i| i + 2))
315        .ok_or_else(|| {
316            ProxyError::OAuth2Exchange(
317                "malformed HTTP response: no header/body separator".to_string(),
318            )
319        })?;
320
321    // Check status code
322    let status_line = response_str.lines().next().unwrap_or("");
323    let status_code = parse_status_code(status_line);
324    if !(200..300).contains(&status_code) {
325        let body_preview: String = response_str[body_start..].chars().take(200).collect();
326        return Err(ProxyError::OAuth2Exchange(format!(
327            "token endpoint returned HTTP {}: {}",
328            status_code, body_preview
329        )));
330    }
331
332    let json_body = &response_str[body_start..];
333    parse_token_response(json_body)
334}
335
336/// Read a full HTTP response from a stream up to [`MAX_TOKEN_RESPONSE`] bytes.
337async fn read_http_response<S: tokio::io::AsyncRead + Unpin>(stream: &mut S) -> Result<Vec<u8>> {
338    let mut buf = Vec::with_capacity(4096);
339    let mut tmp = [0u8; 4096];
340    loop {
341        let n = stream
342            .read(&mut tmp)
343            .await
344            .map_err(|e| ProxyError::OAuth2Exchange(format!("read response: {}", e)))?;
345        if n == 0 {
346            break;
347        }
348        buf.extend_from_slice(&tmp[..n]);
349        if buf.len() > MAX_TOKEN_RESPONSE {
350            return Err(ProxyError::OAuth2Exchange(format!(
351                "token response exceeds {} bytes",
352                MAX_TOKEN_RESPONSE
353            )));
354        }
355    }
356    Ok(buf)
357}
358
359/// Parse the HTTP status code from the status line.
360fn parse_status_code(line: &str) -> u16 {
361    // "HTTP/1.1 200 OK" -> "200"
362    let mut parts = line.split_whitespace();
363    parts.nth(1).and_then(|code| code.parse().ok()).unwrap_or(0)
364}
365
366// ────────────────────────────────────────────────────────────────────────────
367// Request / response helpers (pub(crate) for testing)
368// ────────────────────────────────────────────────────────────────────────────
369
370/// Build the `application/x-www-form-urlencoded` body for the token request.
371///
372/// The `scope` parameter is omitted when empty.
373fn build_token_request_body(
374    client_id: &str,
375    client_secret: &str,
376    scope: &str,
377) -> Zeroizing<String> {
378    let mut body = Zeroizing::new(format!(
379        "grant_type=client_credentials&client_id={}&client_secret={}",
380        urlencoding::encode(client_id),
381        urlencoding::encode(client_secret),
382    ));
383    if !scope.is_empty() {
384        body.push_str(&format!("&scope={}", urlencoding::encode(scope)));
385    }
386    body
387}
388
389/// Parse a standard OAuth2 token response JSON.
390///
391/// Expects `{"access_token": "...", "expires_in": 3600, ...}`.
392/// - `access_token` is required.
393/// - `expires_in` defaults to [`DEFAULT_EXPIRES_IN_SECS`] if missing.
394fn parse_token_response(json: &str) -> Result<(Zeroizing<String>, Duration)> {
395    let value: serde_json::Value = serde_json::from_str(json).map_err(|e| {
396        ProxyError::OAuth2Exchange(format!("invalid JSON from token endpoint: {}", e))
397    })?;
398
399    let access_token = value
400        .get("access_token")
401        .and_then(|v| v.as_str())
402        .ok_or_else(|| {
403            ProxyError::OAuth2Exchange("token response missing 'access_token' field".to_string())
404        })?;
405
406    let expires_in_secs = value
407        .get("expires_in")
408        .and_then(|v| v.as_u64())
409        .unwrap_or(DEFAULT_EXPIRES_IN_SECS);
410
411    Ok((
412        Zeroizing::new(access_token.to_string()),
413        Duration::from_secs(expires_in_secs),
414    ))
415}
416
417// ────────────────────────────────────────────────────────────────────────────
418// Tests
419// ────────────────────────────────────────────────────────────────────────────
420
421#[cfg(test)]
422#[allow(clippy::unwrap_used)]
423mod tests {
424    use super::*;
425
426    // ── parse_token_response ─────────────────────────────────────────────
427
428    #[test]
429    fn test_parse_token_response_success() {
430        let json =
431            r#"{"access_token":"eyJhbGciOiJSUzI1NiJ9","token_type":"Bearer","expires_in":3600}"#;
432        let (token, expires) = parse_token_response(json).unwrap();
433        assert_eq!(token.as_str(), "eyJhbGciOiJSUzI1NiJ9");
434        assert_eq!(expires, Duration::from_secs(3600));
435    }
436
437    #[test]
438    fn test_parse_token_response_missing_expires_defaults() {
439        let json = r#"{"access_token":"tok_abc","token_type":"Bearer"}"#;
440        let (token, expires) = parse_token_response(json).unwrap();
441        assert_eq!(token.as_str(), "tok_abc");
442        assert_eq!(expires, Duration::from_secs(DEFAULT_EXPIRES_IN_SECS));
443    }
444
445    #[test]
446    fn test_parse_token_response_missing_access_token_errors() {
447        let json = r#"{"token_type":"Bearer","expires_in":3600}"#;
448        let err = parse_token_response(json).unwrap_err();
449        let msg = err.to_string();
450        assert!(
451            msg.contains("access_token"),
452            "error should mention access_token: {}",
453            msg
454        );
455    }
456
457    #[test]
458    fn test_parse_token_response_non_json_errors() {
459        let err = parse_token_response("this is not json").unwrap_err();
460        let msg = err.to_string();
461        assert!(
462            msg.contains("invalid JSON"),
463            "error should mention invalid JSON: {}",
464            msg
465        );
466    }
467
468    // ── build_token_request_body ─────────────────────────────────────────
469
470    #[test]
471    fn test_build_token_request_body() {
472        let body = build_token_request_body("my-client", "s3cret!", "read write");
473        assert!(body.contains("grant_type=client_credentials"));
474        assert!(body.contains("client_id=my-client"));
475        assert!(body.contains("client_secret=s3cret%21"));
476        assert!(body.contains("scope=read%20write"));
477    }
478
479    #[test]
480    fn test_build_token_request_body_no_scope() {
481        let body = build_token_request_body("cid", "csec", "");
482        assert!(body.contains("grant_type=client_credentials"));
483        assert!(body.contains("client_id=cid"));
484        assert!(body.contains("client_secret=csec"));
485        assert!(!body.contains("scope="), "empty scope should be omitted");
486    }
487
488    // ── parse_status_code ────────────────────────────────────────────────
489
490    #[test]
491    fn test_parse_status_code_200() {
492        assert_eq!(parse_status_code("HTTP/1.1 200 OK"), 200);
493    }
494
495    #[test]
496    fn test_parse_status_code_401() {
497        assert_eq!(parse_status_code("HTTP/1.1 401 Unauthorized"), 401);
498    }
499
500    #[test]
501    fn test_parse_status_code_garbage() {
502        assert_eq!(parse_status_code("not http"), 0);
503    }
504
505    // ── TokenCache expiry logic ──────────────────────────────────────────
506
507    #[tokio::test]
508    async fn test_token_cache_returns_valid_token() {
509        // Construct a cache with a token that expires far in the future.
510        let cache = make_test_cache("valid_token", Duration::from_secs(3600));
511        let token = cache.get_or_refresh().await;
512        assert_eq!(token.as_str(), "valid_token");
513    }
514
515    #[tokio::test]
516    async fn test_token_cache_detects_expiry() {
517        // Token that "expired" 10 seconds ago. Because exchange_token will
518        // fail (no real server), the stale token is returned.
519        let cache = make_test_cache("stale_token", Duration::from_secs(0));
520        // Manually set expires_at to the past.
521        {
522            let mut guard = cache.token.write().await;
523            guard.expires_at = Instant::now() - Duration::from_secs(10);
524        }
525        let token = cache.get_or_refresh().await;
526        // Should still get the stale token (graceful degradation).
527        assert_eq!(token.as_str(), "stale_token");
528    }
529
530    // ── Helpers ──────────────────────────────────────────────────────────
531
532    /// Build a `TokenCache` with a pre-populated token for unit tests.
533    /// The `exchange_token` config points to a non-routable address so any
534    /// actual exchange attempt will fail (which is fine — we test cache logic).
535    fn make_test_cache(token: &str, ttl: Duration) -> TokenCache {
536        let config = OAuth2ExchangeConfig {
537            token_url: "https://127.0.0.1:1/oauth/token".to_string(),
538            client_id: Zeroizing::new("test-client".to_string()),
539            client_secret: Zeroizing::new("test-secret".to_string()),
540            scope: String::new(),
541        };
542
543        // Build a minimal TLS connector (never actually used in these tests).
544        let mut root_store = rustls::RootCertStore::empty();
545        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
546        let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
547            rustls::crypto::ring::default_provider(),
548        ))
549        .with_safe_default_protocol_versions()
550        .unwrap()
551        .with_root_certificates(root_store)
552        .with_no_client_auth();
553        let tls_connector = TlsConnector::from(Arc::new(tls_config));
554
555        TokenCache::new_from_parts(config, tls_connector, token, ttl)
556    }
557}