aether-auth 0.1.5

OAuth credential storage and authorization flows for the Aether AI agent framework
Documentation
use crate::credential::{OAuthCredential, OAuthCredentialStorage};
use crate::error::OAuthError;
use age::scrypt::Identity;
use age::secrecy::SecretString;
use age::{Decryptor, Encryptor};
use async_trait::async_trait;
use dirs::home_dir;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env::var;
use std::io::{Read, Write};
use std::iter::once;
use std::path::{Path, PathBuf};
use std::sync::Mutex;

const DEFAULT_PASSWORD_ENV: &str = "AETHER_CREDENTIALS_PASSWORD";

#[derive(Debug, Clone, Serialize, Deserialize)]
struct CredentialStore {
    credentials: HashMap<String, OAuthCredential>,
}

pub struct EncryptedFileOAuthCredentialStorage {
    path: PathBuf,
    passphrase: String,
    write_guard: Mutex<()>,
}

impl EncryptedFileOAuthCredentialStorage {
    pub fn new(path: PathBuf, passphrase: String) -> Self {
        Self { path, passphrase, write_guard: Mutex::new(()) }
    }

    pub fn from_settings(path: Option<PathBuf>, password_env: Option<&str>) -> Result<Self, OAuthError> {
        let path = path.map_or_else(default_path, Ok)?;
        let env_var = password_env.unwrap_or(DEFAULT_PASSWORD_ENV);
        let passphrase = var(env_var).ok().filter(|pass| !pass.is_empty()).ok_or_else(|| {
            OAuthError::CredentialStore(format!(
                "Encrypted file credential store requires a passphrase. \
                     Set the {env_var} environment variable or configure a custom `passwordEnv` in settings."
            ))
        })?;

        Ok(Self::new(path, passphrase))
    }

    fn encrypt(plaintext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
        let fail = |e| OAuthError::CredentialStore(format!("Encryption failed: {e}"));

        let mut ciphertext = Vec::new();
        let mut writer = Encryptor::with_user_passphrase(SecretString::from(passphrase))
            .wrap_output(&mut ciphertext)
            .map_err(fail)?;
        writer.write_all(plaintext).map_err(fail)?;
        writer.finish().map_err(fail)?;

        Ok(ciphertext)
    }

    fn decrypt(ciphertext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
        let decryptor = Decryptor::new(ciphertext)
            .map_err(|e| OAuthError::CredentialStore(format!("Invalid encrypted file: {e}")))?;

        let mut reader =
            decryptor.decrypt(once(&Identity::new(SecretString::from(passphrase)) as &dyn age::Identity)).map_err(
                |e| OAuthError::CredentialStore(format!("Decryption failed — wrong passphrase or corrupted file: {e}")),
            )?;

        let mut plaintext = Vec::new();
        reader
            .read_to_end(&mut plaintext)
            .map_err(|e| OAuthError::CredentialStore(format!("Decryption failed: {e}")))?;

        Ok(plaintext)
    }

    fn load(&self) -> Result<CredentialStore, OAuthError> {
        if !self.path.exists() {
            return Ok(CredentialStore { credentials: HashMap::new() });
        }

        let bytes = std::fs::read(&self.path)?;
        if bytes.is_empty() {
            return Ok(CredentialStore { credentials: HashMap::new() });
        }

        let plaintext = Self::decrypt(&bytes, &self.passphrase)?;
        serde_json::from_slice(&plaintext)
            .map_err(|e| OAuthError::CredentialStore(format!("Invalid credential data: {e}")))
    }

    fn update(&self, mutate: impl FnOnce(&mut CredentialStore)) -> Result<(), OAuthError> {
        let _guard = self
            .write_guard
            .lock()
            .map_err(|_| OAuthError::CredentialStore("Failed to acquire write lock on credential store".to_string()))?;

        let mut store = self.load()?;
        mutate(&mut store);

        let plaintext = serde_json::to_vec(&store)
            .map_err(|e| OAuthError::CredentialStore(format!("Failed to serialize credentials: {e}")))?;

        let ciphertext = Self::encrypt(&plaintext, &self.passphrase)?;
        write_atomic(&self.path, &ciphertext)
    }
}

fn default_path() -> Result<PathBuf, OAuthError> {
    home_dir().map(|home| home.join(".aether").join("credentials.enc")).ok_or_else(|| {
        OAuthError::CredentialStore(
            "Could not determine the home directory for the encrypted credential file".to_string(),
        )
    })
}

fn write_atomic(path: &Path, data: &[u8]) -> Result<(), OAuthError> {
    if let Some(parent) = path.parent() {
        std::fs::create_dir_all(parent)?;
    }

    let temp_path = path.with_extension("tmp");

    {
        let mut file = std::fs::File::create(&temp_path)?;
        file.write_all(data)?;
        file.sync_all()?;
    }

    std::fs::rename(&temp_path, path)?;

    Ok(())
}

