Skip to main content

agent_diva_core/auth/
oauth_common.rs

1use base64::Engine;
2use sha2::{Digest, Sha256};
3use std::collections::BTreeMap;
4
5use crate::auth::profiles::ProviderTokenSet;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct PkceState {
9    pub state: String,
10    pub code_verifier: String,
11    pub code_challenge: String,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct OAuthProfileState {
16    pub token_set: ProviderTokenSet,
17    pub account_id: Option<String>,
18    pub metadata: BTreeMap<String, String>,
19}
20
21#[async_trait::async_trait]
22pub trait OAuthTokenManager: Send + Sync {
23    async fn refresh_oauth_state(&self, refresh_token: &str) -> anyhow::Result<OAuthProfileState>;
24
25    fn extract_account_id(&self, access_token: &str) -> Option<String>;
26}
27
28pub fn generate_pkce_state() -> PkceState {
29    let state = uuid::Uuid::new_v4().to_string();
30    let code_verifier = format!("{}{}", uuid::Uuid::new_v4(), uuid::Uuid::new_v4());
31    let digest = Sha256::digest(code_verifier.as_bytes());
32    let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
33    PkceState {
34        state,
35        code_verifier,
36        code_challenge,
37    }
38}
39
40pub fn parse_code_from_redirect(
41    input: &str,
42    expected_state: Option<&str>,
43) -> anyhow::Result<String> {
44    let trimmed = input.trim();
45    if trimmed.is_empty() {
46        anyhow::bail!("OAuth redirect does not contain authorization code");
47    }
48    let query = trimmed
49        .split_once('?')
50        .map(|(_, query)| query)
51        .unwrap_or(trimmed);
52    let params = parse_query_params(query);
53    let callback_like = trimmed.contains('?')
54        || params.contains_key("code")
55        || params.contains_key("state")
56        || params.contains_key("error");
57
58    if let Some(error) = params.get("error") {
59        anyhow::bail!("OAuth redirect returned error: {error}");
60    }
61
62    if let Some(state) = expected_state {
63        if let Some(returned) = params.get("state") {
64            if returned != state {
65                anyhow::bail!("OAuth state mismatch");
66            }
67        } else if callback_like {
68            anyhow::bail!("OAuth state mismatch");
69        }
70    }
71
72    if let Some(code) = params.get("code").cloned() {
73        return Ok(code);
74    }
75    if !callback_like {
76        return Ok(trimmed.to_string());
77    }
78    anyhow::bail!("OAuth redirect does not contain authorization code")
79}
80
81pub fn parse_query_params(input: &str) -> BTreeMap<String, String> {
82    input
83        .split('&')
84        .filter_map(|entry| entry.split_once('='))
85        .map(|(key, value)| {
86            (
87                urlencoding::decode(key)
88                    .unwrap_or_else(|_| key.into())
89                    .to_string(),
90                urlencoding::decode(value)
91                    .unwrap_or_else(|_| value.into())
92                    .to_string(),
93            )
94        })
95        .collect()
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn pkce_state_contains_all_fields() {
104        let state = generate_pkce_state();
105        assert!(!state.state.is_empty());
106        assert!(!state.code_verifier.is_empty());
107        assert!(!state.code_challenge.is_empty());
108    }
109
110    #[test]
111    fn parse_redirect_code_roundtrip() {
112        let parsed =
113            parse_code_from_redirect("/auth/callback?code=abc&state=expected", Some("expected"))
114                .unwrap();
115        assert_eq!(parsed, "abc");
116    }
117
118    #[test]
119    fn parse_redirect_code_rejects_bad_state() {
120        let err = parse_code_from_redirect("/auth/callback?code=abc&state=wrong", Some("expected"))
121            .unwrap_err();
122        assert!(err.to_string().contains("state mismatch"));
123    }
124}