use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
use crate::error::{Result, ToriiError};
const CLOUD_ENV_VAR: &str = "TORII_API_KEY";
const FILE_NAME: &str = "auth.toml";
#[derive(Debug, Clone, Default)]
pub struct ApiKey {
pub key: String,
pub endpoint: String,
}
#[derive(Debug, Clone, Default)]
pub struct AuthStore {
pub cloud: Option<ApiKey>,
pub tokens: BTreeMap<String, String>,
pub expirations: BTreeMap<String, String>,
pub refresh_tokens: BTreeMap<String, String>,
}
pub const PROVIDERS: &[&str] = &[
"github",
"gitlab",
"gitea",
"forgejo",
"codeberg",
"bitbucket",
"sourcehut",
"azure", "cargo",
];
pub fn default_endpoint() -> String {
std::env::var("TORII_API_ENDPOINT")
.unwrap_or_else(|_| "https://api.gitorii.com".to_string())
}
fn global_path() -> Option<PathBuf> {
dirs::config_dir().map(|d| d.join("torii").join(FILE_NAME))
}
fn local_path<P: AsRef<Path>>(repo_path: P) -> PathBuf {
repo_path.as_ref().join(".torii").join(FILE_NAME)
}
pub fn load() -> Option<ApiKey> {
if let Ok(env_key) = std::env::var(CLOUD_ENV_VAR) {
if !env_key.is_empty() {
return Some(ApiKey {
key: env_key,
endpoint: default_endpoint(),
});
}
}
load_global().cloud
}
pub fn load_global() -> AuthStore {
let Some(path) = global_path() else {
return AuthStore::default();
};
if !path.exists() {
return migrate_from_config_toml().unwrap_or_default();
}
let text = match fs::read_to_string(&path) {
Ok(t) => t,
Err(_) => return AuthStore::default(),
};
parse(&text)
}
pub fn load_local_raw<P: AsRef<Path>>(repo_path: P) -> AuthStore {
let path = local_path(repo_path);
if !path.exists() {
return AuthStore::default();
}
let text = match fs::read_to_string(&path) {
Ok(t) => t,
Err(_) => return AuthStore::default(),
};
parse(&text)
}
fn save_to(path: &Path, store: &AuthStore) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|e| ToriiError::Fs(format!("create dir: {}", e)))?;
}
let mut out = String::new();
out.push_str("# torii credentials — managed by 'torii auth …'. Do not share.\n\n");
if let Some(cloud) = &store.cloud {
out.push_str("[cloud]\n");
out.push_str(&format!("key = \"{}\"\n", cloud.key));
out.push_str(&format!("endpoint = \"{}\"\n\n", cloud.endpoint));
}
if !store.tokens.is_empty() {
out.push_str("[tokens]\n");
for (k, v) in &store.tokens {
out.push_str(&format!("{} = \"{}\"\n", k, v));
}
out.push('\n');
}
if !store.expirations.is_empty() {
out.push_str("[token_expires]\n");
for (k, v) in &store.expirations {
out.push_str(&format!("{} = \"{}\"\n", k, v));
}
out.push('\n');
}
if !store.refresh_tokens.is_empty() {
out.push_str("[token_refresh]\n");
for (k, v) in &store.refresh_tokens {
out.push_str(&format!("{} = \"{}\"\n", k, v));
}
}
fs::write(path, out)
.map_err(|e| ToriiError::Fs(format!("write {}: {}", path.display(), e)))?;
restrict_permissions(path);
Ok(())
}
pub fn save_global(store: &AuthStore) -> Result<()> {
let path = global_path()
.ok_or_else(|| ToriiError::InvalidConfig("could not resolve config dir".to_string()))?;
save_to(&path, store)
}
pub fn save_local<P: AsRef<Path>>(repo_path: P, store: &AuthStore) -> Result<()> {
let path = local_path(repo_path);
save_to(&path, store)
}
pub fn save_cloud(key: &str, endpoint: &str) -> Result<()> {
let mut store = load_global();
store.cloud = Some(ApiKey {
key: key.to_string(),
endpoint: endpoint.to_string(),
});
save_global(&store)
}
pub fn delete() -> Result<()> {
let mut store = load_global();
store.cloud = None;
if store.tokens.is_empty() {
if let Some(path) = global_path() {
if path.exists() {
fs::remove_file(&path).map_err(|e| {
ToriiError::Fs(format!("remove {}: {}", path.display(), e))
})?;
}
}
return Ok(());
}
save_global(&store)
}
pub fn normalise_provider(name: &str) -> Result<String> {
let lc = name.to_lowercase();
if PROVIDERS.iter().any(|p| **p == lc) {
Ok(lc)
} else {
Err(ToriiError::Usage(format!(
"unknown provider '{}'. Known: {}",
name,
PROVIDERS.join(", ")
)))
}
}
pub fn set_token(provider: &str, token: &str, local: Option<&Path>) -> Result<()> {
set_token_with_expiry(provider, token, None, local)
}
pub fn set_token_with_expiry(
provider: &str,
token: &str,
expires_at: Option<&str>,
local: Option<&Path>,
) -> Result<()> {
let provider = normalise_provider(provider)?;
let result = if let Some(repo) = local {
let mut store = load_local_raw(repo);
store.tokens.insert(provider.clone(), token.to_string());
apply_expiry(&mut store.expirations, &provider, expires_at);
save_local(repo, &store)
} else {
let mut store = load_global();
store.tokens.insert(provider.clone(), token.to_string());
apply_expiry(&mut store.expirations, &provider, expires_at);
save_global(&store)
};
invalidate_token_cache();
result
}
pub fn set_token_with_refresh(
provider: &str,
access_token: &str,
refresh_token: Option<&str>,
expires_in_seconds: Option<u64>,
) -> Result<()> {
let provider = normalise_provider(provider)?;
let expires_at = expires_in_seconds.map(|s| {
let when = chrono::Utc::now() + chrono::Duration::seconds(s as i64);
when.to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
});
let mut store = load_global();
store.tokens.insert(provider.clone(), access_token.to_string());
apply_expiry(&mut store.expirations, &provider, expires_at.as_deref());
if let Some(r) = refresh_token {
store.refresh_tokens.insert(provider.clone(), r.to_string());
}
save_global(&store)?;
invalidate_token_cache();
Ok(())
}
pub fn refresh_if_needed(provider: &str) -> Result<bool> {
let provider_lc = provider.to_lowercase();
let store = load_global();
let Some(refresh) = store.refresh_tokens.get(&provider_lc).cloned() else {
return Ok(false);
};
let due = store.expirations.get(&provider_lc)
.and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
.map(|when| {
let now = chrono::Utc::now();
when.with_timezone(&chrono::Utc) - now < chrono::Duration::minutes(5)
})
.unwrap_or(false);
if !due { return Ok(false); }
let (new_access, new_refresh, expires_in) =
crate::oauth::refresh_access_token(&provider_lc, &refresh)?;
set_token_with_refresh(
&provider_lc,
&new_access,
new_refresh.as_deref().or(Some(&refresh)),
expires_in,
)?;
Ok(true)
}
fn apply_expiry(map: &mut BTreeMap<String, String>, provider: &str, expires_at: Option<&str>) {
match expires_at {
Some(s) if !s.is_empty() => { map.insert(provider.to_string(), s.to_string()); }
_ => { map.remove(provider); }
}
}
pub fn token_expires_at(provider: &str) -> Option<String> {
let store = load_global();
store.expirations.get(&provider.to_lowercase()).cloned()
}
pub fn remove_token(provider: &str, local: Option<&Path>) -> Result<bool> {
let provider = normalise_provider(provider)?;
let removed = if let Some(repo) = local {
let mut store = load_local_raw(repo);
let r = store.tokens.remove(&provider).is_some();
store.expirations.remove(&provider);
save_local(repo, &store)?;
r
} else {
let mut store = load_global();
let r = store.tokens.remove(&provider).is_some();
store.expirations.remove(&provider);
save_global(&store)?;
r
};
invalidate_token_cache();
Ok(removed)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TokenSource {
EnvVar(&'static str),
EnvGeneric,
Local,
Global,
Missing,
}
#[derive(Debug, Clone)]
pub struct ResolvedToken {
#[allow(dead_code)]
pub provider: String,
pub value: Option<String>,
pub source: TokenSource,
}
fn token_cache() -> &'static std::sync::Mutex<std::collections::HashMap<String, ResolvedToken>> {
static CACHE: std::sync::OnceLock<std::sync::Mutex<std::collections::HashMap<String, ResolvedToken>>> = std::sync::OnceLock::new();
CACHE.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
}
fn invalidate_token_cache() {
if let Ok(mut g) = token_cache().lock() {
g.clear();
}
}
pub fn drop_token_cache() {
invalidate_token_cache();
}
pub fn resolve_token<P: AsRef<Path>>(provider: &str, repo_path: P) -> ResolvedToken {
let _ = refresh_if_needed(provider);
let key = format!("{}|{}", provider.to_lowercase(), repo_path.as_ref().display());
if let Ok(g) = token_cache().lock() {
if let Some(hit) = g.get(&key) {
return hit.clone();
}
}
let result = resolve_token_uncached(provider, repo_path);
if let Ok(mut g) = token_cache().lock() {
g.insert(key, result.clone());
}
result
}
fn resolve_token_uncached<P: AsRef<Path>>(provider: &str, repo_path: P) -> ResolvedToken {
let provider_lc = provider.to_lowercase();
for env_name in env_vars_for(&provider_lc) {
if let Ok(v) = std::env::var(env_name) {
if !v.is_empty() {
return ResolvedToken {
provider: provider_lc,
value: Some(v),
source: TokenSource::EnvVar(env_name),
};
}
}
}
if let Ok(v) = std::env::var("TORII_HTTPS_TOKEN") {
if !v.is_empty() {
return ResolvedToken {
provider: provider_lc,
value: Some(v),
source: TokenSource::EnvGeneric,
};
}
}
let local = load_local_raw(repo_path);
if let Some(v) = local.tokens.get(&provider_lc) {
if !v.is_empty() {
return ResolvedToken {
provider: provider_lc,
value: Some(v.clone()),
source: TokenSource::Local,
};
}
}
let global = load_global();
if let Some(v) = global.tokens.get(&provider_lc) {
if !v.is_empty() {
return ResolvedToken {
provider: provider_lc,
value: Some(v.clone()),
source: TokenSource::Global,
};
}
}
ResolvedToken {
provider: provider_lc,
value: None,
source: TokenSource::Missing,
}
}
fn env_vars_for(provider: &str) -> &'static [&'static str] {
match provider {
"github" => &["GITHUB_TOKEN", "GH_TOKEN"],
"gitlab" => &["GITLAB_TOKEN", "GL_TOKEN"],
"gitea" => &["GITEA_TOKEN"],
"forgejo" => &["FORGEJO_TOKEN"],
"codeberg" => &["CODEBERG_TOKEN"],
"bitbucket" => &["BITBUCKET_TOKEN"],
"azure" => &["AZURE_DEVOPS_TOKEN", "AZURE_DEVOPS_EXT_PAT", "AZDO_TOKEN"],
"sourcehut" => &["SOURCEHUT_TOKEN", "SRHT_TOKEN"],
"cargo" => &["CARGO_REGISTRY_TOKEN"],
_ => &[],
}
}
fn parse(text: &str) -> AuthStore {
enum Section {
TopLevel,
Cloud,
Tokens,
TokenExpires,
TokenRefresh,
}
let mut section = Section::TopLevel;
let mut cloud_key = String::new();
let mut cloud_endpoint = default_endpoint();
let mut have_cloud = false;
let mut tokens = BTreeMap::new();
let mut expirations = BTreeMap::new();
let mut refresh_tokens = BTreeMap::new();
for raw in text.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line.starts_with('[') && line.ends_with(']') {
let name = &line[1..line.len() - 1];
section = match name.trim() {
"cloud" => Section::Cloud,
"tokens" => Section::Tokens,
"token_expires" => Section::TokenExpires,
"token_refresh" => Section::TokenRefresh,
_ => Section::TopLevel, };
continue;
}
let Some((k, v)) = line.split_once('=') else {
continue;
};
let k = k.trim();
let v = v.trim().trim_matches('"').to_string();
match section {
Section::Cloud | Section::TopLevel => match k {
"key" => {
cloud_key = v;
have_cloud = true;
}
"endpoint" => {
cloud_endpoint = v;
}
_ => {}
},
Section::Tokens => {
if !v.is_empty() {
tokens.insert(k.to_string(), v);
}
}
Section::TokenExpires => {
if !v.is_empty() {
expirations.insert(k.to_string(), v);
}
}
Section::TokenRefresh => {
if !v.is_empty() {
refresh_tokens.insert(k.to_string(), v);
}
}
}
}
AuthStore {
cloud: if have_cloud && !cloud_key.is_empty() {
Some(ApiKey {
key: cloud_key,
endpoint: cloud_endpoint,
})
} else {
None
},
tokens,
expirations,
refresh_tokens,
}
}
fn migrate_from_config_toml() -> Option<AuthStore> {
let config_path = dirs::config_dir()?.join("torii").join("config.toml");
if !config_path.exists() {
return None;
}
let text = fs::read_to_string(&config_path).ok()?;
let mut tokens = BTreeMap::new();
let mut in_auth = false;
for raw in text.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line.starts_with('[') && line.ends_with(']') {
in_auth = line.trim_start_matches('[').trim_end_matches(']').trim() == "auth";
continue;
}
if !in_auth {
continue;
}
let Some((k, v)) = line.split_once('=') else {
continue;
};
let key = k.trim();
let value = v.trim().trim_matches('"').to_string();
if value.is_empty() {
continue;
}
if let Some(provider) = key.strip_suffix("_token") {
tokens.insert(provider.to_string(), value);
}
}
if tokens.is_empty() {
return None;
}
let store = AuthStore {
cloud: None,
tokens,
expirations: BTreeMap::new(),
refresh_tokens: BTreeMap::new(),
};
let _ = save_global(&store);
Some(store)
}
#[cfg(unix)]
fn restrict_permissions(path: &std::path::Path) {
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(path, fs::Permissions::from_mode(0o600));
}
#[cfg(not(unix))]
fn restrict_permissions(_: &std::path::Path) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_legacy_top_level_cloud() {
let s = parse("key = \"gitorii_sk_abc\"");
assert_eq!(s.cloud.as_ref().unwrap().key, "gitorii_sk_abc");
assert!(s.tokens.is_empty());
}
#[test]
fn parse_new_sectioned_cloud_only() {
let s = parse("[cloud]\nkey = \"x\"\nendpoint = \"http://h\"\n");
let c = s.cloud.unwrap();
assert_eq!(c.key, "x");
assert_eq!(c.endpoint, "http://h");
}
#[test]
fn parse_tokens_only() {
let s = parse("[tokens]\ngithub = \"ghp_x\"\ngitlab = \"glp_y\"\n");
assert_eq!(s.tokens["github"], "ghp_x");
assert_eq!(s.tokens["gitlab"], "glp_y");
assert!(s.cloud.is_none());
}
#[test]
fn parse_both_sections() {
let s = parse("[cloud]\nkey = \"k\"\n[tokens]\ncargo = \"cio\"\n");
assert_eq!(s.cloud.unwrap().key, "k");
assert_eq!(s.tokens["cargo"], "cio");
}
#[test]
fn parse_empty_tokens_are_dropped() {
let s = parse("[tokens]\ngithub = \"\"\ngitlab = \"x\"\n");
assert!(!s.tokens.contains_key("github"));
assert!(s.tokens.contains_key("gitlab"));
}
#[test]
fn normalise_provider_accepts_known() {
assert_eq!(normalise_provider("GitHub").unwrap(), "github");
assert_eq!(normalise_provider("cargo").unwrap(), "cargo");
}
#[test]
fn normalise_provider_rejects_unknown() {
assert!(normalise_provider("hackernews").is_err());
}
}