shunt-proxy 0.1.90

A local proxy that pools multiple Claude accounts behind a single endpoint, routing requests to maximise rate limits
Documentation
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use crate::credential::{deserialize_credential_map, Credential};
use crate::provider::Provider;

pub const APP_NAME: &str = "shunt";

pub fn config_path() -> PathBuf {
    dirs::config_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(APP_NAME)
        .join("config.toml")
}

pub fn credentials_path() -> PathBuf {
    dirs::config_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(APP_NAME)
        .join("credentials.json")
}

pub fn state_path() -> PathBuf {
    dirs::data_local_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(APP_NAME)
        .join("state.json")
}

pub fn log_path() -> PathBuf {
    dirs::data_local_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(APP_NAME)
        .join("proxy.log")
}

pub fn pid_path() -> PathBuf {
    dirs::data_local_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(APP_NAME)
        .join("shunt.pid")
}

// ---------------------------------------------------------------------------
// Credentials store  (separate file from config — never commit this)
// ---------------------------------------------------------------------------

#[derive(Debug, Default, Serialize, Deserialize)]
pub struct CredentialsStore {
    #[serde(deserialize_with = "deserialize_credential_map", default)]
    pub accounts: HashMap<String, Credential>,
}

impl CredentialsStore {
    pub fn load() -> Self {
        let p = credentials_path();
        if !p.exists() {
            return Self::default();
        }
        match std::fs::read_to_string(&p) {
            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
            Err(_) => Self::default(),
        }
    }

    pub fn save(&self) -> Result<()> {
        let p = credentials_path();
        if let Some(parent) = p.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let tmp = p.with_extension("tmp");
        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
        std::fs::rename(&tmp, &p)?;
        #[cfg(unix)]
        {
            use std::os::unix::fs::PermissionsExt;
            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
        }
        // On Windows, restrict the file to the current user via icacls (best-effort).
        #[cfg(windows)]
        {
            if let Some(path_str) = p.to_str() {
                let username = std::env::var("USERNAME").unwrap_or_default();
                if !username.is_empty() {
                    let _ = std::process::Command::new("icacls")
                        .arg(path_str)
                        .arg("/inheritance:r")
                        .arg("/grant:r")
                        .arg(format!("{username}:F"))
                        .status();
                }
            }
        }
        Ok(())
    }
}

// ---------------------------------------------------------------------------
// Raw TOML config types
// ---------------------------------------------------------------------------

#[derive(Debug, Deserialize)]
struct RawConfig {
    #[serde(default)]
    server: RawServer,
    #[serde(default)]
    accounts: Vec<RawAccount>,
    /// Global model-name mapping: `"claude-sonnet-4-6" = "llama-3.3-70b-versatile"`
    /// Applied when routing Anthropic-format requests to non-Anthropic providers.
    #[serde(default)]
    model_mapping: HashMap<String, String>,
}

#[derive(Debug, Deserialize)]
struct RawServer {
    #[serde(default = "default_host")]
    host: String,
    #[serde(default = "default_port")]
    port: u16,
    #[serde(default = "default_control_port")]
    control_port: u16,
    #[serde(default = "default_log_level")]
    log_level: String,
    upstream_url: Option<String>,
    remote_key: Option<String>,
    relay_url: Option<String>,
    pub custom_domain: Option<String>,
    /// Conversation stickiness TTL in minutes (default: 10)
    sticky_ttl_minutes: Option<u64>,
    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
    expiry_soon_minutes: Option<u64>,
    /// Upstream request timeout in seconds (default: 600)
    request_timeout_secs: Option<u64>,
}

impl Default for RawServer {
    fn default() -> Self {
        Self {
            host: default_host(),
            port: default_port(),
            control_port: default_control_port(),
            log_level: default_log_level(),
            upstream_url: None,
            remote_key: None,
            relay_url: None,
            custom_domain: None,
            sticky_ttl_minutes: None,
            expiry_soon_minutes: None,
            request_timeout_secs: None,
        }
    }
}

#[derive(Debug, Deserialize)]
struct RawAccount {
    name: String,
    #[serde(default = "default_plan_type")]
    plan_type: String,
    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
    #[serde(default)]
    provider: Option<String>,
    /// Inline API key (use api_key_env for better security).
    #[serde(default)]
    api_key: Option<String>,
    /// Name of an environment variable that holds the API key.
    #[serde(default)]
    api_key_env: Option<String>,
    /// Per-account upstream URL override (required for Local provider).
    #[serde(default)]
    upstream_url: Option<String>,
    /// Pin this account to a specific model, overriding global model_mapping
    /// and the provider's default_model(). Useful for mixing model tiers.
    #[serde(default)]
    model: Option<String>,
}

fn default_host() -> String { "127.0.0.1".into() }
fn default_port() -> u16 { 8082 }
fn default_control_port() -> u16 { 19081 }
fn default_log_level() -> String { "info".into() }
fn default_plan_type() -> String { "pro".into() }

// ---------------------------------------------------------------------------
// Resolved config types
// ---------------------------------------------------------------------------

#[derive(Debug, Clone)]
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    /// Port for the control plane (/status, /use, /health) — sees all accounts.
    pub control_port: u16,
    pub log_level: String,
    pub upstream_url: String,
    /// When set, remote requests must supply this value as `x-api-key`.
    pub remote_key: Option<String>,
    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
    pub relay_url: String,
    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
    pub custom_domain: Option<String>,
    /// Conversation stickiness TTL in milliseconds.
    pub sticky_ttl_ms: u64,
    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
    pub expiry_soon_secs: u64,
    /// Upstream request timeout in seconds.
    pub request_timeout_secs: u64,
}

