1use std::collections::HashMap;
7use std::path::Path;
8
9use crate::error::{AuthError, CloudError, Result};
10
11#[derive(Debug, Clone)]
13pub enum Credentials {
14 None,
16
17 ApiKey {
19 key: String,
21 },
22
23 AccessKey {
25 access_key: String,
27 secret_key: String,
29 session_token: Option<String>,
31 },
32
33 OAuth2 {
35 access_token: String,
37 refresh_token: Option<String>,
39 expires_at: Option<chrono::DateTime<chrono::Utc>>,
41 },
42
43 ServiceAccount {
45 key_json: String,
47 project_id: Option<String>,
49 },
50
51 SasToken {
53 token: String,
55 expires_at: Option<chrono::DateTime<chrono::Utc>>,
57 },
58
59 IamRole {
61 role_arn: String,
63 session_name: String,
65 },
66
67 Custom {
69 data: HashMap<String, String>,
71 },
72}
73
74impl Credentials {
75 #[must_use]
77 pub fn api_key(key: impl Into<String>) -> Self {
78 Self::ApiKey { key: key.into() }
79 }
80
81 #[must_use]
83 pub fn access_key(access_key: impl Into<String>, secret_key: impl Into<String>) -> Self {
84 Self::AccessKey {
85 access_key: access_key.into(),
86 secret_key: secret_key.into(),
87 session_token: None,
88 }
89 }
90
91 #[must_use]
93 pub fn access_key_with_session(
94 access_key: impl Into<String>,
95 secret_key: impl Into<String>,
96 session_token: impl Into<String>,
97 ) -> Self {
98 Self::AccessKey {
99 access_key: access_key.into(),
100 secret_key: secret_key.into(),
101 session_token: Some(session_token.into()),
102 }
103 }
104
105 #[must_use]
107 pub fn oauth2(access_token: impl Into<String>) -> Self {
108 Self::OAuth2 {
109 access_token: access_token.into(),
110 refresh_token: None,
111 expires_at: None,
112 }
113 }
114
115 #[must_use]
117 pub fn oauth2_with_refresh(
118 access_token: impl Into<String>,
119 refresh_token: impl Into<String>,
120 ) -> Self {
121 Self::OAuth2 {
122 access_token: access_token.into(),
123 refresh_token: Some(refresh_token.into()),
124 expires_at: None,
125 }
126 }
127
128 pub fn service_account_from_json(json: impl Into<String>) -> Result<Self> {
130 let json_str = json.into();
131
132 let parsed: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
134 CloudError::Auth(AuthError::ServiceAccountKey {
135 message: format!("Invalid JSON: {e}"),
136 })
137 })?;
138
139 let project_id = parsed
141 .get("project_id")
142 .and_then(|v| v.as_str())
143 .map(|s| s.to_string());
144
145 Ok(Self::ServiceAccount {
146 key_json: json_str,
147 project_id,
148 })
149 }
150
151 pub fn service_account_from_file(path: impl AsRef<Path>) -> Result<Self> {
153 let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
154 CloudError::Auth(AuthError::ServiceAccountKey {
155 message: format!("Failed to read service account key file: {e}"),
156 })
157 })?;
158
159 Self::service_account_from_json(content)
160 }
161
162 #[must_use]
164 pub fn sas_token(token: impl Into<String>) -> Self {
165 Self::SasToken {
166 token: token.into(),
167 expires_at: None,
168 }
169 }
170
171 #[must_use]
173 pub fn iam_role(role_arn: impl Into<String>, session_name: impl Into<String>) -> Self {
174 Self::IamRole {
175 role_arn: role_arn.into(),
176 session_name: session_name.into(),
177 }
178 }
179
180 #[must_use]
182 pub fn is_expired(&self) -> bool {
183 let now = chrono::Utc::now();
184
185 match self {
186 Self::OAuth2 {
187 expires_at: Some(expiry),
188 ..
189 } => *expiry <= now,
190 Self::SasToken {
191 expires_at: Some(expiry),
192 ..
193 } => *expiry <= now,
194 _ => false,
195 }
196 }
197
198 #[must_use]
200 pub fn needs_refresh(&self) -> bool {
201 let now = chrono::Utc::now();
202 let buffer = chrono::Duration::minutes(5); match self {
205 Self::OAuth2 {
206 expires_at: Some(expiry),
207 ..
208 } => *expiry <= now + buffer,
209 Self::SasToken {
210 expires_at: Some(expiry),
211 ..
212 } => *expiry <= now + buffer,
213 _ => false,
214 }
215 }
216}
217
218#[cfg(feature = "async")]
220#[async_trait::async_trait]
221pub trait CredentialProvider: Send + Sync {
222 async fn load(&self) -> Result<Credentials>;
224
225 async fn refresh(&self, _credentials: &Credentials) -> Result<Credentials> {
227 self.load().await
229 }
230}
231
232pub struct EnvCredentialProvider {
234 credential_type: CredentialType,
236}
237
238#[derive(Debug, Clone, Copy)]
240pub enum CredentialType {
241 Aws,
243 Azure,
245 Gcp,
247 ApiKey,
249}
250
251impl EnvCredentialProvider {
252 #[must_use]
254 pub const fn new(credential_type: CredentialType) -> Self {
255 Self { credential_type }
256 }
257
258 fn load_aws() -> Result<Credentials> {
260 let access_key = std::env::var("AWS_ACCESS_KEY_ID").map_err(|_| {
261 CloudError::Auth(AuthError::CredentialsNotFound {
262 message: "AWS_ACCESS_KEY_ID not found".to_string(),
263 })
264 })?;
265
266 let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| {
267 CloudError::Auth(AuthError::CredentialsNotFound {
268 message: "AWS_SECRET_ACCESS_KEY not found".to_string(),
269 })
270 })?;
271
272 let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
273
274 Ok(Credentials::AccessKey {
275 access_key,
276 secret_key,
277 session_token,
278 })
279 }
280
281 fn load_azure() -> Result<Credentials> {
283 let account_name = std::env::var("AZURE_STORAGE_ACCOUNT").map_err(|_| {
284 CloudError::Auth(AuthError::CredentialsNotFound {
285 message: "AZURE_STORAGE_ACCOUNT not found".to_string(),
286 })
287 })?;
288
289 if let Ok(account_key) = std::env::var("AZURE_STORAGE_KEY") {
291 let mut data = HashMap::new();
292 data.insert("account_name".to_string(), account_name);
293 data.insert("account_key".to_string(), account_key);
294
295 Ok(Credentials::Custom { data })
296 } else if let Ok(sas_token) = std::env::var("AZURE_STORAGE_SAS_TOKEN") {
297 Ok(Credentials::SasToken {
298 token: sas_token,
299 expires_at: None,
300 })
301 } else {
302 Err(CloudError::Auth(AuthError::CredentialsNotFound {
303 message: "Neither AZURE_STORAGE_KEY nor AZURE_STORAGE_SAS_TOKEN found".to_string(),
304 }))
305 }
306 }
307
308 fn load_gcp() -> Result<Credentials> {
310 let key_file = std::env::var("GOOGLE_APPLICATION_CREDENTIALS").map_err(|_| {
311 CloudError::Auth(AuthError::CredentialsNotFound {
312 message: "GOOGLE_APPLICATION_CREDENTIALS not found".to_string(),
313 })
314 })?;
315
316 Credentials::service_account_from_file(&key_file)
317 }
318
319 fn load_api_key() -> Result<Credentials> {
321 let key = std::env::var("API_KEY")
322 .or_else(|_| std::env::var("APIKEY"))
323 .map_err(|_| {
324 CloudError::Auth(AuthError::CredentialsNotFound {
325 message: "API_KEY or APIKEY not found".to_string(),
326 })
327 })?;
328
329 Ok(Credentials::ApiKey { key })
330 }
331}
332
333#[cfg(feature = "async")]
334#[async_trait::async_trait]
335impl CredentialProvider for EnvCredentialProvider {
336 async fn load(&self) -> Result<Credentials> {
337 match self.credential_type {
338 CredentialType::Aws => Self::load_aws(),
339 CredentialType::Azure => Self::load_azure(),
340 CredentialType::Gcp => Self::load_gcp(),
341 CredentialType::ApiKey => Self::load_api_key(),
342 }
343 }
344}
345
346pub struct FileCredentialProvider {
348 path: std::path::PathBuf,
350}
351
352impl FileCredentialProvider {
353 #[must_use]
355 pub fn new(path: impl AsRef<Path>) -> Self {
356 Self {
357 path: path.as_ref().to_path_buf(),
358 }
359 }
360}
361
362#[cfg(feature = "async")]
363#[async_trait::async_trait]
364impl CredentialProvider for FileCredentialProvider {
365 async fn load(&self) -> Result<Credentials> {
366 Credentials::service_account_from_file(&self.path)
367 }
368}
369
370pub struct ChainCredentialProvider {
372 providers: Vec<Box<dyn CredentialProvider>>,
374}
375
376impl ChainCredentialProvider {
377 #[must_use]
379 pub fn new() -> Self {
380 Self {
381 providers: Vec::new(),
382 }
383 }
384
385 #[must_use]
387 pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
388 self.providers.push(provider);
389 self
390 }
391}
392
393impl Default for ChainCredentialProvider {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[cfg(feature = "async")]
400#[async_trait::async_trait]
401impl CredentialProvider for ChainCredentialProvider {
402 async fn load(&self) -> Result<Credentials> {
403 for provider in &self.providers {
404 if let Ok(credentials) = provider.load().await {
405 return Ok(credentials);
406 }
407 }
408
409 Err(CloudError::Auth(AuthError::CredentialsNotFound {
410 message: "No credential provider succeeded".to_string(),
411 }))
412 }
413}
414
415#[cfg(test)]
416#[allow(clippy::panic)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_credentials_api_key() {
422 let creds = Credentials::api_key("test-key");
423 match creds {
424 Credentials::ApiKey { key } => assert_eq!(key, "test-key"),
425 _ => panic!("Expected ApiKey credentials"),
426 }
427 }
428
429 #[test]
430 fn test_credentials_access_key() {
431 let creds = Credentials::access_key("access", "secret");
432 match creds {
433 Credentials::AccessKey {
434 access_key,
435 secret_key,
436 session_token,
437 } => {
438 assert_eq!(access_key, "access");
439 assert_eq!(secret_key, "secret");
440 assert!(session_token.is_none());
441 }
442 _ => panic!("Expected AccessKey credentials"),
443 }
444 }
445
446 #[test]
447 fn test_credentials_oauth2() {
448 let creds = Credentials::oauth2("token");
449 match creds {
450 Credentials::OAuth2 { access_token, .. } => assert_eq!(access_token, "token"),
451 _ => panic!("Expected OAuth2 credentials"),
452 }
453 }
454
455 #[test]
456 fn test_credentials_sas_token() {
457 let creds = Credentials::sas_token("token");
458 match creds {
459 Credentials::SasToken { token, .. } => assert_eq!(token, "token"),
460 _ => panic!("Expected SasToken credentials"),
461 }
462 }
463
464 #[test]
465 fn test_credentials_iam_role() {
466 let creds = Credentials::iam_role("arn:aws:iam::123:role/test", "session");
467 match creds {
468 Credentials::IamRole {
469 role_arn,
470 session_name,
471 } => {
472 assert_eq!(role_arn, "arn:aws:iam::123:role/test");
473 assert_eq!(session_name, "session");
474 }
475 _ => panic!("Expected IamRole credentials"),
476 }
477 }
478
479 #[test]
480 fn test_credentials_service_account_from_json() {
481 let json = r#"{"type":"service_account","project_id":"test-project"}"#;
482 let creds = Credentials::service_account_from_json(json);
483 assert!(creds.is_ok());
484
485 match creds.ok() {
486 Some(Credentials::ServiceAccount {
487 project_id: Some(project_id),
488 ..
489 }) => {
490 assert_eq!(project_id, "test-project");
491 }
492 _ => panic!("Expected ServiceAccount credentials with project_id"),
493 }
494 }
495
496 #[test]
497 fn test_credentials_is_expired() {
498 let now = chrono::Utc::now();
499 let past = now - chrono::Duration::hours(1);
500 let future = now + chrono::Duration::hours(1);
501
502 let expired = Credentials::OAuth2 {
503 access_token: "token".to_string(),
504 refresh_token: None,
505 expires_at: Some(past),
506 };
507 assert!(expired.is_expired());
508
509 let valid = Credentials::OAuth2 {
510 access_token: "token".to_string(),
511 refresh_token: None,
512 expires_at: Some(future),
513 };
514 assert!(!valid.is_expired());
515 }
516
517 #[test]
518 fn test_credentials_needs_refresh() {
519 let now = chrono::Utc::now();
520 let soon = now + chrono::Duration::minutes(3); let later = now + chrono::Duration::hours(1);
522
523 let needs_refresh = Credentials::OAuth2 {
524 access_token: "token".to_string(),
525 refresh_token: None,
526 expires_at: Some(soon),
527 };
528 assert!(needs_refresh.needs_refresh());
529
530 let valid = Credentials::OAuth2 {
531 access_token: "token".to_string(),
532 refresh_token: None,
533 expires_at: Some(later),
534 };
535 assert!(!valid.needs_refresh());
536 }
537}