agent-diva-core 0.3.0

Core types and traits for agent-diva
Documentation
use base64::Engine;
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;

use crate::auth::profiles::ProviderTokenSet;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PkceState {
    pub state: String,
    pub code_verifier: String,
    pub code_challenge: String,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthProfileState {
    pub token_set: ProviderTokenSet,
    pub account_id: Option<String>,
    pub metadata: BTreeMap<String, String>,
}

#[async_trait::async_trait]
pub trait OAuthTokenManager: Send + Sync {
    async fn refresh_oauth_state(&self, refresh_token: &str) -> anyhow::Result<OAuthProfileState>;

    fn extract_account_id(&self, access_token: &str) -> Option<String>;
}

pub fn generate_pkce_state() -> PkceState {
    let state = uuid::Uuid::new_v4().to_string();
    let code_verifier = format!("{}{}", uuid::Uuid::new_v4(), uuid::Uuid::new_v4());
    let digest = Sha256::digest(code_verifier.as_bytes());
    let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
    PkceState {
        state,
        code_verifier,
        code_challenge,
    }
}

pub fn parse_code_from_redirect(
    input: &str,
    expected_state: Option<&str>,
) -> anyhow::Result<String> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        anyhow::bail!("OAuth redirect does not contain authorization code");
    }
    let query = trimmed
        .split_once('?')
        .map(|(_, query)| query)
        .unwrap_or(trimmed);
    let params = parse_query_params(query);
    let callback_like = trimmed.contains('?')
        || params.contains_key("code")
        || params.contains_key("state")
        || params.contains_key("error");

    if let Some(error) = params.get("error") {
        anyhow::bail!("OAuth redirect returned error: {error}");
    }

    if let Some(state) = expected_state {
        if let Some(returned) = params.get("state") {
            if returned != state {
                anyhow::bail!("OAuth state mismatch");
            }
        } else if callback_like {
            anyhow::bail!("OAuth state mismatch");
        }
    }

    if let Some(code) = params.get("code").cloned() {
        return Ok(code);
    }
    if !callback_like {
        return Ok(trimmed.to_string());
    }
    anyhow::bail!("OAuth redirect does not contain authorization code")
}

pub fn parse_query_params(input: &str) -> BTreeMap<String, String> {
    input
        .split('&')
        .filter_map(|entry| entry.split_once('='))
        .map(|(key, value)| {
            (
                urlencoding::decode(key)
                    .unwrap_or_else(|_| key.into())
                    .to_string(),
                urlencoding::decode(value)
                    .unwrap_or_else(|_| value.into())
                    .to_string(),
            )
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn pkce_state_contains_all_fields() {
        let state = generate_pkce_state();
        assert!(!state.state.is_empty());
        assert!(!state.code_verifier.is_empty());
        assert!(!state.code_challenge.is_empty());
    }

    #[test]
    fn parse_redirect_code_roundtrip() {
        let parsed =
            parse_code_from_redirect("/auth/callback?code=abc&state=expected", Some("expected"))
                .unwrap();
        assert_eq!(parsed, "abc");
    }

    #[test]
    fn parse_redirect_code_rejects_bad_state() {
        let err = parse_code_from_redirect("/auth/callback?code=abc&state=wrong", Some("expected"))
            .unwrap_err();
        assert!(err.to_string().contains("state mismatch"));
    }
}