Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7
8pub const APP_NAME: &str = "shunt";
9
10pub fn config_path() -> PathBuf {
11    dirs::config_dir()
12        .unwrap_or_else(|| PathBuf::from("."))
13        .join(APP_NAME)
14        .join("config.toml")
15}
16
17pub fn credentials_path() -> PathBuf {
18    dirs::config_dir()
19        .unwrap_or_else(|| PathBuf::from("."))
20        .join(APP_NAME)
21        .join("credentials.json")
22}
23
24pub fn state_path() -> PathBuf {
25    dirs::data_local_dir()
26        .unwrap_or_else(|| PathBuf::from("."))
27        .join(APP_NAME)
28        .join("state.json")
29}
30
31pub fn log_path() -> PathBuf {
32    dirs::data_local_dir()
33        .unwrap_or_else(|| PathBuf::from("."))
34        .join(APP_NAME)
35        .join("proxy.log")
36}
37
38pub fn pid_path() -> PathBuf {
39    dirs::data_local_dir()
40        .unwrap_or_else(|| PathBuf::from("."))
41        .join(APP_NAME)
42        .join("shunt.pid")
43}
44
45// ---------------------------------------------------------------------------
46// Credentials store  (separate file from config — never commit this)
47// ---------------------------------------------------------------------------
48
49#[derive(Debug, Default, Serialize, Deserialize)]
50pub struct CredentialsStore {
51    pub accounts: HashMap<String, OAuthCredential>,
52}
53
54impl CredentialsStore {
55    pub fn load() -> Self {
56        let p = credentials_path();
57        if !p.exists() {
58            return Self::default();
59        }
60        match std::fs::read_to_string(&p) {
61            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
62            Err(_) => Self::default(),
63        }
64    }
65
66    pub fn save(&self) -> Result<()> {
67        let p = credentials_path();
68        if let Some(parent) = p.parent() {
69            std::fs::create_dir_all(parent)?;
70        }
71        let tmp = p.with_extension("tmp");
72        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
73        std::fs::rename(&tmp, &p)?;
74        #[cfg(unix)]
75        {
76            use std::os::unix::fs::PermissionsExt;
77            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
78        }
79        Ok(())
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Raw TOML config types
85// ---------------------------------------------------------------------------
86
87#[derive(Debug, Deserialize)]
88struct RawConfig {
89    #[serde(default)]
90    server: RawServer,
91    #[serde(default)]
92    accounts: Vec<RawAccount>,
93}
94
95#[derive(Debug, Deserialize)]
96struct RawServer {
97    #[serde(default = "default_host")]
98    host: String,
99    #[serde(default = "default_port")]
100    port: u16,
101    #[serde(default = "default_log_level")]
102    log_level: String,
103    upstream_url: Option<String>,
104    remote_key: Option<String>,
105}
106
107impl Default for RawServer {
108    fn default() -> Self {
109        Self {
110            host: default_host(),
111            port: default_port(),
112            log_level: default_log_level(),
113            upstream_url: None,
114            remote_key: None,
115        }
116    }
117}
118
119#[derive(Debug, Deserialize)]
120struct RawAccount {
121    name: String,
122    #[serde(default = "default_plan_type")]
123    plan_type: String,
124}
125
126fn default_host() -> String { "127.0.0.1".into() }
127fn default_port() -> u16 { 8082 }
128fn default_log_level() -> String { "info".into() }
129fn default_plan_type() -> String { "pro".into() }
130
131// ---------------------------------------------------------------------------
132// Resolved config types
133// ---------------------------------------------------------------------------
134
135#[derive(Debug, Clone)]
136pub struct ServerConfig {
137    pub host: String,
138    pub port: u16,
139    pub log_level: String,
140    pub upstream_url: String,
141    /// When set, remote requests must supply this value as `x-api-key`.
142    pub remote_key: Option<String>,
143}
144
145#[derive(Debug, Clone)]
146pub struct AccountConfig {
147    pub name: String,
148    pub plan_type: String,
149    /// `None` when the account is in config but has no credential yet.
150    /// These accounts are shown in status but skipped during proxying.
151    pub credential: Option<OAuthCredential>,
152}
153
154#[derive(Debug, Clone)]
155pub struct Config {
156    pub server: ServerConfig,
157    pub accounts: Vec<AccountConfig>,
158    pub config_file: PathBuf,
159}
160
161// ---------------------------------------------------------------------------
162// Loading
163// ---------------------------------------------------------------------------
164
165pub fn load_config(path: Option<&Path>) -> Result<Config> {
166    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
167
168    if !p.exists() {
169        bail!(
170            "Config not found: {}\nRun `shunt setup` to get started.",
171            p.display()
172        );
173    }
174
175    let raw_text = std::fs::read_to_string(&p)
176        .with_context(|| format!("Failed to read config: {}", p.display()))?;
177
178    let raw: RawConfig = toml::from_str(&raw_text)
179        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
180
181    let upstream_url = raw
182        .server
183        .upstream_url
184        .clone()
185        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
186        .unwrap_or_else(|| "https://api.anthropic.com".into());
187
188    let server = ServerConfig {
189        host: raw.server.host,
190        port: raw.server.port,
191        log_level: raw.server.log_level,
192        upstream_url,
193        remote_key: raw.server.remote_key,
194    };
195
196    if raw.accounts.is_empty() {
197        bail!("Config has no accounts. Run `shunt setup` to add one.");
198    }
199
200    let store = CredentialsStore::load();
201
202    let mut accounts = Vec::new();
203    for a in &raw.accounts {
204        // Resolve credential: env var → credentials store → auto-import primary
205        let cred = if a.name == "main" || store.accounts.is_empty() {
206            // Try the canonical Claude Code credentials as fallback for the primary account
207            store
208                .accounts
209                .get(&a.name)
210                .cloned()
211                .or_else(crate::oauth::read_claude_credentials)
212        } else {
213            store.accounts.get(&a.name).cloned()
214        };
215
216        accounts.push(AccountConfig {
217            name: a.name.clone(),
218            plan_type: a.plan_type.clone(),
219            credential: cred,
220        });
221    }
222
223    Ok(Config { server, accounts, config_file: p })
224}
225
226// ---------------------------------------------------------------------------
227// Config file template
228// ---------------------------------------------------------------------------
229
230pub fn config_template(accounts: &[(&str, &str)]) -> String {
231    let mut out = String::from(
232        "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
233    );
234    for (name, plan_type) in accounts {
235        out.push_str(&format!(
236            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
237        ));
238    }
239    out
240}