Skip to main content

agent_diva_core/auth/
profiles.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::time::Duration;
5
6pub const CURRENT_AUTH_SCHEMA_VERSION: u32 = 1;
7
8#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
9#[serde(rename_all = "snake_case")]
10pub enum ProviderAuthKind {
11    OAuth,
12    Token,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16pub struct ProviderTokenSet {
17    pub access_token: String,
18    #[serde(default)]
19    pub refresh_token: Option<String>,
20    #[serde(default)]
21    pub id_token: Option<String>,
22    #[serde(default)]
23    pub expires_at: Option<DateTime<Utc>>,
24    #[serde(default)]
25    pub token_type: Option<String>,
26    #[serde(default)]
27    pub scope: Option<String>,
28}
29
30impl ProviderTokenSet {
31    pub fn is_expiring_within(&self, skew: Duration) -> bool {
32        match self.expires_at {
33            Some(expires_at) => {
34                let Ok(skew) = chrono::Duration::from_std(skew) else {
35                    return false;
36                };
37                expires_at <= Utc::now() + skew
38            }
39            None => false,
40        }
41    }
42}
43
44#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct ProviderAuthProfile {
46    pub id: String,
47    pub provider: String,
48    pub profile_name: String,
49    pub kind: ProviderAuthKind,
50    #[serde(default)]
51    pub account_id: Option<String>,
52    #[serde(default)]
53    pub token_set: Option<ProviderTokenSet>,
54    #[serde(default)]
55    pub token: Option<String>,
56    #[serde(default)]
57    pub metadata: BTreeMap<String, String>,
58    pub created_at: DateTime<Utc>,
59    pub updated_at: DateTime<Utc>,
60}
61
62impl std::fmt::Debug for ProviderAuthProfile {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("ProviderAuthProfile")
65            .field("id", &self.id)
66            .field("provider", &self.provider)
67            .field("profile_name", &self.profile_name)
68            .field("kind", &self.kind)
69            .field("account_id", &self.account_id)
70            .field("metadata", &self.metadata)
71            .field("created_at", &self.created_at)
72            .field("updated_at", &self.updated_at)
73            .finish_non_exhaustive()
74    }
75}
76
77impl ProviderAuthProfile {
78    pub fn new_oauth(
79        provider: impl Into<String>,
80        profile_name: impl Into<String>,
81        token_set: ProviderTokenSet,
82    ) -> Self {
83        let provider = provider.into();
84        let profile_name = profile_name.into();
85        let now = Utc::now();
86        Self {
87            id: profile_id(&provider, &profile_name),
88            provider,
89            profile_name,
90            kind: ProviderAuthKind::OAuth,
91            account_id: None,
92            token_set: Some(token_set),
93            token: None,
94            metadata: BTreeMap::new(),
95            created_at: now,
96            updated_at: now,
97        }
98    }
99
100    pub fn new_token(
101        provider: impl Into<String>,
102        profile_name: impl Into<String>,
103        token: impl Into<String>,
104    ) -> Self {
105        let provider = provider.into();
106        let profile_name = profile_name.into();
107        let now = Utc::now();
108        Self {
109            id: profile_id(&provider, &profile_name),
110            provider,
111            profile_name,
112            kind: ProviderAuthKind::Token,
113            account_id: None,
114            token_set: None,
115            token: Some(token.into()),
116            metadata: BTreeMap::new(),
117            created_at: now,
118            updated_at: now,
119        }
120    }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
124pub struct ProviderAuthProfilesData {
125    pub schema_version: u32,
126    pub updated_at: DateTime<Utc>,
127    #[serde(default)]
128    pub active_profiles: BTreeMap<String, String>,
129    #[serde(default)]
130    pub profiles: BTreeMap<String, ProviderAuthProfile>,
131}
132
133impl Default for ProviderAuthProfilesData {
134    fn default() -> Self {
135        Self {
136            schema_version: CURRENT_AUTH_SCHEMA_VERSION,
137            updated_at: Utc::now(),
138            active_profiles: BTreeMap::new(),
139            profiles: BTreeMap::new(),
140        }
141    }
142}
143
144pub fn profile_id(provider: &str, profile_name: &str) -> String {
145    format!("{provider}:{profile_name}")
146}