#[async_trait]
impl OAuthCredentialStorage for EncryptedFileOAuthCredentialStorage {
    async fn load_credential(&self, key: &str) -> Result<Option<OAuthCredential>, OAuthError> {
        let store = self.load()?;
        Ok(store.credentials.get(key).cloned())
    }

    async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
        self.update(|store| {
            store.credentials.insert(key.to_string(), credential);
        })
    }

    async fn delete_credential(&self, key: &str) -> Result<(), OAuthError> {
        self.update(|store| {
            store.credentials.remove(key);
        })
    }

    fn has_credential(&self, key: &str) -> bool {
        self.load().is_ok_and(|store| store.credentials.contains_key(key))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn save_then_load_round_trips() {
        let store = temp_store("correct-passphrase");
        let cred = test_credential();

        store.save_credential("server-1", cred.clone()).await.unwrap();
        let loaded = store.load_credential("server-1").await.unwrap().unwrap();

        assert_eq!(loaded.client_id, "client_1");
        assert_eq!(loaded.access_token, "tok_abc");
        assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
        assert_eq!(loaded.granted_scopes, vec!["scope1"]);
    }

    #[tokio::test]
    async fn load_returns_none_for_missing_key() {
        let store = temp_store("pass");
        assert!(store.load_credential("nonexistent").await.unwrap().is_none());
    }

    #[tokio::test]
    async fn delete_removes_credential() {
        let store = temp_store("pass");
        store.save_credential("server-1", test_credential()).await.unwrap();
        assert!(store.has_credential("server-1"));

        store.delete_credential("server-1").await.unwrap();
        assert!(!store.has_credential("server-1"));
    }

    #[tokio::test]
    async fn wrong_passphrase_fails_to_load() {
        let store = temp_store("correct-pass");
        store.save_credential("server-1", test_credential()).await.unwrap();

        let wrong_store = EncryptedFileOAuthCredentialStorage::new(store.path.clone(), "wrong-pass".to_string());

        let err = wrong_store.load_credential("server-1").await.unwrap_err();
        let msg = err.to_string();
        assert!(msg.contains("Decryption failed"), "Expected decryption error, got: {msg}");
    }

    #[tokio::test]
    async fn multiple_credentials_are_isolated() {
        let store = temp_store("pass");

        let cred_a = OAuthCredential {
            client_id: "a".to_string(),
            access_token: "token_a".to_string(),
            refresh_token: None,
            expires_at: None,
            granted_scopes: vec![],
        };
        let cred_b = OAuthCredential {
            client_id: "b".to_string(),
            access_token: "token_b".to_string(),
            refresh_token: None,
            expires_at: None,
            granted_scopes: vec![],
        };

        store.save_credential("server-a", cred_a).await.unwrap();
        store.save_credential("server-b", cred_b).await.unwrap();

        let loaded_a = store.load_credential("server-a").await.unwrap().unwrap();
        let loaded_b = store.load_credential("server-b").await.unwrap().unwrap();

        assert_eq!(loaded_a.access_token, "token_a");
        assert_eq!(loaded_b.access_token, "token_b");
    }

    #[tokio::test]
    async fn save_overwrites_existing_credential() {
        let store = temp_store("pass");
        let cred_v1 = OAuthCredential {
            client_id: "c".to_string(),
            access_token: "v1".to_string(),
            refresh_token: None,
            expires_at: None,
            granted_scopes: vec![],
        };
        let cred_v2 = OAuthCredential {
            client_id: "c".to_string(),
            access_token: "v2".to_string(),
            refresh_token: None,
            expires_at: None,
            granted_scopes: vec![],
        };

        store.save_credential("server", cred_v1).await.unwrap();
        store.save_credential("server", cred_v2).await.unwrap();

        let loaded = store.load_credential("server").await.unwrap().unwrap();
        assert_eq!(loaded.access_token, "v2");
    }

    #[test]
    fn encrypt_decrypt_round_trips() {
        let plaintext = b"hello, world!";
        let ciphertext = EncryptedFileOAuthCredentialStorage::encrypt(plaintext, "passphrase").unwrap();
        let decrypted = EncryptedFileOAuthCredentialStorage::decrypt(&ciphertext, "passphrase").unwrap();
        assert_eq!(decrypted, plaintext);
    }

    fn test_credential() -> OAuthCredential {
        OAuthCredential {
            client_id: "client_1".to_string(),
            access_token: "tok_abc".to_string(),
            refresh_token: Some("ref_xyz".to_string()),
            expires_at: Some(9_999_999_999_999),
            granted_scopes: vec!["scope1".to_string()],
        }
    }

    fn temp_store(passphrase: &str) -> EncryptedFileOAuthCredentialStorage {
        let dir = tempfile::tempdir().unwrap();
        let path = dir.keep().join("creds.enc");
        EncryptedFileOAuthCredentialStorage::new(path, passphrase.to_string())
    }
}