aether_auth/
credential.rs1use async_trait::async_trait;
2use oauth2::basic::BasicClient;
3use oauth2::reqwest::redirect::Policy;
4use oauth2::{ClientId, RefreshToken, TokenResponse, TokenUrl};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8use crate::OAuthError;
9
10const TOKEN_EXPIRY_GRACE_PERIOD: Duration = Duration::from_mins(1);
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct OAuthCredential {
15 pub client_id: String,
16 pub access_token: String,
17 pub refresh_token: Option<String>,
18 pub expires_at: Option<u64>,
20}
21
22impl OAuthCredential {
23 pub fn from_token_response<T: TokenResponse>(client_id: String, token_response: &T) -> Self {
25 Self {
26 client_id,
27 access_token: token_response.access_token().secret().clone(),
28 refresh_token: token_response.refresh_token().map(|token| token.secret().clone()),
29 expires_at: expires_at_from_duration(token_response.expires_in()),
30 }
31 }
32
33 pub fn needs_refresh(&self) -> bool {
35 self.expires_at.is_some_and(|at| {
36 current_unix_time_millis() >= at.saturating_sub(duration_millis(TOKEN_EXPIRY_GRACE_PERIOD))
37 })
38 }
39
40 pub fn expires_in(&self) -> Option<Duration> {
42 self.expires_at.and_then(|expires_at| {
43 let now = current_unix_time_millis();
44 (expires_at > now).then(|| Duration::from_millis(expires_at - now))
45 })
46 }
47
48 pub async fn refresh(self, token_url: &TokenUrl) -> Result<Self, OAuthError> {
53 let old_refresh_token = self.refresh_token.clone().ok_or_else(|| {
54 OAuthError::NoCredentials(
55 "OAuth credential expired and no refresh token is available. Re-run OAuth login.".to_string(),
56 )
57 })?;
58
59 let oauth_client = BasicClient::new(ClientId::new(self.client_id.clone())).set_token_uri(token_url.clone());
60 let http_client = oauth_http_client()?;
61 let token_response = oauth_client
62 .exchange_refresh_token(&RefreshToken::new(old_refresh_token.clone()))
63 .request_async(&http_client)
64 .await
65 .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
66
67 let refreshed = Self::from_token_response(self.client_id, &token_response);
68 Ok(Self { refresh_token: refreshed.refresh_token.or(Some(old_refresh_token)), ..refreshed })
69 }
70}
71
72#[async_trait]
77pub trait OAuthCredentialStorage: Send + Sync {
78 async fn load_credential(&self, key: &str) -> Result<Option<OAuthCredential>, OAuthError>;
79
80 async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError>;
81
82 async fn delete_credential(&self, key: &str) -> Result<(), OAuthError>;
83
84 fn has_credential(&self, key: &str) -> bool;
85}
86
87fn expires_at_from_duration(duration: Option<Duration>) -> Option<u64> {
88 duration.map(|duration| current_unix_time_millis().saturating_add(duration_millis(duration)))
89}
90
91pub fn oauth_http_client() -> Result<oauth2::reqwest::Client, OAuthError> {
92 oauth2::reqwest::Client::builder()
93 .redirect(Policy::none())
94 .build()
95 .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))
96}
97
98fn current_unix_time_millis() -> u64 {
99 u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis())
100 .unwrap_or(u64::MAX)
101}
102
103fn duration_millis(duration: Duration) -> u64 {
104 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn needs_refresh_is_false_when_no_expiry() {
113 assert!(!build_credential(None).needs_refresh());
114 }
115
116 #[test]
117 fn needs_refresh_is_false_when_far_in_future() {
118 assert!(!build_credential(Some(u64::MAX)).needs_refresh());
119 }
120
121 #[test]
122 fn needs_refresh_is_true_when_past() {
123 assert!(build_credential(Some(0)).needs_refresh());
124 }
125
126 #[test]
127 fn needs_refresh_is_true_when_within_skew() {
128 let cred = build_credential(expires_at_from_duration(Some(Duration::from_millis(59_999))));
129 assert!(cred.needs_refresh());
130 }
131
132 #[test]
133 fn expires_in_is_none_when_no_expiry() {
134 assert!(build_credential(None).expires_in().is_none());
135 }
136
137 #[test]
138 fn expires_in_is_none_when_already_past() {
139 assert!(build_credential(Some(0)).expires_in().is_none());
140 }
141
142 #[test]
143 fn expires_in_returns_remaining_duration_when_future() {
144 let cred = build_credential(expires_at_from_duration(Some(Duration::from_hours(1))));
145 let remaining = cred.expires_in().expect("expires_in should be Some for future expiry");
146 assert!(remaining > Duration::from_mins(58));
147 assert!(remaining <= Duration::from_hours(1));
148 }
149
150 fn build_credential(expires_at: Option<u64>) -> OAuthCredential {
151 OAuthCredential {
152 client_id: "client".to_string(),
153 access_token: "access".to_string(),
154 refresh_token: None,
155 expires_at,
156 }
157 }
158}