use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use crate::oauth::OAuthCredential;
use crate::provider::Provider;
pub const APP_NAME: &str = "shunt";
pub fn config_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(APP_NAME)
.join("config.toml")
}
pub fn credentials_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(APP_NAME)
.join("credentials.json")
}
pub fn state_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(APP_NAME)
.join("state.json")
}
pub fn log_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(APP_NAME)
.join("proxy.log")
}
pub fn pid_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(APP_NAME)
.join("shunt.pid")
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct CredentialsStore {
pub accounts: HashMap<String, OAuthCredential>,
}
impl CredentialsStore {
pub fn load() -> Self {
let p = credentials_path();
if !p.exists() {
return Self::default();
}
match std::fs::read_to_string(&p) {
Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
Err(_) => Self::default(),
}
}
pub fn save(&self) -> Result<()> {
let p = credentials_path();
if let Some(parent) = p.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = p.with_extension("tmp");
std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
std::fs::rename(&tmp, &p)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
}
#[cfg(windows)]
{
if let Some(path_str) = p.to_str() {
let username = std::env::var("USERNAME").unwrap_or_default();
if !username.is_empty() {
let _ = std::process::Command::new("icacls")
.arg(path_str)
.arg("/inheritance:r")
.arg("/grant:r")
.arg(format!("{username}:F"))
.status();
}
}
}
Ok(())
}
}
#[derive(Debug, Deserialize)]
struct RawConfig {
#[serde(default)]
server: RawServer,
#[serde(default)]
accounts: Vec<RawAccount>,
}
#[derive(Debug, Deserialize)]
struct RawServer {
#[serde(default = "default_host")]
host: String,
#[serde(default = "default_port")]
port: u16,
#[serde(default = "default_log_level")]
log_level: String,
upstream_url: Option<String>,
remote_key: Option<String>,
relay_url: Option<String>,
pub custom_domain: Option<String>,
sticky_ttl_minutes: Option<u64>,
expiry_soon_minutes: Option<u64>,
request_timeout_secs: Option<u64>,
}
impl Default for RawServer {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
log_level: default_log_level(),
upstream_url: None,
remote_key: None,
relay_url: None,
custom_domain: None,
sticky_ttl_minutes: None,
expiry_soon_minutes: None,
request_timeout_secs: None,
}
}
}
#[derive(Debug, Deserialize)]
struct RawAccount {
name: String,
#[serde(default = "default_plan_type")]
plan_type: String,
#[serde(default)]
provider: Option<String>,
}
fn default_host() -> String { "127.0.0.1".into() }
fn default_port() -> u16 { 8082 }
fn default_log_level() -> String { "info".into() }
fn default_plan_type() -> String { "pro".into() }
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub log_level: String,
pub upstream_url: String,
pub remote_key: Option<String>,
pub relay_url: String,
pub custom_domain: Option<String>,
pub sticky_ttl_ms: u64,
pub expiry_soon_secs: u64,
pub request_timeout_secs: u64,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".into(),
port: 8082,
log_level: "info".into(),
upstream_url: "https://api.anthropic.com".into(),
remote_key: None,
relay_url: "https://relay.ramcharan.shop".into(),
custom_domain: None,
sticky_ttl_ms: 10 * 60 * 1000,
expiry_soon_secs: 30 * 60,
request_timeout_secs: 600,
}
}
}
#[derive(Debug, Clone)]
pub struct AccountConfig {
pub name: String,
pub plan_type: String,
pub provider: Provider,
pub credential: Option<OAuthCredential>,
pub upstream_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Config {
pub server: ServerConfig,
pub accounts: Vec<AccountConfig>,
pub config_file: PathBuf,
}
pub fn load_config(path: Option<&Path>) -> Result<Config> {
let p = path.map(PathBuf::from).unwrap_or_else(config_path);
if !p.exists() {
bail!(
"Config not found: {}\nRun `shunt setup` to get started.",
p.display()
);
}
let raw_text = std::fs::read_to_string(&p)
.with_context(|| format!("Failed to read config: {}", p.display()))?;
let raw: RawConfig = toml::from_str(&raw_text)
.with_context(|| format!("Failed to parse config: {}", p.display()))?;
let default_upstream = raw.accounts.first()
.map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
.unwrap_or_default()
.default_upstream_url()
.to_owned();
let upstream_url = raw
.server
.upstream_url
.clone()
.or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
.unwrap_or(default_upstream);
let relay_url = raw
.server
.relay_url
.clone()
.or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
.unwrap_or_else(|| "https://relay.ramcharan.shop".into());
let server = ServerConfig {
host: raw.server.host,
port: raw.server.port,
log_level: raw.server.log_level,
upstream_url,
remote_key: raw.server.remote_key,
relay_url,
custom_domain: raw.server.custom_domain,
sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
};
if raw.accounts.is_empty() {
bail!("Config has no accounts. Run `shunt setup` to add one.");
}
let store = CredentialsStore::load();
let primary_provider = raw.accounts.first()
.map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
.unwrap_or_default();
let mut accounts = Vec::new();
for a in &raw.accounts {
let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
let cred = store
.accounts
.get(&a.name)
.cloned()
.or_else(|| provider.read_local_credentials());
let acct_upstream = if provider != primary_provider {
Some(provider.default_upstream_url().to_owned())
} else {
None
};
accounts.push(AccountConfig {
name: a.name.clone(),
plan_type: a.plan_type.clone(),
provider,
credential: cred,
upstream_url: acct_upstream,
});
}
Ok(Config { server, accounts, config_file: p })
}
pub fn config_template(accounts: &[(&str, &str)]) -> String {
let mut out = String::from(
"[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
);
for (name, plan_type) in accounts {
out.push_str(&format!(
"\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
));
}
out
}