impl Default for ServerConfig {
    fn default() -> Self {
        Self {
            host: "127.0.0.1".into(),
            port: 8082,
            control_port: 19081,
            log_level: "info".into(),
            upstream_url: "https://api.anthropic.com".into(),
            remote_key: None,
            relay_url: "https://relay.ramcharan.shop".into(),
            custom_domain: None,
            sticky_ttl_ms: 10 * 60 * 1000,
            expiry_soon_secs: 30 * 60,
            request_timeout_secs: 600,
        }
    }
}

#[derive(Debug, Clone)]
pub struct AccountConfig {
    pub name: String,
    pub plan_type: String,
    pub provider: Provider,
    /// `None` when the account has no credential.
    /// OAuth accounts: None means reauth required (shown as auth_failed).
    /// ApiKey accounts: None means key not yet configured.
    /// Local accounts: None is normal (no auth required).
    pub credential: Option<Credential>,
    /// Override the upstream base URL for this account.
    /// `None` means use `config.server.upstream_url` (primary provider) or
    /// `provider.default_upstream_url()` (non-primary provider).
    pub upstream_url: Option<String>,
    /// Pin this account to a specific model name.
    /// Overrides both `model_mapping` and `provider.default_model()`.
    pub model: Option<String>,
}

#[derive(Debug, Clone)]
pub struct Config {
    pub server: ServerConfig,
    pub accounts: Vec<AccountConfig>,
    pub config_file: PathBuf,
    /// Global model-name overrides: claude model → provider model.
    /// e.g. `"claude-sonnet-4-6" → "llama-3.3-70b-versatile"`
    pub model_mapping: HashMap<String, String>,
}

// ---------------------------------------------------------------------------
// Loading
// ---------------------------------------------------------------------------

pub fn load_config(path: Option<&Path>) -> Result<Config> {
    let p = path.map(PathBuf::from).unwrap_or_else(config_path);

    if !p.exists() {
        bail!(
            "Config not found: {}\nRun `shunt setup` to get started.",
            p.display()
        );
    }

    let raw_text = std::fs::read_to_string(&p)
        .with_context(|| format!("Failed to read config: {}", p.display()))?;

    let raw: RawConfig = toml::from_str(&raw_text)
        .with_context(|| format!("Failed to parse config: {}", p.display()))?;

    // Derive the default upstream URL from the first account's provider so that
    // an all-OpenAI config automatically points at api.openai.com without any
    // explicit `upstream_url` in the config file.
    let default_upstream = raw.accounts.first()
        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
        .unwrap_or_default()
        .default_upstream_url()
        .to_owned();

    let upstream_url = raw
        .server
        .upstream_url
        .clone()
        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
        .unwrap_or(default_upstream);

    let relay_url = raw
        .server
        .relay_url
        .clone()
        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());

    let server = ServerConfig {
        host: raw.server.host,
        port: raw.server.port,
        control_port: raw.server.control_port,
        log_level: raw.server.log_level,
        upstream_url,
        remote_key: raw.server.remote_key,
        relay_url,
        custom_domain: raw.server.custom_domain,
        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
    };

    if raw.accounts.is_empty() {
        bail!("Config has no accounts. Run `shunt setup` to add one.");
    }

    let store = CredentialsStore::load();

    // Determine the primary provider (first account) so we know which accounts
    // use config.server.upstream_url and which need the provider's default URL.
    let primary_provider = raw.accounts.first()
        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
        .unwrap_or_default();

    let mut accounts = Vec::new();
    for a in &raw.accounts {
        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();

        // Resolve credential.
        //
        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
        // auto-import from the provider's local CLI tool.
        //
        // API-key providers: credentials.json first, then inline api_key field,
        // then api_key_env field, then the provider's well-known env var.
        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
            .or_else(|| {
                // Inline api_key from TOML (less secure, but convenient for testing).
                a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
            })
            .or_else(|| {
                // api_key_env: name of env var holding the key.
                a.api_key_env.as_deref()
                    .and_then(|var| std::env::var(var).ok())
                    .map(|k| Credential::Apikey { key: k })
            })
            .or_else(|| {
                // Auto-import from provider's CLI tool (OAuth providers) or
                // well-known env var (API-key providers).
                provider.read_local_credentials()
            });

        // Upstream URL: per-account override from TOML takes priority, then
        // non-primary-provider accounts get the provider's default URL so
        // the forwarder knows where to send requests.
        let acct_upstream = a.upstream_url.clone().or_else(|| {
            if provider != primary_provider {
                Some(provider.default_upstream_url().to_owned())
            } else {
                None
            }
        });

        accounts.push(AccountConfig {
            name: a.name.clone(),
            plan_type: a.plan_type.clone(),
            provider,
            credential: cred,
            upstream_url: acct_upstream,
            model: a.model.clone(),
        });
    }

    Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
}

// ---------------------------------------------------------------------------
// Config file template
// ---------------------------------------------------------------------------

pub fn config_template(accounts: &[(&str, &str)]) -> String {
    let mut out = String::from(
        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
    );
    for (name, plan_type) in accounts {
        out.push_str(&format!(
            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
        ));
    }
    out
}