agent-diva-core 0.3.0

Core types and traits for agent-diva
Documentation
use crate::auth::profiles::{ProviderAuthProfile, ProviderAuthProfilesData};
use anyhow::{Context, Result};
use chrono::Utc;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::fs;
use tokio::time::{sleep, Instant};

const AUTH_DIR: &str = "data/auth";
const AUTH_FILENAME: &str = "profiles.json";
const LOCK_FILENAME: &str = "profiles.lock";
const LOCK_RETRY_MS: u64 = 50;
const LOCK_TIMEOUT_MS: u64 = 10_000;

#[derive(Debug, Clone)]
pub struct ProviderAuthStore {
    path: PathBuf,
    lock_path: PathBuf,
}

impl ProviderAuthStore {
    pub fn new(config_dir: &Path) -> Self {
        let auth_dir = config_dir.join(AUTH_DIR);
        Self {
            path: auth_dir.join(AUTH_FILENAME),
            lock_path: auth_dir.join(LOCK_FILENAME),
        }
    }

    pub fn path(&self) -> &Path {
        &self.path
    }

    pub async fn load(&self) -> Result<ProviderAuthProfilesData> {
        let _lock = self.acquire_lock().await?;
        self.load_locked().await
    }

    pub async fn upsert_profile(
        &self,
        mut profile: ProviderAuthProfile,
        set_active: bool,
    ) -> Result<()> {
        let _lock = self.acquire_lock().await?;
        let mut data = self.load_locked().await?;

        profile.updated_at = Utc::now();
        if let Some(existing) = data.profiles.get(&profile.id) {
            profile.created_at = existing.created_at;
        }

        if set_active {
            data.active_profiles
                .insert(profile.provider.clone(), profile.id.clone());
        }
        data.profiles.insert(profile.id.clone(), profile);
        data.updated_at = Utc::now();
        self.save_locked(&data).await
    }

    pub async fn remove_profile(&self, profile_id: &str) -> Result<bool> {
        let _lock = self.acquire_lock().await?;
        let mut data = self.load_locked().await?;
        let removed = data.profiles.remove(profile_id).is_some();
        if removed {
            data.active_profiles.retain(|_, id| id != profile_id);
            data.updated_at = Utc::now();
            self.save_locked(&data).await?;
        }
        Ok(removed)
    }

    pub async fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> {
        let _lock = self.acquire_lock().await?;
        let mut data = self.load_locked().await?;
        if !data.profiles.contains_key(profile_id) {
            anyhow::bail!("Auth profile not found: {profile_id}");
        }
        data.active_profiles
            .insert(provider.to_string(), profile_id.to_string());
        data.updated_at = Utc::now();
        self.save_locked(&data).await
    }

    pub async fn clear_active_profile(&self, provider: &str) -> Result<()> {
        let _lock = self.acquire_lock().await?;
        let mut data = self.load_locked().await?;
        data.active_profiles.remove(provider);
        data.updated_at = Utc::now();
        self.save_locked(&data).await
    }

    pub async fn update_profile<F>(
        &self,
        profile_id: &str,
        mut updater: F,
    ) -> Result<ProviderAuthProfile>
    where
        F: FnMut(&mut ProviderAuthProfile) -> Result<()>,
    {
        let _lock = self.acquire_lock().await?;
        let mut data = self.load_locked().await?;
        let profile = data
            .profiles
            .get_mut(profile_id)
            .ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
        updater(profile)?;
        profile.updated_at = Utc::now();
        let updated = profile.clone();
        data.updated_at = Utc::now();
        self.save_locked(&data).await?;
        Ok(updated)
    }

    async fn load_locked(&self) -> Result<ProviderAuthProfilesData> {
        if !self.path.exists() {
            return Ok(ProviderAuthProfilesData::default());
        }
        let raw = fs::read_to_string(&self.path)
            .await
            .with_context(|| format!("Failed to read auth store {}", self.path.display()))?;
        let data = serde_json::from_str(&raw)
            .with_context(|| format!("Failed to parse auth store {}", self.path.display()))?;
        Ok(data)
    }

    async fn save_locked(&self, data: &ProviderAuthProfilesData) -> Result<()> {
        if let Some(parent) = self.path.parent() {
            fs::create_dir_all(parent).await.with_context(|| {
                format!("Failed to create auth store directory {}", parent.display())
            })?;
        }
        let temp_path = self.path.with_extension("json.tmp");
        let payload = serde_json::to_vec_pretty(data)?;
        fs::write(&temp_path, payload)
            .await
            .with_context(|| format!("Failed to write auth temp file {}", temp_path.display()))?;
        fs::rename(&temp_path, &self.path).await.with_context(|| {
            format!(
                "Failed to move auth temp file {} into {}",
                temp_path.display(),
                self.path.display()
            )
        })?;
        Ok(())
    }

