use chrono::Utc;
use std::sync::Arc;
use crate::auth::config::AuthConfig;
use crate::auth::credential::{AuthCredential, AuthCredentialType, OAuth2Auth};
use crate::auth::exchanger::ExchangerRegistry;
use crate::auth::handler::AuthHandler;
use crate::auth::provider::AuthProviderRegistry;
use crate::auth::refresher::RefresherRegistry;
use crate::auth::scheme::AuthScheme;
use crate::auth::service::CredentialService;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub enum ResolveOutcome {
Ready(AuthCredential),
NeedsUserConsent(AuthConfig),
Misconfigured(String),
}
#[derive(Debug, Clone)]
pub struct ConsentRequest {
pub auth_uri: String,
pub flow_id: String,
}
const PENDING_CONSENT_PREFIX: &str = "__pending_consent:";
fn pending_consent_key(flow_id: &str) -> String {
format!("{PENDING_CONSENT_PREFIX}{flow_id}")
}
#[derive(Debug)]
pub struct CredentialManager {
config: AuthConfig,
exchangers: Arc<ExchangerRegistry>,
refreshers: Arc<RefresherRegistry>,
providers: Arc<AuthProviderRegistry>,
}
impl CredentialManager {
#[must_use]
pub fn new(config: AuthConfig) -> Self {
Self {
config,
exchangers: Arc::new(ExchangerRegistry::with_defaults()),
refreshers: Arc::new(RefresherRegistry::with_defaults()),
providers: Arc::new(AuthProviderRegistry::new()),
}
}
#[must_use]
pub fn with_registries(
config: AuthConfig,
exchangers: Arc<ExchangerRegistry>,
refreshers: Arc<RefresherRegistry>,
providers: Arc<AuthProviderRegistry>,
) -> Self {
Self {
config,
exchangers,
refreshers,
providers,
}
}
#[must_use]
pub fn credential_key(&self) -> String {
self.config.resolve_credential_key()
}
#[must_use]
pub fn config(&self) -> &AuthConfig {
&self.config
}
pub async fn resolve(
&self,
app: &str,
user: &str,
credentials: Option<&dyn CredentialService>,
) -> Result<ResolveOutcome> {
let raw = self
.config
.raw_auth_credential
.as_ref()
.ok_or_else(|| Error::config("AuthConfig.raw_auth_credential is required"))?;
let now = Utc::now().timestamp();
if raw.is_ready() && !raw.is_expired(now) {
return Ok(ResolveOutcome::Ready(raw.clone()));
}
let key = self.config.resolve_credential_key();
if let Some(svc) = credentials {
if let Some(cached) = svc.load(app, user, &key).await? {
if cached.is_ready() && !cached.is_expired(now) {
return Ok(ResolveOutcome::Ready(cached));
}
if let Some(r) = self.refreshers.get(cached.auth_type) {
if let Some(refreshed) = r.refresh(&self.config, &cached).await? {
svc.save(app, user, &key, &refreshed).await?;
return Ok(ResolveOutcome::Ready(refreshed));
}
}
}
}
if matches!(
raw.auth_type,
AuthCredentialType::OAuth2 | AuthCredentialType::OpenIdConnect
) && raw
.oauth2
.as_ref()
.is_some_and(|o| o.auth_code.is_none() && o.access_token.is_none())
{
return Ok(ResolveOutcome::NeedsUserConsent(self.config.clone()));
}
if let Some(ex) = self.exchangers.get(raw.auth_type) {
if let Some(exchanged) = ex.exchange(&self.config, raw).await? {
if let Some(svc) = credentials {
svc.save(app, user, &key, &exchanged).await?;
}
return Ok(ResolveOutcome::Ready(exchanged));
}
}
if let Some(prov) = self.providers.get(self.config.auth_scheme.kind()) {
if let Some(c) = prov.get_auth_credential(&self.config).await? {
if let Some(svc) = credentials {
svc.save(app, user, &key, &c).await?;
}
return Ok(ResolveOutcome::Ready(c));
}
}
Ok(ResolveOutcome::Misconfigured(format!(
"no exchanger registered for {:?}; credential not ready",
raw.auth_type
)))
}
pub async fn begin_consent(
&self,
credentials: &dyn CredentialService,
) -> Result<ConsentRequest> {
let raw = self
.config
.raw_auth_credential
.as_ref()
.ok_or_else(|| Error::config("AuthConfig.raw_auth_credential is required"))?;
let oauth2 = raw
.oauth2
.as_ref()
.ok_or_else(|| Error::config("begin_consent requires an OAuth2 credential"))?;
if !matches!(
self.config.auth_scheme,
AuthScheme::OAuth2 { .. } | AuthScheme::OpenIdConnect { .. }
) {
return Err(Error::config(
"begin_consent requires an OAuth2 / OpenIdConnect scheme",
));
}
let mut populated = oauth2.clone();
attach_flow_endpoints(&mut populated, &self.config.auth_scheme);
let handler = AuthHandler::from_oauth2(&populated)?;
let (auth_uri, state, verifier) = handler.authorize_url(&populated.scopes);
let flow_id = state.clone();
let pending = AuthCredential::oauth2(OAuth2Auth {
client_id: populated.client_id.clone(),
client_secret: populated.client_secret.clone(),
auth_uri: populated.auth_uri.clone(),
token_uri: populated.token_uri.clone(),
redirect_uri: populated.redirect_uri.clone(),
state: Some(state),
code_verifier: Some(verifier),
scopes: populated.scopes.clone(),
..OAuth2Auth::default()
});
credentials
.save(
"__adk",
"__pending",
&pending_consent_key(&flow_id),
&pending,
)
.await?;
Ok(ConsentRequest { auth_uri, flow_id })
}
pub async fn complete_consent(
&self,
app: &str,
user: &str,
flow_id: &str,
callback_state: &str,
callback_code: &str,
credentials: &dyn CredentialService,
) -> Result<AuthCredential> {
if !constant_time_eq(callback_state.as_bytes(), flow_id.as_bytes()) {
return Err(Error::other(
"OAuth2 callback `state` does not match the flow id (possible CSRF)",
));
}
let pending_key = pending_consent_key(flow_id);
let pending = credentials
.load("__adk", "__pending", &pending_key)
.await?
.ok_or_else(|| {
Error::other(format!(
"no pending consent for flow_id {flow_id:?} (expired or already used)"
))
})?;
let pending_oauth2 = pending
.oauth2
.as_ref()
.ok_or_else(|| Error::other("pending consent payload is not OAuth2"))?;
let verifier = pending_oauth2
.code_verifier
.as_deref()
.ok_or_else(|| Error::other("pending consent has no PKCE verifier"))?;
let stored_state = pending_oauth2.state.as_deref().unwrap_or("");
if !constant_time_eq(stored_state.as_bytes(), flow_id.as_bytes()) {
return Err(Error::other(
"pending consent state mismatch (possible replay)",
));
}
let handler = AuthHandler::from_oauth2(pending_oauth2)?;
let tok = handler.exchange_code(callback_code, verifier).await?;
let mut new = pending_oauth2.clone();
new.state = None;
new.code_verifier = None;
new.auth_code = None;
tok.apply_to(&mut new);
let exchanged = AuthCredential::oauth2(new);
let cache_key = self.config.resolve_credential_key();
credentials.save(app, user, &cache_key, &exchanged).await?;
let _ = credentials.delete("__adk", "__pending", &pending_key).await;
Ok(exchanged)
}
}
fn attach_flow_endpoints(oauth2: &mut OAuth2Auth, scheme: &AuthScheme) {
if let AuthScheme::OAuth2 { flows, .. } = scheme {
if let Some(ac) = flows.authorization_code.as_ref() {
if oauth2.auth_uri.is_none() {
oauth2.auth_uri.clone_from(&ac.authorization_url);
}
if oauth2.token_uri.is_none() {
oauth2.token_uri = Some(ac.token_url.clone());
}
}
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::credential::AuthCredential;
use crate::auth::scheme::{ApiKeyLocation, AuthScheme};
use crate::auth::service::InMemoryCredentialService;
#[tokio::test]
async fn api_key_resolves_immediately() {
let cfg = AuthConfig::new(AuthScheme::ApiKey {
location: ApiKeyLocation::Header,
name: "X-API-Key".into(),
description: None,
})
.with_raw(AuthCredential::api_key("secret"));
let mgr = CredentialManager::new(cfg);
let svc = InMemoryCredentialService::new();
match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
ResolveOutcome::Ready(c) => assert_eq!(c.api_key.as_deref(), Some("secret")),
other => panic!("unexpected outcome: {other:?}"),
}
}
#[tokio::test]
async fn oauth2_without_consent_returns_needs_user() {
use crate::auth::credential::OAuth2Auth;
use crate::auth::scheme::{OAuthFlow, OAuthFlows};
let cfg = AuthConfig::new(AuthScheme::OAuth2 {
flows: OAuthFlows {
authorization_code: Some(OAuthFlow {
authorization_url: Some("https://p/authorize".into()),
token_url: "https://p/token".into(),
refresh_url: None,
scopes: Default::default(),
}),
..OAuthFlows::default()
},
description: None,
})
.with_raw(AuthCredential::oauth2(OAuth2Auth {
client_id: "abc".into(),
client_secret: Some("xyz".into()),
..OAuth2Auth::default()
}));
let mgr = CredentialManager::new(cfg);
let svc = InMemoryCredentialService::new();
match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
ResolveOutcome::NeedsUserConsent(_) => {}
other => panic!("unexpected outcome: {other:?}"),
}
}
#[tokio::test]
async fn cached_credential_is_returned_when_raw_not_ready() {
use crate::auth::credential::OAuth2Auth;
use crate::auth::scheme::{OAuthFlow, OAuthFlows};
let cfg = AuthConfig::new(AuthScheme::OAuth2 {
flows: OAuthFlows {
authorization_code: Some(OAuthFlow {
authorization_url: Some("https://p/authorize".into()),
token_url: "https://p/token".into(),
refresh_url: None,
scopes: Default::default(),
}),
..OAuthFlows::default()
},
description: None,
})
.with_raw(AuthCredential::oauth2(OAuth2Auth {
client_id: "abc".into(),
client_secret: Some("xyz".into()),
..OAuth2Auth::default()
}))
.with_key("fixed");
let cached = AuthCredential::oauth2(OAuth2Auth {
client_id: "abc".into(),
access_token: Some("CACHED_TOKEN".into()),
..OAuth2Auth::default()
});
let svc = InMemoryCredentialService::new();
svc.save("a", "u", "fixed", &cached).await.unwrap();
let mgr = CredentialManager::new(cfg);
match mgr.resolve("a", "u", Some(&svc)).await.unwrap() {
ResolveOutcome::Ready(c) => {
assert_eq!(
c.oauth2.as_ref().and_then(|o| o.access_token.as_deref()),
Some("CACHED_TOKEN")
);
}
other => panic!("unexpected outcome: {other:?}"),
}
}
}