Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::oauth::OAuthCredential;
8use crate::provider::Provider;
9
10pub const APP_NAME: &str = "shunt";
11
12pub fn config_path() -> PathBuf {
13    dirs::config_dir()
14        .unwrap_or_else(|| PathBuf::from("."))
15        .join(APP_NAME)
16        .join("config.toml")
17}
18
19pub fn credentials_path() -> PathBuf {
20    dirs::config_dir()
21        .unwrap_or_else(|| PathBuf::from("."))
22        .join(APP_NAME)
23        .join("credentials.json")
24}
25
26pub fn state_path() -> PathBuf {
27    dirs::data_local_dir()
28        .unwrap_or_else(|| PathBuf::from("."))
29        .join(APP_NAME)
30        .join("state.json")
31}
32
33pub fn log_path() -> PathBuf {
34    dirs::data_local_dir()
35        .unwrap_or_else(|| PathBuf::from("."))
36        .join(APP_NAME)
37        .join("proxy.log")
38}
39
40pub fn pid_path() -> PathBuf {
41    dirs::data_local_dir()
42        .unwrap_or_else(|| PathBuf::from("."))
43        .join(APP_NAME)
44        .join("shunt.pid")
45}
46
47// ---------------------------------------------------------------------------
48// Credentials store  (separate file from config — never commit this)
49// ---------------------------------------------------------------------------
50
51#[derive(Debug, Default, Serialize, Deserialize)]
52pub struct CredentialsStore {
53    #[serde(deserialize_with = "deserialize_credential_map", default)]
54    pub accounts: HashMap<String, Credential>,
55}
56
57impl CredentialsStore {
58    pub fn load() -> Self {
59        let p = credentials_path();
60        if !p.exists() {
61            return Self::default();
62        }
63        match std::fs::read_to_string(&p) {
64            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
65            Err(_) => Self::default(),
66        }
67    }
68
69    pub fn save(&self) -> Result<()> {
70        let p = credentials_path();
71        if let Some(parent) = p.parent() {
72            std::fs::create_dir_all(parent)?;
73        }
74        let tmp = p.with_extension("tmp");
75        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
76        std::fs::rename(&tmp, &p)?;
77        #[cfg(unix)]
78        {
79            use std::os::unix::fs::PermissionsExt;
80            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
81        }
82        // On Windows, restrict the file to the current user via icacls (best-effort).
83        #[cfg(windows)]
84        {
85            if let Some(path_str) = p.to_str() {
86                let username = std::env::var("USERNAME").unwrap_or_default();
87                if !username.is_empty() {
88                    let _ = std::process::Command::new("icacls")
89                        .arg(path_str)
90                        .arg("/inheritance:r")
91                        .arg("/grant:r")
92                        .arg(format!("{username}:F"))
93                        .status();
94                }
95            }
96        }
97        Ok(())
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Raw TOML config types
103// ---------------------------------------------------------------------------
104
105#[derive(Debug, Deserialize)]
106struct RawConfig {
107    #[serde(default)]
108    server: RawServer,
109    #[serde(default)]
110    accounts: Vec<RawAccount>,
111}
112
113#[derive(Debug, Deserialize)]
114struct RawServer {
115    #[serde(default = "default_host")]
116    host: String,
117    #[serde(default = "default_port")]
118    port: u16,
119    #[serde(default = "default_control_port")]
120    control_port: u16,
121    #[serde(default = "default_log_level")]
122    log_level: String,
123    upstream_url: Option<String>,
124    remote_key: Option<String>,
125    relay_url: Option<String>,
126    pub custom_domain: Option<String>,
127    /// Conversation stickiness TTL in minutes (default: 10)
128    sticky_ttl_minutes: Option<u64>,
129    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
130    expiry_soon_minutes: Option<u64>,
131    /// Upstream request timeout in seconds (default: 600)
132    request_timeout_secs: Option<u64>,
133}
134
135impl Default for RawServer {
136    fn default() -> Self {
137        Self {
138            host: default_host(),
139            port: default_port(),
140            control_port: default_control_port(),
141            log_level: default_log_level(),
142            upstream_url: None,
143            remote_key: None,
144            relay_url: None,
145            custom_domain: None,
146            sticky_ttl_minutes: None,
147            expiry_soon_minutes: None,
148            request_timeout_secs: None,
149        }
150    }
151}
152
153#[derive(Debug, Deserialize)]
154struct RawAccount {
155    name: String,
156    #[serde(default = "default_plan_type")]
157    plan_type: String,
158    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
159    #[serde(default)]
160    provider: Option<String>,
161    /// Inline API key (use api_key_env for better security).
162    #[serde(default)]
163    api_key: Option<String>,
164    /// Name of an environment variable that holds the API key.
165    #[serde(default)]
166    api_key_env: Option<String>,
167    /// Per-account upstream URL override (required for Local provider).
168    #[serde(default)]
169    upstream_url: Option<String>,
170}
171
172fn default_host() -> String { "127.0.0.1".into() }
173fn default_port() -> u16 { 8082 }
174fn default_control_port() -> u16 { 19081 }
175fn default_log_level() -> String { "info".into() }
176fn default_plan_type() -> String { "pro".into() }
177
178// ---------------------------------------------------------------------------
179// Resolved config types
180// ---------------------------------------------------------------------------
181
182#[derive(Debug, Clone)]
183pub struct ServerConfig {
184    pub host: String,
185    pub port: u16,
186    /// Port for the control plane (/status, /use, /health) — sees all accounts.
187    pub control_port: u16,
188    pub log_level: String,
189    pub upstream_url: String,
190    /// When set, remote requests must supply this value as `x-api-key`.
191    pub remote_key: Option<String>,
192    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
193    pub relay_url: String,
194    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
195    pub custom_domain: Option<String>,
196    /// Conversation stickiness TTL in milliseconds.
197    pub sticky_ttl_ms: u64,
198    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
199    pub expiry_soon_secs: u64,
200    /// Upstream request timeout in seconds.
201    pub request_timeout_secs: u64,
202}
203
204impl Default for ServerConfig {
205    fn default() -> Self {
206        Self {
207            host: "127.0.0.1".into(),
208            port: 8082,
209            control_port: 19081,
210            log_level: "info".into(),
211            upstream_url: "https://api.anthropic.com".into(),
212            remote_key: None,
213            relay_url: "https://relay.ramcharan.shop".into(),
214            custom_domain: None,
215            sticky_ttl_ms: 10 * 60 * 1000,
216            expiry_soon_secs: 30 * 60,
217            request_timeout_secs: 600,
218        }
219    }
220}
221
222#[derive(Debug, Clone)]
223pub struct AccountConfig {
224    pub name: String,
225    pub plan_type: String,
226    pub provider: Provider,
227    /// `None` when the account has no credential.
228    /// OAuth accounts: None means reauth required (shown as auth_failed).
229    /// ApiKey accounts: None means key not yet configured.
230    /// Local accounts: None is normal (no auth required).
231    pub credential: Option<Credential>,
232    /// Override the upstream base URL for this account.
233    /// `None` means use `config.server.upstream_url` (primary provider) or
234    /// `provider.default_upstream_url()` (non-primary provider).
235    pub upstream_url: Option<String>,
236}
237
238#[derive(Debug, Clone)]
239pub struct Config {
240    pub server: ServerConfig,
241    pub accounts: Vec<AccountConfig>,
242    pub config_file: PathBuf,
243}
244
245// ---------------------------------------------------------------------------
246// Loading
247// ---------------------------------------------------------------------------
248
249pub fn load_config(path: Option<&Path>) -> Result<Config> {
250    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
251
252    if !p.exists() {
253        bail!(
254            "Config not found: {}\nRun `shunt setup` to get started.",
255            p.display()
256        );
257    }
258
259    let raw_text = std::fs::read_to_string(&p)
260        .with_context(|| format!("Failed to read config: {}", p.display()))?;
261
262    let raw: RawConfig = toml::from_str(&raw_text)
263        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
264
265    // Derive the default upstream URL from the first account's provider so that
266    // an all-OpenAI config automatically points at api.openai.com without any
267    // explicit `upstream_url` in the config file.
268    let default_upstream = raw.accounts.first()
269        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
270        .unwrap_or_default()
271        .default_upstream_url()
272        .to_owned();
273
274    let upstream_url = raw
275        .server
276        .upstream_url
277        .clone()
278        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
279        .unwrap_or(default_upstream);
280
281    let relay_url = raw
282        .server
283        .relay_url
284        .clone()
285        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
286        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
287
288    let server = ServerConfig {
289        host: raw.server.host,
290        port: raw.server.port,
291        control_port: raw.server.control_port,
292        log_level: raw.server.log_level,
293        upstream_url,
294        remote_key: raw.server.remote_key,
295        relay_url,
296        custom_domain: raw.server.custom_domain,
297        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
298        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
299        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
300    };
301
302    if raw.accounts.is_empty() {
303        bail!("Config has no accounts. Run `shunt setup` to add one.");
304    }
305
306    let store = CredentialsStore::load();
307
308    // Determine the primary provider (first account) so we know which accounts
309    // use config.server.upstream_url and which need the provider's default URL.
310    let primary_provider = raw.accounts.first()
311        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
312        .unwrap_or_default();
313
314    let mut accounts = Vec::new();
315    for a in &raw.accounts {
316        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
317
318        // Resolve credential.
319        //
320        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
321        // auto-import from the provider's local CLI tool.
322        //
323        // API-key providers: credentials.json first, then inline api_key field,
324        // then api_key_env field, then the provider's well-known env var.
325        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
326            .or_else(|| {
327                // Inline api_key from TOML (less secure, but convenient for testing).
328                a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
329            })
330            .or_else(|| {
331                // api_key_env: name of env var holding the key.
332                a.api_key_env.as_deref()
333                    .and_then(|var| std::env::var(var).ok())
334                    .map(|k| Credential::Apikey { key: k })
335            })
336            .or_else(|| {
337                // Auto-import from provider's CLI tool (OAuth providers) or
338                // well-known env var (API-key providers).
339                provider.read_local_credentials()
340            });
341
342        // Upstream URL: per-account override from TOML takes priority, then
343        // non-primary-provider accounts get the provider's default URL so
344        // the forwarder knows where to send requests.
345        let acct_upstream = a.upstream_url.clone().or_else(|| {
346            if provider != primary_provider {
347                Some(provider.default_upstream_url().to_owned())
348            } else {
349                None
350            }
351        });
352
353        accounts.push(AccountConfig {
354            name: a.name.clone(),
355            plan_type: a.plan_type.clone(),
356            provider,
357            credential: cred,
358            upstream_url: acct_upstream,
359        });
360    }
361
362    Ok(Config { server, accounts, config_file: p })
363}
364
365// ---------------------------------------------------------------------------
366// Config file template
367// ---------------------------------------------------------------------------
368
369pub fn config_template(accounts: &[(&str, &str)]) -> String {
370    let mut out = String::from(
371        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
372    );
373    for (name, plan_type) in accounts {
374        out.push_str(&format!(
375            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
376        ));
377    }
378    out
379}