use crate::models::auth::ProviderAuth;
use crate::oauth::error::{OAuthError, OAuthResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
const AUTH_FILE_NAME: &str = "auth.toml";
const ALL_PROFILE: &str = "all";
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuthFile {
#[serde(flatten)]
pub profiles: HashMap<String, HashMap<String, ProviderAuth>>,
}
#[derive(Debug, Clone)]
pub struct AuthManager {
auth_path: PathBuf,
auth_file: AuthFile,
}
impl AuthManager {
pub fn new(config_dir: &Path) -> OAuthResult<Self> {
let auth_path = config_dir.join(AUTH_FILE_NAME);
let auth_file = if auth_path.is_file() {
let content = std::fs::read_to_string(&auth_path)?;
toml::from_str(&content)?
} else {
AuthFile::default()
};
Ok(Self {
auth_path,
auth_file,
})
}
pub fn from_default_dir() -> OAuthResult<Self> {
let config_dir = get_default_config_dir()?;
Self::new(&config_dir)
}
pub fn get(&self, profile: &str, provider: &str) -> Option<&ProviderAuth> {
if let Some(providers) = self.auth_file.profiles.get(profile)
&& let Some(auth) = providers.get(provider)
{
return Some(auth);
}
if profile != ALL_PROFILE
&& let Some(providers) = self.auth_file.profiles.get(ALL_PROFILE)
&& let Some(auth) = providers.get(provider)
{
return Some(auth);
}
None
}
pub fn set(&mut self, profile: &str, provider: &str, auth: ProviderAuth) -> OAuthResult<()> {
self.auth_file
.profiles
.entry(profile.to_string())
.or_default()
.insert(provider.to_string(), auth);
self.save()
}
pub fn remove(&mut self, profile: &str, provider: &str) -> OAuthResult<bool> {
let removed = if let Some(providers) = self.auth_file.profiles.get_mut(profile) {
let removed = providers.remove(provider).is_some();
if providers.is_empty() {
self.auth_file.profiles.remove(profile);
}
removed
} else {
false
};
if removed {
self.save()?;
}
Ok(removed)
}
pub fn list(&self) -> &HashMap<String, HashMap<String, ProviderAuth>> {
&self.auth_file.profiles
}
pub fn list_for_profile(&self, profile: &str) -> HashMap<String, &ProviderAuth> {
let mut result = HashMap::new();
if let Some(all_providers) = self.auth_file.profiles.get(ALL_PROFILE) {
for (provider, auth) in all_providers {
result.insert(provider.clone(), auth);
}
}
if profile != ALL_PROFILE
&& let Some(profile_providers) = self.auth_file.profiles.get(profile)
{
for (provider, auth) in profile_providers {
result.insert(provider.clone(), auth);
}
}
result
}
pub fn has_credentials(&self) -> bool {
self.auth_file
.profiles
.values()
.any(|providers| !providers.is_empty())
}
pub fn auth_path(&self) -> &Path {
&self.auth_path
}
pub fn update_oauth_tokens(
&mut self,
profile: &str,
provider: &str,
access: &str,
refresh: &str,
expires: i64,
) -> OAuthResult<()> {
let auth = ProviderAuth::oauth(access, refresh, expires);
self.set(profile, provider, auth)
}
fn save(&self) -> OAuthResult<()> {
if let Some(parent) = self.auth_path.parent() {
std::fs::create_dir_all(parent)?;
}
let content = toml::to_string_pretty(&self.auth_file)?;
let temp_path = self.auth_path.with_extension("toml.tmp");
std::fs::write(&temp_path, &content)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let permissions = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(&temp_path, permissions)?;
}
std::fs::rename(&temp_path, &self.auth_path)?;
Ok(())
}
}
pub fn get_default_config_dir() -> OAuthResult<PathBuf> {
let home = dirs::home_dir().ok_or_else(|| {
OAuthError::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Could not determine home directory",
))
})?;
Ok(home.join(".stakpak"))
}
pub fn get_auth_file_path(config_dir: &Path) -> PathBuf {
config_dir.join(AUTH_FILE_NAME)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_auth_manager() -> (AuthManager, TempDir) {
let temp_dir = TempDir::new().unwrap();
let manager = AuthManager::new(temp_dir.path()).unwrap();
(manager, temp_dir)
}
#[test]
fn test_new_empty() {
let (manager, _temp) = create_test_auth_manager();
assert!(!manager.has_credentials());
assert!(manager.list().is_empty());
}
#[test]
fn test_set_and_get() {
let (mut manager, _temp) = create_test_auth_manager();
let auth = ProviderAuth::api_key("sk-test-key");
manager.set("default", "anthropic", auth.clone()).unwrap();
let retrieved = manager.get("default", "anthropic");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), &auth);
}
#[test]
fn test_profile_inheritance() {
let (mut manager, _temp) = create_test_auth_manager();
let all_auth = ProviderAuth::api_key("sk-all-key");
manager.set("all", "anthropic", all_auth.clone()).unwrap();
assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
assert_eq!(manager.get("work", "anthropic"), Some(&all_auth));
assert_eq!(manager.get("all", "anthropic"), Some(&all_auth));
}
#[test]
fn test_profile_override() {
let (mut manager, _temp) = create_test_auth_manager();
let all_auth = ProviderAuth::api_key("sk-all-key");
manager.set("all", "anthropic", all_auth.clone()).unwrap();
let work_auth = ProviderAuth::api_key("sk-work-key");
manager.set("work", "anthropic", work_auth.clone()).unwrap();
assert_eq!(manager.get("work", "anthropic"), Some(&work_auth));
assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
}
#[test]
fn test_remove() {
let (mut manager, _temp) = create_test_auth_manager();
let auth = ProviderAuth::api_key("sk-test-key");
manager.set("default", "anthropic", auth).unwrap();
assert!(manager.get("default", "anthropic").is_some());
let removed = manager.remove("default", "anthropic").unwrap();
assert!(removed);
assert!(manager.get("default", "anthropic").is_none());
}
#[test]
fn test_remove_nonexistent() {
let (mut manager, _temp) = create_test_auth_manager();
let removed = manager.remove("default", "anthropic").unwrap();
assert!(!removed);
}
#[test]
fn test_list_for_profile() {
let (mut manager, _temp) = create_test_auth_manager();
let all_anthropic = ProviderAuth::api_key("sk-all-anthropic");
let all_openai = ProviderAuth::api_key("sk-all-openai");
let work_anthropic = ProviderAuth::api_key("sk-work-anthropic");
manager
.set("all", "anthropic", all_anthropic.clone())
.unwrap();
manager.set("all", "openai", all_openai.clone()).unwrap();
manager
.set("work", "anthropic", work_anthropic.clone())
.unwrap();
let work_creds = manager.list_for_profile("work");
assert_eq!(work_creds.len(), 2);
assert_eq!(work_creds.get("anthropic"), Some(&&work_anthropic));
assert_eq!(work_creds.get("openai"), Some(&&all_openai));
let default_creds = manager.list_for_profile("default");
assert_eq!(default_creds.len(), 2);
assert_eq!(default_creds.get("anthropic"), Some(&&all_anthropic));
assert_eq!(default_creds.get("openai"), Some(&&all_openai));
}
#[test]
fn test_persistence() {
let temp_dir = TempDir::new().unwrap();
{
let mut manager = AuthManager::new(temp_dir.path()).unwrap();
let auth = ProviderAuth::api_key("sk-test-key");
manager.set("default", "anthropic", auth).unwrap();
}
{
let manager = AuthManager::new(temp_dir.path()).unwrap();
let retrieved = manager.get("default", "anthropic");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().api_key_value(), Some("sk-test-key"));
}
}
#[test]
fn test_oauth_tokens() {
let (mut manager, _temp) = create_test_auth_manager();
let expires = chrono::Utc::now().timestamp_millis() + 3600000;
let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
manager.set("default", "anthropic", auth).unwrap();
let retrieved = manager.get("default", "anthropic").unwrap();
assert!(retrieved.is_oauth());
assert_eq!(retrieved.access_token(), Some("access-token"));
assert_eq!(retrieved.refresh_token(), Some("refresh-token"));
}
#[test]
fn test_update_oauth_tokens() {
let (mut manager, _temp) = create_test_auth_manager();
manager
.set(
"default",
"anthropic",
ProviderAuth::oauth("old-access", "old-refresh", 0),
)
.unwrap();
let new_expires = chrono::Utc::now().timestamp_millis() + 3600000;
manager
.update_oauth_tokens(
"default",
"anthropic",
"new-access",
"new-refresh",
new_expires,
)
.unwrap();
let retrieved = manager.get("default", "anthropic").unwrap();
assert_eq!(retrieved.access_token(), Some("new-access"));
assert_eq!(retrieved.refresh_token(), Some("new-refresh"));
}
#[cfg(unix)]
#[test]
fn test_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let temp_dir = TempDir::new().unwrap();
let mut manager = AuthManager::new(temp_dir.path()).unwrap();
let auth = ProviderAuth::api_key("sk-test-key");
manager.set("default", "anthropic", auth).unwrap();
let metadata = std::fs::metadata(manager.auth_path()).unwrap();
let mode = metadata.permissions().mode();
assert_eq!(mode & 0o777, 0o600);
}
}