Skip to main content

adk_rs/auth/
manager.rs

1//! [`CredentialManager`] — the orchestrator that resolves a tool's auth needs
2//! into a usable credential. Ports Python ADK's 8-step workflow.
3
4use chrono::Utc;
5use std::sync::Arc;
6
7use crate::auth::config::AuthConfig;
8use crate::auth::credential::{AuthCredential, AuthCredentialType, OAuth2Auth};
9use crate::auth::exchanger::ExchangerRegistry;
10use crate::auth::handler::AuthHandler;
11use crate::auth::provider::AuthProviderRegistry;
12use crate::auth::refresher::RefresherRegistry;
13use crate::auth::scheme::AuthScheme;
14use crate::auth::service::CredentialService;
15use crate::error::{Error, Result};
16
17/// Outcome of a [`CredentialManager::resolve`] call.
18#[derive(Debug, Clone)]
19pub enum ResolveOutcome {
20    /// A usable credential is ready. Hand to the tool.
21    Ready(AuthCredential),
22    /// Interactive consent is required. The runner should emit
23    /// `adk_request_credential` and pause the tool call.
24    NeedsUserConsent(AuthConfig),
25    /// Configuration error — the tool can't be invoked.
26    Misconfigured(String),
27}
28
29/// Output of [`CredentialManager::begin_consent`]. The caller redirects the
30/// user to `auth_uri`; when the provider redirects back with `code` + `state`,
31/// pass everything to [`CredentialManager::complete_consent`] (which will
32/// validate `state` matches and exchange `code`).
33#[derive(Debug, Clone)]
34pub struct ConsentRequest {
35    /// URL the user should be sent to (contains `state`, `code_challenge`,
36    /// `client_id`, scopes etc.).
37    pub auth_uri: String,
38    /// Opaque flow id the caller persists alongside its own UI state. Passed
39    /// back to `complete_consent`.
40    pub flow_id: String,
41}
42
43/// Persisted shape of an in-flight consent. Stored in the
44/// [`CredentialService`] under `__pending_consent:<flow_id>` so we can
45/// recover the CSRF state and PKCE verifier across the redirect (which may
46/// happen in a different process).
47const PENDING_CONSENT_PREFIX: &str = "__pending_consent:";
48
49fn pending_consent_key(flow_id: &str) -> String {
50    format!("{PENDING_CONSENT_PREFIX}{flow_id}")
51}
52
53/// Resolves [`AuthConfig`] into a ready [`AuthCredential`] per the 8-step
54/// workflow:
55///
56/// 1. validate config
57/// 2. return immediately if `is_ready` and not expired
58/// 3. try cache: `credential_service.load(app, user, key)`
59/// 4. (preprocessor-stored) auth response (handled at runner layer)
60/// 5. authorization-code flow with no exchanged credential → `NeedsUserConsent`
61/// 6. exchange (service-account / authorization-code → access token)
62/// 7. refresh if expired
63/// 8. save back to credential service
64#[derive(Debug)]
65pub struct CredentialManager {
66    config: AuthConfig,
67    exchangers: Arc<ExchangerRegistry>,
68    refreshers: Arc<RefresherRegistry>,
69    providers: Arc<AuthProviderRegistry>,
70}
71
72impl CredentialManager {
73    /// Construct with default exchangers + refreshers.
74    #[must_use]
75    pub fn new(config: AuthConfig) -> Self {
76        Self {
77            config,
78            exchangers: Arc::new(ExchangerRegistry::with_defaults()),
79            refreshers: Arc::new(RefresherRegistry::with_defaults()),
80            providers: Arc::new(AuthProviderRegistry::new()),
81        }
82    }
83
84    /// Construct with explicit registries (override for tests / custom providers).
85    #[must_use]
86    pub fn with_registries(
87        config: AuthConfig,
88        exchangers: Arc<ExchangerRegistry>,
89        refreshers: Arc<RefresherRegistry>,
90        providers: Arc<AuthProviderRegistry>,
91    ) -> Self {
92        Self {
93            config,
94            exchangers,
95            refreshers,
96            providers,
97        }
98    }
99
100    /// The cache key this manager resolves to.
101    #[must_use]
102    pub fn credential_key(&self) -> String {
103        self.config.resolve_credential_key()
104    }
105
106    /// Borrowed view of the wrapped config.
107    #[must_use]
108    pub fn config(&self) -> &AuthConfig {
109        &self.config
110    }
111
112    /// Run the resolution workflow.
113    pub async fn resolve(
114        &self,
115        app: &str,
116        user: &str,
117        credentials: Option<&dyn CredentialService>,
118    ) -> Result<ResolveOutcome> {
119        let raw = self
120            .config
121            .raw_auth_credential
122            .as_ref()
123            .ok_or_else(|| Error::config("AuthConfig.raw_auth_credential is required"))?;
124
125        // Step 2: already-ready and not expired? hand back.
126        let now = Utc::now().timestamp();
127        if raw.is_ready() && !raw.is_expired(now) {
128            return Ok(ResolveOutcome::Ready(raw.clone()));
129        }
130
131        let key = self.config.resolve_credential_key();
132
133        // Step 3: try cache.
134        if let Some(svc) = credentials {
135            if let Some(cached) = svc.load(app, user, &key).await? {
136                if cached.is_ready() && !cached.is_expired(now) {
137                    return Ok(ResolveOutcome::Ready(cached));
138                }
139                // Cached but expired — fall through to refresh.
140                if let Some(r) = self.refreshers.get(cached.auth_type) {
141                    if let Some(refreshed) = r.refresh(&self.config, &cached).await? {
142                        svc.save(app, user, &key, &refreshed).await?;
143                        return Ok(ResolveOutcome::Ready(refreshed));
144                    }
145                }
146            }
147        }
148
149        // Step 5: authorization-code flow with no consent yet → bubble out.
150        if matches!(
151            raw.auth_type,
152            AuthCredentialType::OAuth2 | AuthCredentialType::OpenIdConnect
153        ) && raw
154            .oauth2
155            .as_ref()
156            .is_some_and(|o| o.auth_code.is_none() && o.access_token.is_none())
157        {
158            return Ok(ResolveOutcome::NeedsUserConsent(self.config.clone()));
159        }
160
161        // Step 6: exchange.
162        if let Some(ex) = self.exchangers.get(raw.auth_type) {
163            if let Some(exchanged) = ex.exchange(&self.config, raw).await? {
164                if let Some(svc) = credentials {
165                    svc.save(app, user, &key, &exchanged).await?;
166                }
167                return Ok(ResolveOutcome::Ready(exchanged));
168            }
169        }
170
171        // Step 6b: custom provider escape hatch.
172        if let Some(prov) = self.providers.get(self.config.auth_scheme.kind()) {
173            if let Some(c) = prov.get_auth_credential(&self.config).await? {
174                if let Some(svc) = credentials {
175                    svc.save(app, user, &key, &c).await?;
176                }
177                return Ok(ResolveOutcome::Ready(c));
178            }
179        }
180
181        Ok(ResolveOutcome::Misconfigured(format!(
182            "no exchanger registered for {:?}; credential not ready",
183            raw.auth_type
184        )))
185    }
186
187    /// Start an OAuth 2.0 authorization-code consent flow.
188    ///
189    /// Generates a fresh CSRF `state` + PKCE verifier via the `oauth2` crate,
190    /// persists them in `credentials` keyed by an opaque `flow_id`, and
191    /// returns the URL the caller should redirect the user to. After the
192    /// provider redirects back, call [`Self::complete_consent`] with the
193    /// `flow_id`, the inbound `state`, and the inbound authorization `code` —
194    /// it will reject any mismatched state, perform the token exchange, and
195    /// save the resolved credential under the regular cache key.
196    ///
197    /// **Requires** a `credentials` service: the verifier and state must
198    /// outlive the HTTP redirect, so transient `None` storage isn't an
199    /// option here.
200    pub async fn begin_consent(
201        &self,
202        credentials: &dyn CredentialService,
203    ) -> Result<ConsentRequest> {
204        let raw = self
205            .config
206            .raw_auth_credential
207            .as_ref()
208            .ok_or_else(|| Error::config("AuthConfig.raw_auth_credential is required"))?;
209        let oauth2 = raw
210            .oauth2
211            .as_ref()
212            .ok_or_else(|| Error::config("begin_consent requires an OAuth2 credential"))?;
213        if !matches!(
214            self.config.auth_scheme,
215            AuthScheme::OAuth2 { .. } | AuthScheme::OpenIdConnect { .. }
216        ) {
217            return Err(Error::config(
218                "begin_consent requires an OAuth2 / OpenIdConnect scheme",
219            ));
220        }
221
222        let mut populated = oauth2.clone();
223        attach_flow_endpoints(&mut populated, &self.config.auth_scheme);
224        let handler = AuthHandler::from_oauth2(&populated)?;
225        let (auth_uri, state, verifier) = handler.authorize_url(&populated.scopes);
226
227        // Persist the in-flight verifier + state so `complete_consent` can
228        // validate the inbound callback. `flow_id` is what the caller hands
229        // back; we use `state` itself as the flow id since it's already
230        // a cryptographically-random opaque token from the `oauth2` crate.
231        let flow_id = state.clone();
232        let pending = AuthCredential::oauth2(OAuth2Auth {
233            client_id: populated.client_id.clone(),
234            client_secret: populated.client_secret.clone(),
235            auth_uri: populated.auth_uri.clone(),
236            token_uri: populated.token_uri.clone(),
237            redirect_uri: populated.redirect_uri.clone(),
238            state: Some(state),
239            code_verifier: Some(verifier),
240            scopes: populated.scopes.clone(),
241            ..OAuth2Auth::default()
242        });
243        // App / user are not yet known at this point — store under a
244        // process-wide bucket. Callers that need multi-tenant isolation can
245        // override `begin_consent_for` (see below).
246        credentials
247            .save(
248                "__adk",
249                "__pending",
250                &pending_consent_key(&flow_id),
251                &pending,
252            )
253            .await?;
254
255        Ok(ConsentRequest { auth_uri, flow_id })
256    }
257
258    /// Complete an OAuth 2.0 authorization-code consent flow.
259    ///
260    /// `callback_state` and `callback_code` are the `state` and `code` query
261    /// params received at the provider's redirect_uri. Validates that
262    /// `callback_state == flow_id` (the persisted state), exchanges the code
263    /// for an access token using the PKCE verifier persisted by
264    /// `begin_consent`, and writes the resolved credential under the regular
265    /// cache key for `(app, user)`. Returns the exchanged credential.
266    pub async fn complete_consent(
267        &self,
268        app: &str,
269        user: &str,
270        flow_id: &str,
271        callback_state: &str,
272        callback_code: &str,
273        credentials: &dyn CredentialService,
274    ) -> Result<AuthCredential> {
275        // Constant-time-ish equality (don't reveal mismatch length via early
276        // bail on prefix).
277        if !constant_time_eq(callback_state.as_bytes(), flow_id.as_bytes()) {
278            return Err(Error::other(
279                "OAuth2 callback `state` does not match the flow id (possible CSRF)",
280            ));
281        }
282
283        let pending_key = pending_consent_key(flow_id);
284        let pending = credentials
285            .load("__adk", "__pending", &pending_key)
286            .await?
287            .ok_or_else(|| {
288                Error::other(format!(
289                    "no pending consent for flow_id {flow_id:?} (expired or already used)"
290                ))
291            })?;
292        let pending_oauth2 = pending
293            .oauth2
294            .as_ref()
295            .ok_or_else(|| Error::other("pending consent payload is not OAuth2"))?;
296        let verifier = pending_oauth2
297            .code_verifier
298            .as_deref()
299            .ok_or_else(|| Error::other("pending consent has no PKCE verifier"))?;
300        let stored_state = pending_oauth2.state.as_deref().unwrap_or("");
301        if !constant_time_eq(stored_state.as_bytes(), flow_id.as_bytes()) {
302            return Err(Error::other(
303                "pending consent state mismatch (possible replay)",
304            ));
305        }
306
307        let handler = AuthHandler::from_oauth2(pending_oauth2)?;
308        let tok = handler.exchange_code(callback_code, verifier).await?;
309        let mut new = pending_oauth2.clone();
310        // Clear the one-shot fields so the resolved credential isn't
311        // re-exchanged by mistake.
312        new.state = None;
313        new.code_verifier = None;
314        new.auth_code = None;
315        tok.apply_to(&mut new);
316        let exchanged = AuthCredential::oauth2(new);
317
318        // Cache under the regular key.
319        let cache_key = self.config.resolve_credential_key();
320        credentials.save(app, user, &cache_key, &exchanged).await?;
321
322        // Single-use: remove the pending entry so a leaked redirect can't be
323        // replayed.
324        let _ = credentials.delete("__adk", "__pending", &pending_key).await;
325
326        Ok(exchanged)
327    }
328}
329
330/// Fill in `auth_uri` / `token_uri` from the scheme's authorization-code flow
331/// if they aren't already set. Mirrors `exchanger::attach_flow_endpoints`.
332fn attach_flow_endpoints(oauth2: &mut OAuth2Auth, scheme: &AuthScheme) {
333    if let AuthScheme::OAuth2 { flows, .. } = scheme {
334        if let Some(ac) = flows.authorization_code.as_ref() {
335            if oauth2.auth_uri.is_none() {
336                oauth2.auth_uri.clone_from(&ac.authorization_url);
337            }
338            if oauth2.token_uri.is_none() {
339                oauth2.token_uri = Some(ac.token_url.clone());
340            }
341        }
342    }
343}
344
345/// Constant-time `==` for two byte slices. Used to compare CSRF state /
346/// flow ids without leaking length via early-return timing.
347fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
348    if a.len() != b.len() {
349        return false;
350    }
351    let mut diff: u8 = 0;
352    for (x, y) in a.iter().zip(b.iter()) {
353        diff |= x ^ y;
354    }
355    diff == 0
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::auth::credential::AuthCredential;
362    use crate::auth::scheme::{ApiKeyLocation, AuthScheme};
363    use crate::auth::service::InMemoryCredentialService;
364
365    #[tokio::test]
366    async fn api_key_resolves_immediately() {
367        let cfg = AuthConfig::new(AuthScheme::ApiKey {
368            location: ApiKeyLocation::Header,
369            name: "X-API-Key".into(),
370            description: None,
371        })
372        .with_raw(AuthCredential::api_key("secret"));
373        let mgr = CredentialManager::new(cfg);
374        let svc = InMemoryCredentialService::new();
375        match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
376            ResolveOutcome::Ready(c) => assert_eq!(c.api_key.as_deref(), Some("secret")),
377            other => panic!("unexpected outcome: {other:?}"),
378        }
379    }
380
381    #[tokio::test]
382    async fn oauth2_without_consent_returns_needs_user() {
383        use crate::auth::credential::OAuth2Auth;
384        use crate::auth::scheme::{OAuthFlow, OAuthFlows};
385
386        let cfg = AuthConfig::new(AuthScheme::OAuth2 {
387            flows: OAuthFlows {
388                authorization_code: Some(OAuthFlow {
389                    authorization_url: Some("https://p/authorize".into()),
390                    token_url: "https://p/token".into(),
391                    refresh_url: None,
392                    scopes: Default::default(),
393                }),
394                ..OAuthFlows::default()
395            },
396            description: None,
397        })
398        .with_raw(AuthCredential::oauth2(OAuth2Auth {
399            client_id: "abc".into(),
400            client_secret: Some("xyz".into()),
401            ..OAuth2Auth::default()
402        }));
403        let mgr = CredentialManager::new(cfg);
404        let svc = InMemoryCredentialService::new();
405        match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
406            ResolveOutcome::NeedsUserConsent(_) => {}
407            other => panic!("unexpected outcome: {other:?}"),
408        }
409    }
410
411    #[tokio::test]
412    async fn cached_credential_is_returned_when_raw_not_ready() {
413        use crate::auth::credential::OAuth2Auth;
414        use crate::auth::scheme::{OAuthFlow, OAuthFlows};
415
416        // Raw credential carries client_id + secret but no access_token →
417        // step 2 falls through; cache (step 3) hits and returns the cached
418        // ready credential.
419        let cfg = AuthConfig::new(AuthScheme::OAuth2 {
420            flows: OAuthFlows {
421                authorization_code: Some(OAuthFlow {
422                    authorization_url: Some("https://p/authorize".into()),
423                    token_url: "https://p/token".into(),
424                    refresh_url: None,
425                    scopes: Default::default(),
426                }),
427                ..OAuthFlows::default()
428            },
429            description: None,
430        })
431        .with_raw(AuthCredential::oauth2(OAuth2Auth {
432            client_id: "abc".into(),
433            client_secret: Some("xyz".into()),
434            ..OAuth2Auth::default()
435        }))
436        .with_key("fixed");
437
438        let cached = AuthCredential::oauth2(OAuth2Auth {
439            client_id: "abc".into(),
440            access_token: Some("CACHED_TOKEN".into()),
441            ..OAuth2Auth::default()
442        });
443        let svc = InMemoryCredentialService::new();
444        svc.save("a", "u", "fixed", &cached).await.unwrap();
445
446        let mgr = CredentialManager::new(cfg);
447        match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
448            ResolveOutcome::Ready(c) => {
449                assert_eq!(
450                    c.oauth2.as_ref().and_then(|o| o.access_token.as_deref()),
451                    Some("CACHED_TOKEN")
452                );
453            }
454            other => panic!("unexpected outcome: {other:?}"),
455        }
456    }
457}