agent_diva_core/auth/
oauth_common.rs1use base64::Engine;
2use sha2::{Digest, Sha256};
3use std::collections::BTreeMap;
4
5use crate::auth::profiles::ProviderTokenSet;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct PkceState {
9 pub state: String,
10 pub code_verifier: String,
11 pub code_challenge: String,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct OAuthProfileState {
16 pub token_set: ProviderTokenSet,
17 pub account_id: Option<String>,
18 pub metadata: BTreeMap<String, String>,
19}
20
21#[async_trait::async_trait]
22pub trait OAuthTokenManager: Send + Sync {
23 async fn refresh_oauth_state(&self, refresh_token: &str) -> anyhow::Result<OAuthProfileState>;
24
25 fn extract_account_id(&self, access_token: &str) -> Option<String>;
26}
27
28pub fn generate_pkce_state() -> PkceState {
29 let state = uuid::Uuid::new_v4().to_string();
30 let code_verifier = format!("{}{}", uuid::Uuid::new_v4(), uuid::Uuid::new_v4());
31 let digest = Sha256::digest(code_verifier.as_bytes());
32 let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
33 PkceState {
34 state,
35 code_verifier,
36 code_challenge,
37 }
38}
39
40pub fn parse_code_from_redirect(
41 input: &str,
42 expected_state: Option<&str>,
43) -> anyhow::Result<String> {
44 let trimmed = input.trim();
45 if trimmed.is_empty() {
46 anyhow::bail!("OAuth redirect does not contain authorization code");
47 }
48 let query = trimmed
49 .split_once('?')
50 .map(|(_, query)| query)
51 .unwrap_or(trimmed);
52 let params = parse_query_params(query);
53 let callback_like = trimmed.contains('?')
54 || params.contains_key("code")
55 || params.contains_key("state")
56 || params.contains_key("error");
57
58 if let Some(error) = params.get("error") {
59 anyhow::bail!("OAuth redirect returned error: {error}");
60 }
61
62 if let Some(state) = expected_state {
63 if let Some(returned) = params.get("state") {
64 if returned != state {
65 anyhow::bail!("OAuth state mismatch");
66 }
67 } else if callback_like {
68 anyhow::bail!("OAuth state mismatch");
69 }
70 }
71
72 if let Some(code) = params.get("code").cloned() {
73 return Ok(code);
74 }
75 if !callback_like {
76 return Ok(trimmed.to_string());
77 }
78 anyhow::bail!("OAuth redirect does not contain authorization code")
79}
80
81pub fn parse_query_params(input: &str) -> BTreeMap<String, String> {
82 input
83 .split('&')
84 .filter_map(|entry| entry.split_once('='))
85 .map(|(key, value)| {
86 (
87 urlencoding::decode(key)
88 .unwrap_or_else(|_| key.into())
89 .to_string(),
90 urlencoding::decode(value)
91 .unwrap_or_else(|_| value.into())
92 .to_string(),
93 )
94 })
95 .collect()
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn pkce_state_contains_all_fields() {
104 let state = generate_pkce_state();
105 assert!(!state.state.is_empty());
106 assert!(!state.code_verifier.is_empty());
107 assert!(!state.code_challenge.is_empty());
108 }
109
110 #[test]
111 fn parse_redirect_code_roundtrip() {
112 let parsed =
113 parse_code_from_redirect("/auth/callback?code=abc&state=expected", Some("expected"))
114 .unwrap();
115 assert_eq!(parsed, "abc");
116 }
117
118 #[test]
119 fn parse_redirect_code_rejects_bad_state() {
120 let err = parse_code_from_redirect("/auth/callback?code=abc&state=wrong", Some("expected"))
121 .unwrap_err();
122 assert!(err.to_string().contains("state mismatch"));
123 }
124}