Skip to main content

openclaw_core/auth/
mod.rs

1//! Authentication profiles and credential management.
2//!
3//! Manages OAuth tokens, API keys, and session credentials
4//! for channels and providers.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use thiserror::Error;
11
12use crate::secrets::{ApiKey, CredentialError, CredentialStore};
13
14/// Authentication errors.
15#[derive(Error, Debug)]
16pub enum AuthError {
17    /// Credential error.
18    #[error("Credential error: {0}")]
19    Credential(#[from] CredentialError),
20
21    /// IO error.
22    #[error("IO error: {0}")]
23    Io(#[from] std::io::Error),
24
25    /// Serialization error.
26    #[error("Serialization error: {0}")]
27    Serialization(#[from] serde_json::Error),
28
29    /// Token expired.
30    #[error("Token expired")]
31    TokenExpired,
32
33    /// Authentication failed.
34    #[error("Authentication failed: {0}")]
35    AuthFailed(String),
36
37    /// Profile not found.
38    #[error("Profile not found: {0}")]
39    ProfileNotFound(String),
40}
41
42/// Authentication profile for a channel or provider.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AuthProfile {
45    /// Profile identifier.
46    pub id: String,
47
48    /// Profile type.
49    pub profile_type: ProfileType,
50
51    /// Channel or provider this profile is for.
52    pub target: String,
53
54    /// Account identifier (e.g., bot username, phone number).
55    pub account_id: Option<String>,
56
57    /// When the profile was created.
58    pub created_at: DateTime<Utc>,
59
60    /// When the profile was last used.
61    pub last_used: Option<DateTime<Utc>>,
62
63    /// Whether this profile is active.
64    pub active: bool,
65
66    /// Additional metadata.
67    #[serde(default)]
68    pub metadata: HashMap<String, serde_json::Value>,
69}
70
71impl AuthProfile {
72    /// Create a new auth profile.
73    #[must_use]
74    pub fn new(
75        id: impl Into<String>,
76        profile_type: ProfileType,
77        target: impl Into<String>,
78    ) -> Self {
79        Self {
80            id: id.into(),
81            profile_type,
82            target: target.into(),
83            account_id: None,
84            created_at: Utc::now(),
85            last_used: None,
86            active: true,
87            metadata: HashMap::new(),
88        }
89    }
90
91    /// Mark the profile as used.
92    pub fn mark_used(&mut self) {
93        self.last_used = Some(Utc::now());
94    }
95}
96
97/// Type of authentication profile.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum ProfileType {
101    /// API key authentication.
102    ApiKey,
103    /// OAuth 2.0 token.
104    OAuth,
105    /// Bot token (e.g., Telegram, Discord).
106    BotToken,
107    /// Session-based auth (e.g., `WhatsApp`, Signal).
108    Session,
109    /// Certificate-based auth.
110    Certificate,
111}
112
113/// OAuth token data.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct OAuthToken {
116    /// Access token.
117    pub access_token: String,
118
119    /// Refresh token (if available).
120    pub refresh_token: Option<String>,
121
122    /// Token expiration time.
123    pub expires_at: Option<DateTime<Utc>>,
124
125    /// Token type (usually "Bearer").
126    pub token_type: String,
127
128    /// Scopes granted.
129    pub scopes: Vec<String>,
130}
131
132impl OAuthToken {
133    /// Check if the token is expired.
134    #[must_use]
135    pub fn is_expired(&self) -> bool {
136        self.expires_at.is_some_and(|exp| exp < Utc::now())
137    }
138
139    /// Check if the token needs refresh (expires within 5 minutes).
140    #[must_use]
141    pub fn needs_refresh(&self) -> bool {
142        self.expires_at
143            .is_some_and(|exp| exp < Utc::now() + chrono::Duration::minutes(5))
144    }
145}
146
147/// Credential store for authentication data.
148///
149/// Wraps `CredentialStore` with profile management.
150pub struct AuthCredentialStore {
151    inner: CredentialStore,
152    profiles_path: std::path::PathBuf,
153    profiles: HashMap<String, AuthProfile>,
154}
155
156impl AuthCredentialStore {
157    /// Create a new auth credential store.
158    ///
159    /// # Arguments
160    ///
161    /// * `encryption_key` - 32-byte encryption key
162    /// * `base_path` - Base directory for credentials
163    ///
164    /// # Errors
165    ///
166    /// Returns error if profile loading fails.
167    pub fn new(encryption_key: [u8; 32], base_path: &Path) -> Result<Self, AuthError> {
168        let store_path = base_path.join("secrets");
169        let profiles_path = base_path.join("profiles.json");
170
171        let inner = CredentialStore::new(encryption_key, store_path);
172
173        let profiles = if profiles_path.exists() {
174            let content = std::fs::read_to_string(&profiles_path)?;
175            serde_json::from_str(&content)?
176        } else {
177            HashMap::new()
178        };
179
180        Ok(Self {
181            inner,
182            profiles_path,
183            profiles,
184        })
185    }
186
187    /// Store an API key for a profile.
188    ///
189    /// # Errors
190    ///
191    /// Returns error if storage fails.
192    pub fn store_api_key(&mut self, profile_id: &str, key: &ApiKey) -> Result<(), AuthError> {
193        self.inner.store(profile_id, key)?;
194        self.save_profiles()?;
195        Ok(())
196    }
197
198    /// Load an API key for a profile.
199    ///
200    /// # Errors
201    ///
202    /// Returns error if profile not found or decryption fails.
203    pub fn load_api_key(&self, profile_id: &str) -> Result<ApiKey, AuthError> {
204        Ok(self.inner.load(profile_id)?)
205    }
206
207    /// Store an OAuth token for a profile.
208    ///
209    /// # Errors
210    ///
211    /// Returns error if storage fails.
212    pub fn store_oauth_token(
213        &mut self,
214        profile_id: &str,
215        token: &OAuthToken,
216    ) -> Result<(), AuthError> {
217        let token_json = serde_json::to_string(token)?;
218        let key = ApiKey::new(token_json);
219        self.inner.store(&format!("{profile_id}_oauth"), &key)?;
220        self.save_profiles()?;
221        Ok(())
222    }
223
224    /// Load an OAuth token for a profile.
225    ///
226    /// # Errors
227    ///
228    /// Returns error if profile not found, decryption fails, or token expired.
229    pub fn load_oauth_token(&self, profile_id: &str) -> Result<OAuthToken, AuthError> {
230        let key = self.inner.load(&format!("{profile_id}_oauth"))?;
231        let token: OAuthToken = serde_json::from_str(key.expose())?;
232
233        if token.is_expired() && token.refresh_token.is_none() {
234            return Err(AuthError::TokenExpired);
235        }
236
237        Ok(token)
238    }
239
240    /// Add or update an auth profile.
241    pub fn set_profile(&mut self, profile: AuthProfile) -> Result<(), AuthError> {
242        self.profiles.insert(profile.id.clone(), profile);
243        self.save_profiles()?;
244        Ok(())
245    }
246
247    /// Get an auth profile.
248    #[must_use]
249    pub fn get_profile(&self, profile_id: &str) -> Option<&AuthProfile> {
250        self.profiles.get(profile_id)
251    }
252
253    /// Get a mutable auth profile.
254    #[must_use]
255    pub fn get_profile_mut(&mut self, profile_id: &str) -> Option<&mut AuthProfile> {
256        self.profiles.get_mut(profile_id)
257    }
258
259    /// Remove an auth profile and its credentials.
260    ///
261    /// # Errors
262    ///
263    /// Returns error if deletion fails.
264    pub fn remove_profile(&mut self, profile_id: &str) -> Result<(), AuthError> {
265        self.profiles.remove(profile_id);
266        let _ = self.inner.delete(profile_id);
267        let _ = self.inner.delete(&format!("{profile_id}_oauth"));
268        self.save_profiles()?;
269        Ok(())
270    }
271
272    /// List all profiles.
273    #[must_use]
274    pub fn list_profiles(&self) -> Vec<&AuthProfile> {
275        self.profiles.values().collect()
276    }
277
278    /// List profiles for a specific target (channel or provider).
279    #[must_use]
280    pub fn profiles_for_target(&self, target: &str) -> Vec<&AuthProfile> {
281        self.profiles
282            .values()
283            .filter(|p| p.target == target)
284            .collect()
285    }
286
287    /// Get the active profile for a target.
288    #[must_use]
289    pub fn active_profile_for_target(&self, target: &str) -> Option<&AuthProfile> {
290        self.profiles
291            .values()
292            .find(|p| p.target == target && p.active)
293    }
294
295    /// Save profiles to disk.
296    fn save_profiles(&self) -> Result<(), AuthError> {
297        if let Some(parent) = self.profiles_path.parent() {
298            std::fs::create_dir_all(parent)?;
299        }
300        let content = serde_json::to_string_pretty(&self.profiles)?;
301        std::fs::write(&self.profiles_path, content)?;
302        Ok(())
303    }
304}
305
306/// Refresh an OAuth token using the refresh token.
307///
308/// This is a placeholder - actual implementation depends on the OAuth provider.
309pub async fn refresh_oauth_token(
310    token: &OAuthToken,
311    client_id: &str,
312    client_secret: &str,
313    token_url: &str,
314) -> Result<OAuthToken, AuthError> {
315    let refresh_token = token
316        .refresh_token
317        .as_ref()
318        .ok_or_else(|| AuthError::AuthFailed("No refresh token available".to_string()))?;
319
320    // Build refresh request
321    let client = reqwest::Client::new();
322    let response = client
323        .post(token_url)
324        .form(&[
325            ("grant_type", "refresh_token"),
326            ("refresh_token", refresh_token),
327            ("client_id", client_id),
328            ("client_secret", client_secret),
329        ])
330        .send()
331        .await
332        .map_err(|e| AuthError::AuthFailed(e.to_string()))?;
333
334    if !response.status().is_success() {
335        return Err(AuthError::AuthFailed(format!(
336            "Token refresh failed: {}",
337            response.status()
338        )));
339    }
340
341    #[derive(Deserialize)]
342    struct TokenResponse {
343        access_token: String,
344        refresh_token: Option<String>,
345        expires_in: Option<i64>,
346        token_type: Option<String>,
347    }
348
349    let token_response: TokenResponse = response
350        .json()
351        .await
352        .map_err(|e| AuthError::AuthFailed(e.to_string()))?;
353
354    let expires_at = token_response
355        .expires_in
356        .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
357
358    Ok(OAuthToken {
359        access_token: token_response.access_token,
360        refresh_token: token_response
361            .refresh_token
362            .or_else(|| token.refresh_token.clone()),
363        expires_at,
364        token_type: token_response
365            .token_type
366            .unwrap_or_else(|| "Bearer".to_string()),
367        scopes: token.scopes.clone(),
368    })
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use tempfile::tempdir;
375
376    #[test]
377    fn test_auth_profile_creation() {
378        let profile = AuthProfile::new("test-profile", ProfileType::ApiKey, "anthropic");
379
380        assert_eq!(profile.id, "test-profile");
381        assert_eq!(profile.target, "anthropic");
382        assert!(profile.active);
383    }
384
385    #[test]
386    fn test_oauth_token_expiry() {
387        let token = OAuthToken {
388            access_token: "test".to_string(),
389            refresh_token: None,
390            expires_at: Some(Utc::now() - chrono::Duration::hours(1)),
391            token_type: "Bearer".to_string(),
392            scopes: vec![],
393        };
394
395        assert!(token.is_expired());
396        assert!(token.needs_refresh());
397    }
398
399    #[test]
400    fn test_oauth_token_valid() {
401        let token = OAuthToken {
402            access_token: "test".to_string(),
403            refresh_token: None,
404            expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
405            token_type: "Bearer".to_string(),
406            scopes: vec![],
407        };
408
409        assert!(!token.is_expired());
410        assert!(!token.needs_refresh());
411    }
412
413    #[test]
414    fn test_auth_credential_store() {
415        let temp = tempdir().unwrap();
416        let encryption_key: [u8; 32] = rand::random();
417
418        let mut store = AuthCredentialStore::new(encryption_key, temp.path()).unwrap();
419
420        // Create and store a profile
421        let profile = AuthProfile::new("test", ProfileType::ApiKey, "anthropic");
422        store.set_profile(profile).unwrap();
423
424        // Store API key
425        let key = ApiKey::new("sk-test-key".to_string());
426        store.store_api_key("test", &key).unwrap();
427
428        // Retrieve
429        let loaded = store.load_api_key("test").unwrap();
430        assert_eq!(loaded.expose(), "sk-test-key");
431
432        // List profiles
433        let profiles = store.list_profiles();
434        assert_eq!(profiles.len(), 1);
435    }
436
437    #[test]
438    fn test_profiles_for_target() {
439        let temp = tempdir().unwrap();
440        let encryption_key: [u8; 32] = rand::random();
441
442        let mut store = AuthCredentialStore::new(encryption_key, temp.path()).unwrap();
443
444        store
445            .set_profile(AuthProfile::new("a1", ProfileType::ApiKey, "anthropic"))
446            .unwrap();
447        store
448            .set_profile(AuthProfile::new("o1", ProfileType::ApiKey, "openai"))
449            .unwrap();
450        store
451            .set_profile(AuthProfile::new("a2", ProfileType::ApiKey, "anthropic"))
452            .unwrap();
453
454        let anthropic_profiles = store.profiles_for_target("anthropic");
455        assert_eq!(anthropic_profiles.len(), 2);
456    }
457}