Skip to main content

brainwires_tool_runtime/
oauth.rs

1//! OAuth 2.0 middleware for tool integrations.
2//!
3//! Gives agents access to OAuth-protected APIs (Google, GitHub, Salesforce,
4//! Slack, …) without hard-coding tokens. The framework handles:
5//!
6//! - Authorization Code + PKCE flow (user-delegated)
7//! - Client Credentials flow (service-to-service)
8//! - Automatic token refresh on expiry or `401 Unauthorized`
9//! - Pluggable [`OAuthTokenStore`] for per-user token storage
10//!
11//! ## Example — client credentials
12//!
13//! ```rust,no_run
14//! use brainwires_tool_runtime::oauth::{OAuthConfig, OAuthFlow, OAuthClient, InMemoryTokenStore};
15//!
16//! # async fn example() -> anyhow::Result<()> {
17//! let config = OAuthConfig::client_credentials(
18//!     "https://provider.example.com/token",
19//!     "my-client-id",
20//!     "my-client-secret",
21//!     &["read:data", "write:data"],
22//! );
23//!
24//! let store = InMemoryTokenStore::new();
25//! let client = OAuthClient::new(config, store)?;
26//!
27//! // Returns a valid Bearer token, refreshing if necessary.
28//! let token = client.access_token("service-account").await?;
29//! println!("Bearer {token}");
30//! # Ok(())
31//! # }
32//! ```
33
34use std::{
35    collections::HashMap,
36    sync::{Arc, Mutex},
37    time::{Duration, SystemTime},
38};
39
40use async_trait::async_trait;
41use reqwest::Client;
42use serde::{Deserialize, Serialize};
43
44// ── Token types ───────────────────────────────────────────────────────────────
45
46/// An OAuth 2.0 token pair.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuthToken {
49    /// The bearer access token.
50    pub access_token: String,
51    /// Refresh token (absent for client-credentials flows that don't issue one).
52    pub refresh_token: Option<String>,
53    /// UTC Unix timestamp (seconds) when the access token expires.
54    pub expires_at: Option<u64>,
55    /// Granted scopes.
56    pub scope: Option<String>,
57    /// Token type (usually `"Bearer"`).
58    pub token_type: String,
59}
60
61impl OAuthToken {
62    /// Returns `true` if the token is known to have expired (with a 30 s buffer).
63    pub fn is_expired(&self) -> bool {
64        let Some(exp) = self.expires_at else {
65            return false;
66        };
67        let now = SystemTime::now()
68            .duration_since(SystemTime::UNIX_EPOCH)
69            .unwrap_or_default()
70            .as_secs();
71        now + 30 >= exp
72    }
73}
74
75// ── Token store ───────────────────────────────────────────────────────────────
76
77/// Pluggable storage for OAuth tokens.
78///
79/// Keys are `(user_id, provider)` pairs. Implement this to persist tokens
80/// across agent restarts (e.g. in a keyring, SQLite, or secrets manager).
81#[async_trait]
82pub trait OAuthTokenStore: Send + Sync + 'static {
83    /// Retrieve a stored token.
84    async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken>;
85    /// Store a token.
86    async fn set(&self, user_id: &str, provider: &str, token: OAuthToken);
87    /// Delete a token (e.g. on revocation).
88    async fn delete(&self, user_id: &str, provider: &str);
89}
90
91/// In-memory token store — tokens are lost when the process exits.
92#[derive(Clone, Default)]
93pub struct InMemoryTokenStore {
94    tokens: Arc<Mutex<HashMap<(String, String), OAuthToken>>>,
95}
96
97impl InMemoryTokenStore {
98    /// Create an empty store.
99    pub fn new() -> Self {
100        Self::default()
101    }
102}
103
104#[async_trait]
105impl OAuthTokenStore for InMemoryTokenStore {
106    async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken> {
107        self.tokens
108            .lock()
109            .unwrap()
110            .get(&(user_id.to_string(), provider.to_string()))
111            .cloned()
112    }
113
114    async fn set(&self, user_id: &str, provider: &str, token: OAuthToken) {
115        self.tokens
116            .lock()
117            .unwrap()
118            .insert((user_id.to_string(), provider.to_string()), token);
119    }
120
121    async fn delete(&self, user_id: &str, provider: &str) {
122        self.tokens
123            .lock()
124            .unwrap()
125            .remove(&(user_id.to_string(), provider.to_string()));
126    }
127}
128
129// ── Config ────────────────────────────────────────────────────────────────────
130
131/// Which OAuth 2.0 grant type to use.
132#[derive(Debug, Clone)]
133pub enum OAuthFlow {
134    /// Authorization Code + PKCE (RFC 7636) — user-delegated access.
135    AuthorizationCodePkce {
136        /// Authorization endpoint URL.
137        auth_url: String,
138        /// Token endpoint URL.
139        token_url: String,
140        /// Redirect URI registered with the OAuth provider.
141        redirect_uri: String,
142    },
143    /// Client Credentials (RFC 6749 §4.4) — service-to-service.
144    ClientCredentials {
145        /// Token endpoint URL.
146        token_url: String,
147    },
148    /// Refresh-only — no interactive flow; start with a pre-existing token.
149    RefreshOnly {
150        /// Token endpoint URL.
151        token_url: String,
152    },
153}
154
155/// OAuth 2.0 application configuration.
156#[derive(Debug, Clone)]
157pub struct OAuthConfig {
158    /// Human-readable provider name (e.g. `"google"`, `"github"`).
159    pub provider: String,
160    /// OAuth client ID.
161    pub client_id: String,
162    /// OAuth client secret (omit for public clients).
163    pub client_secret: Option<String>,
164    /// Requested permission scopes.
165    pub scopes: Vec<String>,
166    /// Grant flow.
167    pub flow: OAuthFlow,
168    /// HTTP request timeout (default: 30 s).
169    pub timeout: Duration,
170}
171
172impl OAuthConfig {
173    /// Build a client-credentials config.
174    pub fn client_credentials(
175        token_url: impl Into<String>,
176        client_id: impl Into<String>,
177        client_secret: impl Into<String>,
178        scopes: &[&str],
179    ) -> Self {
180        Self {
181            provider: "custom".to_string(),
182            client_id: client_id.into(),
183            client_secret: Some(client_secret.into()),
184            scopes: scopes.iter().map(|s| s.to_string()).collect(),
185            flow: OAuthFlow::ClientCredentials {
186                token_url: token_url.into(),
187            },
188            timeout: Duration::from_secs(30),
189        }
190    }
191
192    /// Build an Authorization Code + PKCE config.
193    pub fn authorization_code_pkce(
194        provider: impl Into<String>,
195        auth_url: impl Into<String>,
196        token_url: impl Into<String>,
197        redirect_uri: impl Into<String>,
198        client_id: impl Into<String>,
199        scopes: &[&str],
200    ) -> Self {
201        Self {
202            provider: provider.into(),
203            client_id: client_id.into(),
204            client_secret: None,
205            scopes: scopes.iter().map(|s| s.to_string()).collect(),
206            flow: OAuthFlow::AuthorizationCodePkce {
207                auth_url: auth_url.into(),
208                token_url: token_url.into(),
209                redirect_uri: redirect_uri.into(),
210            },
211            timeout: Duration::from_secs(30),
212        }
213    }
214
215    /// Override the HTTP timeout.
216    pub fn with_timeout(mut self, timeout: Duration) -> Self {
217        self.timeout = timeout;
218        self
219    }
220
221    /// Override the provider name.
222    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
223        self.provider = provider.into();
224        self
225    }
226}
227
228// ── PKCE helpers ──────────────────────────────────────────────────────────────
229
230/// A PKCE challenge pair.
231#[derive(Debug, Clone)]
232pub struct PkceChallenge {
233    /// The `code_verifier` to include in the token exchange request.
234    pub verifier: String,
235    /// The `code_challenge` to include in the authorization URL.
236    pub challenge: String,
237}
238
239impl PkceChallenge {
240    /// Generate a fresh PKCE challenge using SHA-256.
241    pub fn new() -> Self {
242        use sha2::{Digest, Sha256};
243
244        // 32 cryptographically random bytes → base64url verifier
245        let mut raw = [0u8; 32];
246        getrandom::getrandom(&mut raw).expect("CSPRNG unavailable");
247        let verifier = base64_url_encode(&raw);
248
249        // SHA-256(verifier) → base64url challenge
250        let mut hasher = Sha256::new();
251        hasher.update(verifier.as_bytes());
252        let digest = hasher.finalize();
253        let challenge = base64_url_encode(&digest);
254
255        Self {
256            verifier,
257            challenge,
258        }
259    }
260
261    /// Build the authorization URL with PKCE parameters appended.
262    pub fn authorization_url(
263        &self,
264        auth_url: &str,
265        client_id: &str,
266        redirect_uri: &str,
267        scopes: &[String],
268        state: &str,
269    ) -> String {
270        let scope = scopes.join(" ");
271        format!(
272            "{auth_url}?response_type=code\
273             &client_id={client_id}\
274             &redirect_uri={redirect_uri}\
275             &scope={scope}\
276             &state={state}\
277             &code_challenge={}\
278             &code_challenge_method=S256",
279            self.challenge
280        )
281    }
282}
283
284impl Default for PkceChallenge {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290fn base64_url_encode(data: &[u8]) -> String {
291    use std::fmt::Write;
292    // RFC 4648 base64url without padding
293    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
294    let mut out = String::with_capacity((data.len() * 4).div_ceil(3));
295    for chunk in data.chunks(3) {
296        let b0 = chunk[0] as usize;
297        let b1 = if chunk.len() > 1 {
298            chunk[1] as usize
299        } else {
300            0
301        };
302        let b2 = if chunk.len() > 2 {
303            chunk[2] as usize
304        } else {
305            0
306        };
307        let _ = write!(out, "{}", CHARS[(b0 >> 2) & 63] as char);
308        let _ = write!(out, "{}", CHARS[((b0 << 4) | (b1 >> 4)) & 63] as char);
309        if chunk.len() > 1 {
310            let _ = write!(out, "{}", CHARS[((b1 << 2) | (b2 >> 6)) & 63] as char);
311        }
312        if chunk.len() > 2 {
313            let _ = write!(out, "{}", CHARS[b2 & 63] as char);
314        }
315    }
316    out
317}
318
319// ── OAuthClient ───────────────────────────────────────────────────────────────
320
321/// OAuth 2.0 client that manages tokens on behalf of users.
322///
323/// Wrap this inside a tool implementation to produce a fresh Bearer token
324/// for every API call, automatically refreshing when the stored token expires.
325pub struct OAuthClient<S: OAuthTokenStore> {
326    config: OAuthConfig,
327    store: S,
328    http: Client,
329}
330
331impl<S: OAuthTokenStore> OAuthClient<S> {
332    /// Build a client from config and token store.
333    pub fn new(config: OAuthConfig, store: S) -> anyhow::Result<Self> {
334        let http = Client::builder().timeout(config.timeout).build()?;
335        Ok(Self {
336            config,
337            store,
338            http,
339        })
340    }
341
342    /// Return a valid access token for `user_id`, refreshing or fetching as needed.
343    ///
344    /// - If a non-expired token is in the store, it is returned immediately.
345    /// - If the token is expired and a refresh token is available, it is refreshed.
346    /// - If no token exists and the flow is `ClientCredentials`, a new token is fetched.
347    /// - Otherwise returns `Err` — the caller must initiate an interactive auth flow.
348    pub async fn access_token(&self, user_id: &str) -> anyhow::Result<String> {
349        // 1. Check store
350        if let Some(token) = self.store.get(user_id, &self.config.provider).await {
351            if !token.is_expired() {
352                return Ok(token.access_token.clone());
353            }
354            // Try to refresh
355            if let Some(refresh_token) = &token.refresh_token
356                && let Ok(refreshed) = self.refresh_token(refresh_token).await
357            {
358                self.store
359                    .set(user_id, &self.config.provider, refreshed.clone())
360                    .await;
361                return Ok(refreshed.access_token);
362                // Refresh failed — fall through to re-auth
363            }
364        }
365
366        // 2. Client credentials can fetch without user interaction
367        if let OAuthFlow::ClientCredentials { .. } = &self.config.flow {
368            let token = self.fetch_client_credentials().await?;
369            self.store
370                .set(user_id, &self.config.provider, token.clone())
371                .await;
372            return Ok(token.access_token);
373        }
374
375        anyhow::bail!(
376            "No valid token for user '{}' on provider '{}'. \
377             Initiate an authorization flow first via OAuthClient::authorization_url().",
378            user_id,
379            self.config.provider
380        )
381    }
382
383    /// Store a token that was obtained through an external interactive flow.
384    pub async fn store_token(&self, user_id: &str, token: OAuthToken) {
385        self.store.set(user_id, &self.config.provider, token).await;
386    }
387
388    /// Delete the stored token for a user (e.g. on sign-out or revocation).
389    pub async fn revoke(&self, user_id: &str) {
390        self.store.delete(user_id, &self.config.provider).await;
391    }
392
393    /// Exchange an authorization code for tokens (Authorization Code + PKCE).
394    ///
395    /// Call this after the user is redirected back to your `redirect_uri` with
396    /// a `code` parameter.
397    pub async fn exchange_code(&self, code: &str, verifier: &str) -> anyhow::Result<OAuthToken> {
398        let token_url = match &self.config.flow {
399            OAuthFlow::AuthorizationCodePkce {
400                token_url,
401                redirect_uri,
402                ..
403            } => (token_url.clone(), Some(redirect_uri.clone())),
404            _ => anyhow::bail!("exchange_code requires AuthorizationCodePkce flow"),
405        };
406
407        let mut params = vec![
408            ("grant_type", "authorization_code".to_string()),
409            ("code", code.to_string()),
410            ("client_id", self.config.client_id.clone()),
411            ("code_verifier", verifier.to_string()),
412        ];
413        if let Some(uri) = token_url.1 {
414            params.push(("redirect_uri", uri));
415        }
416        if let Some(secret) = &self.config.client_secret {
417            params.push(("client_secret", secret.clone()));
418        }
419
420        self.post_token(&token_url.0, &params).await
421    }
422
423    /// Build a PKCE authorization URL for the user to visit.
424    ///
425    /// Returns `(url, pkce_challenge)` — store the `challenge.verifier` so you
426    /// can pass it to [`exchange_code`] when the callback arrives.
427    pub fn authorization_url(&self, state: &str) -> anyhow::Result<(String, PkceChallenge)> {
428        match &self.config.flow {
429            OAuthFlow::AuthorizationCodePkce {
430                auth_url,
431                redirect_uri,
432                ..
433            } => {
434                let pkce = PkceChallenge::new();
435                let url = pkce.authorization_url(
436                    auth_url,
437                    &self.config.client_id,
438                    redirect_uri,
439                    &self.config.scopes,
440                    state,
441                );
442                Ok((url, pkce))
443            }
444            _ => anyhow::bail!("authorization_url requires AuthorizationCodePkce flow"),
445        }
446    }
447
448    // ── Internal ─────────────────────────────────────────────────────────────
449
450    async fn fetch_client_credentials(&self) -> anyhow::Result<OAuthToken> {
451        let token_url = match &self.config.flow {
452            OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
453            _ => anyhow::bail!("fetch_client_credentials called on non-ClientCredentials flow"),
454        };
455
456        let mut params = vec![
457            ("grant_type", "client_credentials".to_string()),
458            ("client_id", self.config.client_id.clone()),
459        ];
460        if !self.config.scopes.is_empty() {
461            params.push(("scope", self.config.scopes.join(" ")));
462        }
463        if let Some(secret) = &self.config.client_secret {
464            params.push(("client_secret", secret.clone()));
465        }
466
467        self.post_token(&token_url, &params).await
468    }
469
470    async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<OAuthToken> {
471        let token_url = match &self.config.flow {
472            OAuthFlow::AuthorizationCodePkce { token_url, .. } => token_url.clone(),
473            OAuthFlow::RefreshOnly { token_url } => token_url.clone(),
474            OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
475        };
476
477        let mut params = vec![
478            ("grant_type", "refresh_token".to_string()),
479            ("refresh_token", refresh_token.to_string()),
480            ("client_id", self.config.client_id.clone()),
481        ];
482        if let Some(secret) = &self.config.client_secret {
483            params.push(("client_secret", secret.clone()));
484        }
485
486        self.post_token(&token_url, &params).await
487    }
488
489    async fn post_token(&self, url: &str, params: &[(&str, String)]) -> anyhow::Result<OAuthToken> {
490        let resp = self
491            .http
492            .post(url)
493            .form(params)
494            .send()
495            .await
496            .map_err(|e| anyhow::anyhow!("Token request failed: {e}"))?;
497
498        let status = resp.status();
499        let body = resp.text().await.unwrap_or_default();
500        if !status.is_success() {
501            anyhow::bail!("Token endpoint returned {status}: {body}");
502        }
503
504        let raw: TokenResponse =
505            serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("Token parse error: {e}"))?;
506
507        let expires_at = raw.expires_in.map(|secs| {
508            SystemTime::now()
509                .duration_since(SystemTime::UNIX_EPOCH)
510                .unwrap_or_default()
511                .as_secs()
512                + secs
513        });
514
515        Ok(OAuthToken {
516            access_token: raw.access_token,
517            refresh_token: raw.refresh_token,
518            expires_at,
519            scope: raw.scope,
520            token_type: raw.token_type.unwrap_or_else(|| "Bearer".to_string()),
521        })
522    }
523}
524
525/// Raw OAuth token endpoint response.
526#[derive(Deserialize)]
527struct TokenResponse {
528    access_token: String,
529    refresh_token: Option<String>,
530    expires_in: Option<u64>,
531    scope: Option<String>,
532    token_type: Option<String>,
533}
534
535// ── Tests ─────────────────────────────────────────────────────────────────────
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn pkce_challenge_base64url_no_padding() {
543        let pkce = PkceChallenge::new();
544        assert!(!pkce.verifier.contains('='));
545        assert!(!pkce.challenge.contains('='));
546        assert!(!pkce.verifier.contains('+'));
547        assert!(!pkce.challenge.contains('+'));
548        assert!(!pkce.verifier.contains('/'));
549        assert!(!pkce.challenge.contains('/'));
550    }
551
552    #[test]
553    fn pkce_authorization_url_contains_required_params() {
554        let pkce = PkceChallenge::new();
555        let url = pkce.authorization_url(
556            "https://auth.example.com/authorize",
557            "client-abc",
558            "https://myapp.example.com/callback",
559            &["openid".to_string(), "profile".to_string()],
560            "random-state",
561        );
562        assert!(url.contains("response_type=code"));
563        assert!(url.contains("client_id=client-abc"));
564        assert!(url.contains("code_challenge_method=S256"));
565        assert!(url.contains(&pkce.challenge));
566        assert!(url.contains("state=random-state"));
567    }
568
569    #[test]
570    fn token_not_expired_without_expiry() {
571        let t = OAuthToken {
572            access_token: "tok".to_string(),
573            refresh_token: None,
574            expires_at: None,
575            scope: None,
576            token_type: "Bearer".to_string(),
577        };
578        assert!(!t.is_expired());
579    }
580
581    #[test]
582    fn token_expired_in_past() {
583        let t = OAuthToken {
584            access_token: "tok".to_string(),
585            refresh_token: None,
586            expires_at: Some(1), // way in the past
587            scope: None,
588            token_type: "Bearer".to_string(),
589        };
590        assert!(t.is_expired());
591    }
592
593    #[test]
594    fn in_memory_store_operations() {
595        let rt = tokio::runtime::Builder::new_current_thread()
596            .build()
597            .unwrap();
598        rt.block_on(async {
599            let store = InMemoryTokenStore::new();
600            let token = OAuthToken {
601                access_token: "abc".to_string(),
602                refresh_token: None,
603                expires_at: None,
604                scope: None,
605                token_type: "Bearer".to_string(),
606            };
607            store.set("user1", "github", token.clone()).await;
608            let fetched = store.get("user1", "github").await.unwrap();
609            assert_eq!(fetched.access_token, "abc");
610
611            store.delete("user1", "github").await;
612            assert!(store.get("user1", "github").await.is_none());
613        });
614    }
615
616    #[test]
617    fn config_client_credentials_builder() {
618        let cfg = OAuthConfig::client_credentials(
619            "https://token.example.com",
620            "id",
621            "secret",
622            &["read", "write"],
623        );
624        assert_eq!(cfg.scopes, vec!["read", "write"]);
625        matches!(cfg.flow, OAuthFlow::ClientCredentials { .. });
626    }
627}