1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use thiserror::Error;
11
12use crate::secrets::{ApiKey, CredentialError, CredentialStore};
13
14#[derive(Error, Debug)]
16pub enum AuthError {
17 #[error("Credential error: {0}")]
19 Credential(#[from] CredentialError),
20
21 #[error("IO error: {0}")]
23 Io(#[from] std::io::Error),
24
25 #[error("Serialization error: {0}")]
27 Serialization(#[from] serde_json::Error),
28
29 #[error("Token expired")]
31 TokenExpired,
32
33 #[error("Authentication failed: {0}")]
35 AuthFailed(String),
36
37 #[error("Profile not found: {0}")]
39 ProfileNotFound(String),
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AuthProfile {
45 pub id: String,
47
48 pub profile_type: ProfileType,
50
51 pub target: String,
53
54 pub account_id: Option<String>,
56
57 pub created_at: DateTime<Utc>,
59
60 pub last_used: Option<DateTime<Utc>>,
62
63 pub active: bool,
65
66 #[serde(default)]
68 pub metadata: HashMap<String, serde_json::Value>,
69}
70
71impl AuthProfile {
72 #[must_use]
74 pub fn new(
75 id: impl Into<String>,
76 profile_type: ProfileType,
77 target: impl Into<String>,
78 ) -> Self {
79 Self {
80 id: id.into(),
81 profile_type,
82 target: target.into(),
83 account_id: None,
84 created_at: Utc::now(),
85 last_used: None,
86 active: true,
87 metadata: HashMap::new(),
88 }
89 }
90
91 pub fn mark_used(&mut self) {
93 self.last_used = Some(Utc::now());
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum ProfileType {
101 ApiKey,
103 OAuth,
105 BotToken,
107 Session,
109 Certificate,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct OAuthToken {
116 pub access_token: String,
118
119 pub refresh_token: Option<String>,
121
122 pub expires_at: Option<DateTime<Utc>>,
124
125 pub token_type: String,
127
128 pub scopes: Vec<String>,
130}
131
132impl OAuthToken {
133 #[must_use]
135 pub fn is_expired(&self) -> bool {
136 self.expires_at.is_some_and(|exp| exp < Utc::now())
137 }
138
139 #[must_use]
141 pub fn needs_refresh(&self) -> bool {
142 self.expires_at
143 .is_some_and(|exp| exp < Utc::now() + chrono::Duration::minutes(5))
144 }
145}
146
147pub struct AuthCredentialStore {
151 inner: CredentialStore,
152 profiles_path: std::path::PathBuf,
153 profiles: HashMap<String, AuthProfile>,
154}
155
156impl AuthCredentialStore {
157 pub fn new(encryption_key: [u8; 32], base_path: &Path) -> Result<Self, AuthError> {
168 let store_path = base_path.join("secrets");
169 let profiles_path = base_path.join("profiles.json");
170
171 let inner = CredentialStore::new(encryption_key, store_path);
172
173 let profiles = if profiles_path.exists() {
174 let content = std::fs::read_to_string(&profiles_path)?;
175 serde_json::from_str(&content)?
176 } else {
177 HashMap::new()
178 };
179
180 Ok(Self {
181 inner,
182 profiles_path,
183 profiles,
184 })
185 }
186
187 pub fn store_api_key(&mut self, profile_id: &str, key: &ApiKey) -> Result<(), AuthError> {
193 self.inner.store(profile_id, key)?;
194 self.save_profiles()?;
195 Ok(())
196 }
197
198 pub fn load_api_key(&self, profile_id: &str) -> Result<ApiKey, AuthError> {
204 Ok(self.inner.load(profile_id)?)
205 }
206
207 pub fn store_oauth_token(
213 &mut self,
214 profile_id: &str,
215 token: &OAuthToken,
216 ) -> Result<(), AuthError> {
217 let token_json = serde_json::to_string(token)?;
218 let key = ApiKey::new(token_json);
219 self.inner.store(&format!("{profile_id}_oauth"), &key)?;
220 self.save_profiles()?;
221 Ok(())
222 }
223
224 pub fn load_oauth_token(&self, profile_id: &str) -> Result<OAuthToken, AuthError> {
230 let key = self.inner.load(&format!("{profile_id}_oauth"))?;
231 let token: OAuthToken = serde_json::from_str(key.expose())?;
232
233 if token.is_expired() && token.refresh_token.is_none() {
234 return Err(AuthError::TokenExpired);
235 }
236
237 Ok(token)
238 }
239
240 pub fn set_profile(&mut self, profile: AuthProfile) -> Result<(), AuthError> {
242 self.profiles.insert(profile.id.clone(), profile);
243 self.save_profiles()?;
244 Ok(())
245 }
246
247 #[must_use]
249 pub fn get_profile(&self, profile_id: &str) -> Option<&AuthProfile> {
250 self.profiles.get(profile_id)
251 }
252
253 #[must_use]
255 pub fn get_profile_mut(&mut self, profile_id: &str) -> Option<&mut AuthProfile> {
256 self.profiles.get_mut(profile_id)
257 }
258
259 pub fn remove_profile(&mut self, profile_id: &str) -> Result<(), AuthError> {
265 self.profiles.remove(profile_id);
266 let _ = self.inner.delete(profile_id);
267 let _ = self.inner.delete(&format!("{profile_id}_oauth"));
268 self.save_profiles()?;
269 Ok(())
270 }
271
272 #[must_use]
274 pub fn list_profiles(&self) -> Vec<&AuthProfile> {
275 self.profiles.values().collect()
276 }
277
278 #[must_use]
280 pub fn profiles_for_target(&self, target: &str) -> Vec<&AuthProfile> {
281 self.profiles
282 .values()
283 .filter(|p| p.target == target)
284 .collect()
285 }
286
287 #[must_use]
289 pub fn active_profile_for_target(&self, target: &str) -> Option<&AuthProfile> {
290 self.profiles
291 .values()
292 .find(|p| p.target == target && p.active)
293 }
294
295 fn save_profiles(&self) -> Result<(), AuthError> {
297 if let Some(parent) = self.profiles_path.parent() {
298 std::fs::create_dir_all(parent)?;
299 }
300 let content = serde_json::to_string_pretty(&self.profiles)?;
301 std::fs::write(&self.profiles_path, content)?;
302 Ok(())
303 }
304}
305
306pub async fn refresh_oauth_token(
310 token: &OAuthToken,
311 client_id: &str,
312 client_secret: &str,
313 token_url: &str,
314) -> Result<OAuthToken, AuthError> {
315 let refresh_token = token
316 .refresh_token
317 .as_ref()
318 .ok_or_else(|| AuthError::AuthFailed("No refresh token available".to_string()))?;
319
320 let client = reqwest::Client::new();
322 let response = client
323 .post(token_url)
324 .form(&[
325 ("grant_type", "refresh_token"),
326 ("refresh_token", refresh_token),
327 ("client_id", client_id),
328 ("client_secret", client_secret),
329 ])
330 .send()
331 .await
332 .map_err(|e| AuthError::AuthFailed(e.to_string()))?;
333
334 if !response.status().is_success() {
335 return Err(AuthError::AuthFailed(format!(
336 "Token refresh failed: {}",
337 response.status()
338 )));
339 }
340
341 #[derive(Deserialize)]
342 struct TokenResponse {
343 access_token: String,
344 refresh_token: Option<String>,
345 expires_in: Option<i64>,
346 token_type: Option<String>,
347 }
348
349 let token_response: TokenResponse = response
350 .json()
351 .await
352 .map_err(|e| AuthError::AuthFailed(e.to_string()))?;
353
354 let expires_at = token_response
355 .expires_in
356 .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
357
358 Ok(OAuthToken {
359 access_token: token_response.access_token,
360 refresh_token: token_response
361 .refresh_token
362 .or_else(|| token.refresh_token.clone()),
363 expires_at,
364 token_type: token_response
365 .token_type
366 .unwrap_or_else(|| "Bearer".to_string()),
367 scopes: token.scopes.clone(),
368 })
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use tempfile::tempdir;
375
376 #[test]
377 fn test_auth_profile_creation() {
378 let profile = AuthProfile::new("test-profile", ProfileType::ApiKey, "anthropic");
379
380 assert_eq!(profile.id, "test-profile");
381 assert_eq!(profile.target, "anthropic");
382 assert!(profile.active);
383 }
384
385 #[test]
386 fn test_oauth_token_expiry() {
387 let token = OAuthToken {
388 access_token: "test".to_string(),
389 refresh_token: None,
390 expires_at: Some(Utc::now() - chrono::Duration::hours(1)),
391 token_type: "Bearer".to_string(),
392 scopes: vec![],
393 };
394
395 assert!(token.is_expired());
396 assert!(token.needs_refresh());
397 }
398
399 #[test]
400 fn test_oauth_token_valid() {
401 let token = OAuthToken {
402 access_token: "test".to_string(),
403 refresh_token: None,
404 expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
405 token_type: "Bearer".to_string(),
406 scopes: vec![],
407 };
408
409 assert!(!token.is_expired());
410 assert!(!token.needs_refresh());
411 }
412
413 #[test]
414 fn test_auth_credential_store() {
415 let temp = tempdir().unwrap();
416 let encryption_key: [u8; 32] = rand::random();
417
418 let mut store = AuthCredentialStore::new(encryption_key, temp.path()).unwrap();
419
420 let profile = AuthProfile::new("test", ProfileType::ApiKey, "anthropic");
422 store.set_profile(profile).unwrap();
423
424 let key = ApiKey::new("sk-test-key".to_string());
426 store.store_api_key("test", &key).unwrap();
427
428 let loaded = store.load_api_key("test").unwrap();
430 assert_eq!(loaded.expose(), "sk-test-key");
431
432 let profiles = store.list_profiles();
434 assert_eq!(profiles.len(), 1);
435 }
436
437 #[test]
438 fn test_profiles_for_target() {
439 let temp = tempdir().unwrap();
440 let encryption_key: [u8; 32] = rand::random();
441
442 let mut store = AuthCredentialStore::new(encryption_key, temp.path()).unwrap();
443
444 store
445 .set_profile(AuthProfile::new("a1", ProfileType::ApiKey, "anthropic"))
446 .unwrap();
447 store
448 .set_profile(AuthProfile::new("o1", ProfileType::ApiKey, "openai"))
449 .unwrap();
450 store
451 .set_profile(AuthProfile::new("a2", ProfileType::ApiKey, "anthropic"))
452 .unwrap();
453
454 let anthropic_profiles = store.profiles_for_target("anthropic");
455 assert_eq!(anthropic_profiles.len(), 2);
456 }
457}