    async fn acquire_lock(&self) -> Result<LockGuard> {
        if let Some(parent) = self.lock_path.parent() {
            fs::create_dir_all(parent).await.with_context(|| {
                format!("Failed to create auth lock directory {}", parent.display())
            })?;
        }

        let deadline = Instant::now() + Duration::from_millis(LOCK_TIMEOUT_MS);
        loop {
            match fs::OpenOptions::new()
                .write(true)
                .create_new(true)
                .open(&self.lock_path)
                .await
            {
                Ok(_) => return Ok(LockGuard(self.lock_path.clone())),
                Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
                    if Instant::now() >= deadline {
                        anyhow::bail!("Timed out acquiring auth store lock");
                    }
                    sleep(Duration::from_millis(LOCK_RETRY_MS)).await;
                }
                Err(err) => {
                    return Err(err).with_context(|| {
                        format!("Failed to open auth lock {}", self.lock_path.display())
                    })
                }
            }
        }
    }
}

struct LockGuard(PathBuf);

impl Drop for LockGuard {
    fn drop(&mut self) {
        let _ = std::fs::remove_file(&self.0);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::auth::profiles::{
        profile_id, ProviderAuthKind, ProviderAuthProfile, ProviderTokenSet,
    };
    use tempfile::tempdir;

    fn oauth_profile() -> ProviderAuthProfile {
        ProviderAuthProfile::new_oauth(
            "openai-codex",
            "default",
            ProviderTokenSet {
                access_token: "access".into(),
                refresh_token: Some("refresh".into()),
                id_token: None,
                expires_at: None,
                token_type: Some("Bearer".into()),
                scope: Some("openid".into()),
            },
        )
    }

    #[tokio::test]
    async fn upsert_load_remove_profile_roundtrip() {
        let dir = tempdir().unwrap();
        let store = ProviderAuthStore::new(dir.path());
        store.upsert_profile(oauth_profile(), true).await.unwrap();

        let loaded = store.load().await.unwrap();
        assert_eq!(
            loaded.active_profiles.get("openai-codex").unwrap(),
            "openai-codex:default"
        );
        assert!(loaded.profiles.contains_key("openai-codex:default"));

        let removed = store.remove_profile("openai-codex:default").await.unwrap();
        assert!(removed);
        let loaded = store.load().await.unwrap();
        assert!(!loaded.profiles.contains_key("openai-codex:default"));
    }

    #[tokio::test]
    async fn set_and_clear_active_profile() {
        let dir = tempdir().unwrap();
        let store = ProviderAuthStore::new(dir.path());
        store.upsert_profile(oauth_profile(), false).await.unwrap();
        store
            .set_active_profile("openai-codex", &profile_id("openai-codex", "default"))
            .await
            .unwrap();
        assert_eq!(
            store
                .load()
                .await
                .unwrap()
                .active_profiles
                .get("openai-codex")
                .cloned(),
            Some("openai-codex:default".into())
        );
        store.clear_active_profile("openai-codex").await.unwrap();
        assert!(!store
            .load()
            .await
            .unwrap()
            .active_profiles
            .contains_key("openai-codex"));
    }

    #[tokio::test]
    async fn update_profile_changes_token() {
        let dir = tempdir().unwrap();
        let store = ProviderAuthStore::new(dir.path());
        store.upsert_profile(oauth_profile(), true).await.unwrap();
        let updated = store
            .update_profile("openai-codex:default", |profile| {
                profile.kind = ProviderAuthKind::Token;
                profile.token_set = None;
                profile.token = Some("plain".into());
                Ok(())
            })
            .await
            .unwrap();
        assert_eq!(updated.token.as_deref(), Some("plain"));
    }

    #[tokio::test]
    async fn damaged_file_returns_error() {
        let dir = tempdir().unwrap();
        let store = ProviderAuthStore::new(dir.path());
        std::fs::create_dir_all(store.path().parent().unwrap()).unwrap();
        std::fs::write(store.path(), "{not-json").unwrap();
        let err = store.load().await.unwrap_err().to_string();
        assert!(err.contains("Failed to parse auth store"));
    }
}