Skip to main content

aether_auth/
credential.rs

1use async_trait::async_trait;
2use oauth2::basic::BasicClient;
3use oauth2::reqwest::redirect::Policy;
4use oauth2::{ClientId, RefreshToken, TokenResponse, TokenUrl};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8use crate::OAuthError;
9
10const TOKEN_EXPIRY_GRACE_PERIOD: Duration = Duration::from_mins(1);
11
12/// Credential for an OAuth provider.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct OAuthCredential {
15    pub client_id: String,
16    pub access_token: String,
17    pub refresh_token: Option<String>,
18    /// Unix timestamp in milliseconds when the token expires.
19    pub expires_at: Option<u64>,
20    /// Scopes the authorization server granted with this token.
21    #[serde(default)]
22    pub granted_scopes: Vec<String>,
23}
24
25impl OAuthCredential {
26    /// Build an `OAuthCredential` from an `OAuth2` token response.
27    pub fn from_token_response<T: TokenResponse>(client_id: String, token_response: &T) -> Self {
28        Self {
29            client_id,
30            access_token: token_response.access_token().secret().clone(),
31            refresh_token: token_response.refresh_token().map(|token| token.secret().clone()),
32            expires_at: expires_at_from_duration(token_response.expires_in()),
33            granted_scopes: token_response
34                .scopes()
35                .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect())
36                .unwrap_or_default(),
37        }
38    }
39
40    /// Whether the access token is expired or expiring within the refresh skew.
41    pub fn needs_refresh(&self) -> bool {
42        self.expires_at.is_some_and(|at| {
43            current_unix_time_millis() >= at.saturating_sub(duration_millis(TOKEN_EXPIRY_GRACE_PERIOD))
44        })
45    }
46
47    /// Time remaining before the access token expires, if known and still in the future.
48    pub fn expires_in(&self) -> Option<Duration> {
49        self.expires_at.and_then(|expires_at| {
50            let now = current_unix_time_millis();
51            (expires_at > now).then(|| Duration::from_millis(expires_at - now))
52        })
53    }
54
55    /// Exchange the refresh token for a new access token.
56    ///
57    /// Preserves the existing refresh token if the response doesn't include a rotated one.
58    /// Returns `NoCredentials` if the credential has no refresh token to exchange.
59    pub async fn refresh(self, token_url: &TokenUrl) -> Result<Self, OAuthError> {
60        let old_refresh_token = self.refresh_token.clone().ok_or_else(|| {
61            OAuthError::NoCredentials(
62                "OAuth credential expired and no refresh token is available. Re-run OAuth login.".to_string(),
63            )
64        })?;
65
66        let oauth_client = BasicClient::new(ClientId::new(self.client_id.clone())).set_token_uri(token_url.clone());
67        let http_client = oauth_http_client()?;
68        let token_response = oauth_client
69            .exchange_refresh_token(&RefreshToken::new(old_refresh_token.clone()))
70            .request_async(&http_client)
71            .await
72            .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
73
74        let mut refreshed = Self::from_token_response(self.client_id, &token_response);
75        if refreshed.refresh_token.is_none() {
76            refreshed.refresh_token = Some(old_refresh_token);
77        }
78        Ok(refreshed)
79    }
80}
81
82/// Trait for loading and saving OAuth credentials, keyed by provider ID or credential key.
83///
84/// Implementations include [`OsKeyringStore`](crate::OsKeyringStore) (OS keychain, feature `keyring`)
85/// and the in-memory [`FakeOAuthCredentialStore`](crate::FakeOAuthCredentialStore) for tests.
86#[async_trait]
87pub trait OAuthCredentialStorage: Send + Sync {
88    async fn load_credential(&self, key: &str) -> Result<Option<OAuthCredential>, OAuthError>;
89
90    async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError>;
91
92    async fn delete_credential(&self, key: &str) -> Result<(), OAuthError>;
93
94    fn has_credential(&self, key: &str) -> bool;
95}
96
97fn expires_at_from_duration(duration: Option<Duration>) -> Option<u64> {
98    duration.map(|duration| current_unix_time_millis().saturating_add(duration_millis(duration)))
99}
100
101pub fn oauth_http_client() -> Result<oauth2::reqwest::Client, OAuthError> {
102    oauth2::reqwest::Client::builder()
103        .redirect(Policy::none())
104        .build()
105        .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))
106}
107
108fn current_unix_time_millis() -> u64 {
109    u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis())
110        .unwrap_or(u64::MAX)
111}
112
113fn duration_millis(duration: Duration) -> u64 {
114    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn needs_refresh_is_false_when_no_expiry() {
123        assert!(!build_credential(None).needs_refresh());
124    }
125
126    #[test]
127    fn needs_refresh_is_false_when_far_in_future() {
128        assert!(!build_credential(Some(u64::MAX)).needs_refresh());
129    }
130
131    #[test]
132    fn needs_refresh_is_true_when_past() {
133        assert!(build_credential(Some(0)).needs_refresh());
134    }
135
136    #[test]
137    fn needs_refresh_is_true_when_within_skew() {
138        let cred = build_credential(expires_at_from_duration(Some(Duration::from_millis(59_999))));
139        assert!(cred.needs_refresh());
140    }
141
142    #[test]
143    fn expires_in_is_none_when_no_expiry() {
144        assert!(build_credential(None).expires_in().is_none());
145    }
146
147    #[test]
148    fn expires_in_is_none_when_already_past() {
149        assert!(build_credential(Some(0)).expires_in().is_none());
150    }
151
152    #[test]
153    fn expires_in_returns_remaining_duration_when_future() {
154        let cred = build_credential(expires_at_from_duration(Some(Duration::from_hours(1))));
155        let remaining = cred.expires_in().expect("expires_in should be Some for future expiry");
156        assert!(remaining > Duration::from_mins(58));
157        assert!(remaining <= Duration::from_hours(1));
158    }
159
160    fn build_credential(expires_at: Option<u64>) -> OAuthCredential {
161        OAuthCredential {
162            client_id: "client".to_string(),
163            access_token: "access".to_string(),
164            refresh_token: None,
165            expires_at,
166            granted_scopes: Vec::new(),
167        }
168    }
169}