Skip to main content

oxigdal_cloud/
auth.rs

1//! Authentication strategies for cloud storage backends
2//!
3//! This module provides various authentication methods for cloud providers,
4//! including OAuth 2.0, service accounts, API keys, SAS tokens, and IAM roles.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use crate::error::{AuthError, CloudError, Result};
10
11/// Authentication credentials
12#[derive(Debug, Clone)]
13pub enum Credentials {
14    /// No authentication
15    None,
16
17    /// API key authentication
18    ApiKey {
19        /// API key
20        key: String,
21    },
22
23    /// Access key and secret key (AWS-style)
24    AccessKey {
25        /// Access key ID
26        access_key: String,
27        /// Secret access key
28        secret_key: String,
29        /// Optional session token
30        session_token: Option<String>,
31    },
32
33    /// OAuth 2.0 token
34    OAuth2 {
35        /// Access token
36        access_token: String,
37        /// Optional refresh token
38        refresh_token: Option<String>,
39        /// Token expiration time
40        expires_at: Option<chrono::DateTime<chrono::Utc>>,
41    },
42
43    /// Service account key (GCP-style JSON)
44    ServiceAccount {
45        /// Service account key JSON
46        key_json: String,
47        /// Project ID
48        project_id: Option<String>,
49    },
50
51    /// Shared Access Signature token (Azure-style)
52    SasToken {
53        /// SAS token
54        token: String,
55        /// Token expiration time
56        expires_at: Option<chrono::DateTime<chrono::Utc>>,
57    },
58
59    /// IAM role credentials
60    IamRole {
61        /// Role ARN
62        role_arn: String,
63        /// Session name
64        session_name: String,
65    },
66
67    /// Custom credentials with arbitrary key-value pairs
68    Custom {
69        /// Credential data
70        data: HashMap<String, String>,
71    },
72}
73
74impl Credentials {
75    /// Creates API key credentials
76    #[must_use]
77    pub fn api_key(key: impl Into<String>) -> Self {
78        Self::ApiKey { key: key.into() }
79    }
80
81    /// Creates access key credentials
82    #[must_use]
83    pub fn access_key(access_key: impl Into<String>, secret_key: impl Into<String>) -> Self {
84        Self::AccessKey {
85            access_key: access_key.into(),
86            secret_key: secret_key.into(),
87            session_token: None,
88        }
89    }
90
91    /// Creates access key credentials with session token
92    #[must_use]
93    pub fn access_key_with_session(
94        access_key: impl Into<String>,
95        secret_key: impl Into<String>,
96        session_token: impl Into<String>,
97    ) -> Self {
98        Self::AccessKey {
99            access_key: access_key.into(),
100            secret_key: secret_key.into(),
101            session_token: Some(session_token.into()),
102        }
103    }
104
105    /// Creates OAuth 2.0 credentials
106    #[must_use]
107    pub fn oauth2(access_token: impl Into<String>) -> Self {
108        Self::OAuth2 {
109            access_token: access_token.into(),
110            refresh_token: None,
111            expires_at: None,
112        }
113    }
114
115    /// Creates OAuth 2.0 credentials with refresh token
116    #[must_use]
117    pub fn oauth2_with_refresh(
118        access_token: impl Into<String>,
119        refresh_token: impl Into<String>,
120    ) -> Self {
121        Self::OAuth2 {
122            access_token: access_token.into(),
123            refresh_token: Some(refresh_token.into()),
124            expires_at: None,
125        }
126    }
127
128    /// Creates service account credentials from JSON
129    pub fn service_account_from_json(json: impl Into<String>) -> Result<Self> {
130        let json_str = json.into();
131
132        // Try to parse JSON to validate
133        let parsed: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
134            CloudError::Auth(AuthError::ServiceAccountKey {
135                message: format!("Invalid JSON: {e}"),
136            })
137        })?;
138
139        // Extract project ID if available
140        let project_id = parsed
141            .get("project_id")
142            .and_then(|v| v.as_str())
143            .map(|s| s.to_string());
144
145        Ok(Self::ServiceAccount {
146            key_json: json_str,
147            project_id,
148        })
149    }
150
151    /// Creates service account credentials from file
152    pub fn service_account_from_file(path: impl AsRef<Path>) -> Result<Self> {
153        let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
154            CloudError::Auth(AuthError::ServiceAccountKey {
155                message: format!("Failed to read service account key file: {e}"),
156            })
157        })?;
158
159        Self::service_account_from_json(content)
160    }
161
162    /// Creates SAS token credentials
163    #[must_use]
164    pub fn sas_token(token: impl Into<String>) -> Self {
165        Self::SasToken {
166            token: token.into(),
167            expires_at: None,
168        }
169    }
170
171    /// Creates IAM role credentials
172    #[must_use]
173    pub fn iam_role(role_arn: impl Into<String>, session_name: impl Into<String>) -> Self {
174        Self::IamRole {
175            role_arn: role_arn.into(),
176            session_name: session_name.into(),
177        }
178    }
179
180    /// Checks if credentials are expired
181    #[must_use]
182    pub fn is_expired(&self) -> bool {
183        let now = chrono::Utc::now();
184
185        match self {
186            Self::OAuth2 {
187                expires_at: Some(expiry),
188                ..
189            } => *expiry <= now,
190            Self::SasToken {
191                expires_at: Some(expiry),
192                ..
193            } => *expiry <= now,
194            _ => false,
195        }
196    }
197
198    /// Returns true if credentials need refresh
199    #[must_use]
200    pub fn needs_refresh(&self) -> bool {
201        let now = chrono::Utc::now();
202        let buffer = chrono::Duration::minutes(5); // Refresh 5 minutes before expiry
203
204        match self {
205            Self::OAuth2 {
206                expires_at: Some(expiry),
207                ..
208            } => *expiry <= now + buffer,
209            Self::SasToken {
210                expires_at: Some(expiry),
211                ..
212            } => *expiry <= now + buffer,
213            _ => false,
214        }
215    }
216}
217
218/// Credential provider trait for dynamic credential loading
219#[cfg(feature = "async")]
220#[async_trait::async_trait]
221pub trait CredentialProvider: Send + Sync {
222    /// Loads credentials
223    async fn load(&self) -> Result<Credentials>;
224
225    /// Refreshes credentials if needed
226    async fn refresh(&self, _credentials: &Credentials) -> Result<Credentials> {
227        // Default implementation: just reload
228        self.load().await
229    }
230}
231
232/// Environment variable credential provider
233pub struct EnvCredentialProvider {
234    /// Credential type
235    credential_type: CredentialType,
236}
237
238/// Supported credential types for environment variable provider
239#[derive(Debug, Clone, Copy)]
240pub enum CredentialType {
241    /// AWS access key credentials
242    Aws,
243    /// Azure storage credentials
244    Azure,
245    /// GCP service account credentials
246    Gcp,
247    /// Generic API key
248    ApiKey,
249}
250
251impl EnvCredentialProvider {
252    /// Creates a new environment variable credential provider
253    #[must_use]
254    pub const fn new(credential_type: CredentialType) -> Self {
255        Self { credential_type }
256    }
257
258    /// Loads AWS credentials from environment variables
259    fn load_aws() -> Result<Credentials> {
260        let access_key = std::env::var("AWS_ACCESS_KEY_ID").map_err(|_| {
261            CloudError::Auth(AuthError::CredentialsNotFound {
262                message: "AWS_ACCESS_KEY_ID not found".to_string(),
263            })
264        })?;
265
266        let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| {
267            CloudError::Auth(AuthError::CredentialsNotFound {
268                message: "AWS_SECRET_ACCESS_KEY not found".to_string(),
269            })
270        })?;
271
272        let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
273
274        Ok(Credentials::AccessKey {
275            access_key,
276            secret_key,
277            session_token,
278        })
279    }
280
281    /// Loads Azure credentials from environment variables
282    fn load_azure() -> Result<Credentials> {
283        let account_name = std::env::var("AZURE_STORAGE_ACCOUNT").map_err(|_| {
284            CloudError::Auth(AuthError::CredentialsNotFound {
285                message: "AZURE_STORAGE_ACCOUNT not found".to_string(),
286            })
287        })?;
288
289        // Try account key first, then SAS token
290        if let Ok(account_key) = std::env::var("AZURE_STORAGE_KEY") {
291            let mut data = HashMap::new();
292            data.insert("account_name".to_string(), account_name);
293            data.insert("account_key".to_string(), account_key);
294
295            Ok(Credentials::Custom { data })
296        } else if let Ok(sas_token) = std::env::var("AZURE_STORAGE_SAS_TOKEN") {
297            Ok(Credentials::SasToken {
298                token: sas_token,
299                expires_at: None,
300            })
301        } else {
302            Err(CloudError::Auth(AuthError::CredentialsNotFound {
303                message: "Neither AZURE_STORAGE_KEY nor AZURE_STORAGE_SAS_TOKEN found".to_string(),
304            }))
305        }
306    }
307
308    /// Loads GCP credentials from environment variables
309    fn load_gcp() -> Result<Credentials> {
310        let key_file = std::env::var("GOOGLE_APPLICATION_CREDENTIALS").map_err(|_| {
311            CloudError::Auth(AuthError::CredentialsNotFound {
312                message: "GOOGLE_APPLICATION_CREDENTIALS not found".to_string(),
313            })
314        })?;
315
316        Credentials::service_account_from_file(&key_file)
317    }
318
319    /// Loads API key from environment variables
320    fn load_api_key() -> Result<Credentials> {
321        let key = std::env::var("API_KEY")
322            .or_else(|_| std::env::var("APIKEY"))
323            .map_err(|_| {
324                CloudError::Auth(AuthError::CredentialsNotFound {
325                    message: "API_KEY or APIKEY not found".to_string(),
326                })
327            })?;
328
329        Ok(Credentials::ApiKey { key })
330    }
331}
332
333#[cfg(feature = "async")]
334#[async_trait::async_trait]
335impl CredentialProvider for EnvCredentialProvider {
336    async fn load(&self) -> Result<Credentials> {
337        match self.credential_type {
338            CredentialType::Aws => Self::load_aws(),
339            CredentialType::Azure => Self::load_azure(),
340            CredentialType::Gcp => Self::load_gcp(),
341            CredentialType::ApiKey => Self::load_api_key(),
342        }
343    }
344}
345
346/// File-based credential provider
347pub struct FileCredentialProvider {
348    /// Path to credentials file
349    path: std::path::PathBuf,
350}
351
352impl FileCredentialProvider {
353    /// Creates a new file credential provider
354    #[must_use]
355    pub fn new(path: impl AsRef<Path>) -> Self {
356        Self {
357            path: path.as_ref().to_path_buf(),
358        }
359    }
360}
361
362#[cfg(feature = "async")]
363#[async_trait::async_trait]
364impl CredentialProvider for FileCredentialProvider {
365    async fn load(&self) -> Result<Credentials> {
366        Credentials::service_account_from_file(&self.path)
367    }
368}
369
370/// Chain credential provider that tries multiple providers in order
371pub struct ChainCredentialProvider {
372    /// List of credential providers
373    providers: Vec<Box<dyn CredentialProvider>>,
374}
375
376impl ChainCredentialProvider {
377    /// Creates a new chain credential provider
378    #[must_use]
379    pub fn new() -> Self {
380        Self {
381            providers: Vec::new(),
382        }
383    }
384
385    /// Adds a credential provider to the chain
386    #[must_use]
387    pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
388        self.providers.push(provider);
389        self
390    }
391}
392
393impl Default for ChainCredentialProvider {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399#[cfg(feature = "async")]
400#[async_trait::async_trait]
401impl CredentialProvider for ChainCredentialProvider {
402    async fn load(&self) -> Result<Credentials> {
403        for provider in &self.providers {
404            if let Ok(credentials) = provider.load().await {
405                return Ok(credentials);
406            }
407        }
408
409        Err(CloudError::Auth(AuthError::CredentialsNotFound {
410            message: "No credential provider succeeded".to_string(),
411        }))
412    }
413}
414
415#[cfg(test)]
416#[allow(clippy::panic)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_credentials_api_key() {
422        let creds = Credentials::api_key("test-key");
423        match creds {
424            Credentials::ApiKey { key } => assert_eq!(key, "test-key"),
425            _ => panic!("Expected ApiKey credentials"),
426        }
427    }
428
429    #[test]
430    fn test_credentials_access_key() {
431        let creds = Credentials::access_key("access", "secret");
432        match creds {
433            Credentials::AccessKey {
434                access_key,
435                secret_key,
436                session_token,
437            } => {
438                assert_eq!(access_key, "access");
439                assert_eq!(secret_key, "secret");
440                assert!(session_token.is_none());
441            }
442            _ => panic!("Expected AccessKey credentials"),
443        }
444    }
445
446    #[test]
447    fn test_credentials_oauth2() {
448        let creds = Credentials::oauth2("token");
449        match creds {
450            Credentials::OAuth2 { access_token, .. } => assert_eq!(access_token, "token"),
451            _ => panic!("Expected OAuth2 credentials"),
452        }
453    }
454
455    #[test]
456    fn test_credentials_sas_token() {
457        let creds = Credentials::sas_token("token");
458        match creds {
459            Credentials::SasToken { token, .. } => assert_eq!(token, "token"),
460            _ => panic!("Expected SasToken credentials"),
461        }
462    }
463
464    #[test]
465    fn test_credentials_iam_role() {
466        let creds = Credentials::iam_role("arn:aws:iam::123:role/test", "session");
467        match creds {
468            Credentials::IamRole {
469                role_arn,
470                session_name,
471            } => {
472                assert_eq!(role_arn, "arn:aws:iam::123:role/test");
473                assert_eq!(session_name, "session");
474            }
475            _ => panic!("Expected IamRole credentials"),
476        }
477    }
478
479    #[test]
480    fn test_credentials_service_account_from_json() {
481        let json = r#"{"type":"service_account","project_id":"test-project"}"#;
482        let creds = Credentials::service_account_from_json(json);
483        assert!(creds.is_ok());
484
485        match creds.ok() {
486            Some(Credentials::ServiceAccount {
487                project_id: Some(project_id),
488                ..
489            }) => {
490                assert_eq!(project_id, "test-project");
491            }
492            _ => panic!("Expected ServiceAccount credentials with project_id"),
493        }
494    }
495
496    #[test]
497    fn test_credentials_is_expired() {
498        let now = chrono::Utc::now();
499        let past = now - chrono::Duration::hours(1);
500        let future = now + chrono::Duration::hours(1);
501
502        let expired = Credentials::OAuth2 {
503            access_token: "token".to_string(),
504            refresh_token: None,
505            expires_at: Some(past),
506        };
507        assert!(expired.is_expired());
508
509        let valid = Credentials::OAuth2 {
510            access_token: "token".to_string(),
511            refresh_token: None,
512            expires_at: Some(future),
513        };
514        assert!(!valid.is_expired());
515    }
516
517    #[test]
518    fn test_credentials_needs_refresh() {
519        let now = chrono::Utc::now();
520        let soon = now + chrono::Duration::minutes(3); // Within 5-minute buffer
521        let later = now + chrono::Duration::hours(1);
522
523        let needs_refresh = Credentials::OAuth2 {
524            access_token: "token".to_string(),
525            refresh_token: None,
526            expires_at: Some(soon),
527        };
528        assert!(needs_refresh.needs_refresh());
529
530        let valid = Credentials::OAuth2 {
531            access_token: "token".to_string(),
532            refresh_token: None,
533            expires_at: Some(later),
534        };
535        assert!(!valid.needs_refresh());
536    }
537}