use anyhow::Context;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Read;
use std::path::PathBuf;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AuthCredential {
#[serde(rename = "api_key")]
ApiKey { key: String },
#[serde(rename = "oauth")]
Oauth {
access: String,
refresh: Option<String>,
expires: Option<i64>,
#[serde(rename = "enterpriseUrl")]
enterprise_url: Option<String>,
},
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct AuthStorage(HashMap<String, AuthCredential>);
impl AuthStorage {
pub fn load() -> anyhow::Result<Self> {
Self::load_from(Self::path()?)
}
pub fn load_from(path: std::path::PathBuf) -> anyhow::Result<Self> {
let content = read_json_file(&path)?;
match content {
Some(c) => serde_json::from_str(&c)
.with_context(|| format!("Failed to parse {}", path.display())),
None => Ok(Self::default()),
}
}
pub fn path() -> anyhow::Result<PathBuf> {
let dir = directories::BaseDirs::new().context("Could not determine home directory")?;
Ok(dir.home_dir().join(".rab").join("agent").join("auth.json"))
}
pub fn api_key(&self, provider: &str) -> Option<String> {
self.0.get(provider).and_then(|cred| match cred {
AuthCredential::ApiKey { key } => Some(key.clone()),
AuthCredential::Oauth { .. } => None,
})
}
pub fn oauth_token(&self, provider: &str) -> Option<String> {
self.0.get(provider).and_then(|cred| match cred {
AuthCredential::Oauth {
access, expires, ..
} => {
if is_expired(*expires) {
return None;
}
Some(access.clone())
}
AuthCredential::ApiKey { .. } => None,
})
}
pub fn oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
self.0.get(provider).cloned().and_then(|cred| match cred {
AuthCredential::Oauth { .. } => Some(cred),
AuthCredential::ApiKey { .. } => None,
})
}
pub fn all_credentials(&self) -> &HashMap<String, AuthCredential> {
&self.0
}
}
fn with_exclusive_lock<T>(path: &PathBuf, f: impl FnOnce() -> T) -> T {
use fs2::FileExt;
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let file = std::fs::OpenOptions::new()
.create(true)
.truncate(false)
.write(true)
.read(true)
.open(path)
.expect("Failed to open auth file");
let mut attempts = 0;
loop {
match file.try_lock_exclusive() {
Ok(()) => break,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
attempts += 1;
if attempts >= 200 {
break; }
if attempts > 5
&& let Ok(metadata) = path.metadata()
&& let Ok(modified) = metadata.modified()
&& let Ok(elapsed) = modified.elapsed()
&& elapsed > Duration::from_secs(10)
{
let _ = file.unlock();
continue;
}
std::thread::sleep(Duration::from_millis(50));
}
Err(e) => panic!("Failed to lock auth file: {}", e),
}
}
let result = f();
let _ = file.unlock();
result
}
fn read_json_file(path: &PathBuf) -> anyhow::Result<Option<String>> {
if !path.exists() {
return Ok(None);
}
let mut s = String::new();
let mut file =
std::fs::File::open(path).with_context(|| format!("Failed to open {}", path.display()))?;
file.read_to_string(&mut s)
.with_context(|| format!("Failed to read {}", path.display()))?;
Ok(Some(s))
}
fn modify_auth_file(
path: &PathBuf,
f: impl FnOnce(HashMap<String, AuthCredential>) -> (HashMap<String, AuthCredential>, bool),
) -> anyhow::Result<()> {
with_exclusive_lock(path, || {
let auth: HashMap<String, AuthCredential> = match read_json_file(path) {
Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
_ => HashMap::new(),
};
let (result, changed) = f(auth);
if changed {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(content) = serde_json::to_string_pretty(&result) {
let _ = std::fs::write(path, &content);
}
}
});
Ok(())
}
fn is_expired(expires: Option<i64>) -> bool {
match expires {
Some(exp) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
now >= exp
}
None => false, }
}
pub fn login(provider: &str, api_key: &str) -> anyhow::Result<()> {
let path = AuthStorage::path()?;
let p = provider.to_string();
let k = api_key.to_string();
modify_auth_file(&path, |mut auth| {
auth.insert(p, AuthCredential::ApiKey { key: k });
(auth, true)
})
}
pub fn login_oauth(provider: &str, cred: &AuthCredential) -> anyhow::Result<()> {
let path = AuthStorage::path()?;
let p = provider.to_string();
let c = cred.clone();
modify_auth_file(&path, |mut auth| {
auth.insert(p, c);
(auth, true)
})
}
pub fn logout(provider: Option<&str>) -> anyhow::Result<bool> {
let path = AuthStorage::path()?;
if !path.exists() {
return Ok(false);
}
let result = with_exclusive_lock(&path, || -> bool {
let auth: HashMap<String, AuthCredential> = match read_json_file(&path) {
Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
_ => return false,
};
let (new_auth, removed) = match provider {
Some(prov) => {
let mut a = auth;
let removed = a.remove(prov).is_some();
(a, removed)
}
None => {
let removed = !auth.is_empty();
(HashMap::new(), removed)
}
};
if removed {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(content) = serde_json::to_string_pretty(&new_auth) {
let _ = std::fs::write(&path, &content);
}
}
removed
});
Ok(result)
}
pub fn list_logged_in() -> anyhow::Result<Vec<String>> {
let path = AuthStorage::path()?;
let content = read_json_file(&path)?;
match content {
Some(c) => {
let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
.with_context(|| format!("Failed to parse {}", path.display()))?;
Ok(auth.keys().cloned().collect())
}
None => Ok(Vec::new()),
}
}
pub fn read_credential(provider: &str) -> anyhow::Result<Option<AuthCredential>> {
let path = AuthStorage::path()?;
let content = read_json_file(&path)?;
match content {
Some(c) => {
let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
.with_context(|| format!("Failed to parse {}", path.display()))?;
Ok(auth.get(provider).cloned())
}
None => Ok(None),
}
}
pub fn modify_credential(
provider: &str,
f: impl FnOnce(Option<AuthCredential>) -> Option<AuthCredential>,
) -> anyhow::Result<()> {
let path = AuthStorage::path()?;
let p = provider.to_string();
modify_auth_file(&path, |auth| {
let current = auth.get(&p).cloned();
let next = f(current);
let mut updated = auth;
match next {
Some(cred) => {
updated.insert(p, cred);
}
None => {
updated.remove(&p);
}
}
(updated, true)
})
}
pub async fn refresh_oauth_token(provider: &str) -> Option<String> {
let credential = read_credential(provider).ok()??;
let oauth_cred = match &credential {
AuthCredential::Oauth { .. } => credential,
_ => return None,
};
let expires = match &oauth_cred {
AuthCredential::Oauth { expires, .. } => *expires,
_ => return None,
};
if !is_expired(Some(expires.unwrap_or(i64::MAX))) {
let buffer_ms = 300_000;
if let AuthCredential::Oauth { access, .. } = &oauth_cred {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
if now < expires.unwrap_or(i64::MAX) - buffer_ms {
return Some(access.clone());
}
}
}
let oauth_provider = crate::provider::oauth::get(provider)?;
let oauth_creds = match &oauth_cred {
AuthCredential::Oauth {
access,
refresh,
expires,
enterprise_url,
..
} => crate::provider::oauth::OAuthCredentials {
access: access.clone(),
refresh: refresh.clone().unwrap_or_default(),
expires: expires.unwrap_or(0),
enterprise_url: enterprise_url.clone(),
extra: std::collections::HashMap::new(),
},
_ => return None,
};
let new_creds = oauth_provider.refresh_token(&oauth_creds).await.ok()?;
let new_access = new_creds.access.clone();
let result = modify_credential(provider, |_| {
Some(AuthCredential::Oauth {
access: new_creds.access.clone(),
refresh: Some(new_creds.refresh),
expires: Some(new_creds.expires),
enterprise_url: new_creds.enterprise_url,
})
});
match result {
Ok(_) => Some(new_access),
Err(_) => None,
}
}