1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12 dirs::config_dir()
13 .unwrap_or_else(|| PathBuf::from("."))
14 .join(APP_NAME)
15 .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19 dirs::config_dir()
20 .unwrap_or_else(|| PathBuf::from("."))
21 .join(APP_NAME)
22 .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26 dirs::data_local_dir()
27 .unwrap_or_else(|| PathBuf::from("."))
28 .join(APP_NAME)
29 .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33 dirs::data_local_dir()
34 .unwrap_or_else(|| PathBuf::from("."))
35 .join(APP_NAME)
36 .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40 dirs::data_local_dir()
41 .unwrap_or_else(|| PathBuf::from("."))
42 .join(APP_NAME)
43 .join("shunt.pid")
44}
45
46#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52 #[serde(deserialize_with = "deserialize_credential_map", default)]
53 pub accounts: HashMap<String, Credential>,
54}
55
56impl CredentialsStore {
57 pub fn load() -> Self {
58 let p = credentials_path();
59 if !p.exists() {
60 return Self::default();
61 }
62 match std::fs::read_to_string(&p) {
63 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
64 Err(_) => Self::default(),
65 }
66 }
67
68 pub fn save(&self) -> Result<()> {
69 let p = credentials_path();
70 if let Some(parent) = p.parent() {
71 std::fs::create_dir_all(parent)?;
72 }
73 let tmp = p.with_extension("tmp");
74 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
75 std::fs::rename(&tmp, &p)?;
76 #[cfg(unix)]
77 {
78 use std::os::unix::fs::PermissionsExt;
79 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
80 }
81 #[cfg(windows)]
83 {
84 if let Some(path_str) = p.to_str() {
85 let username = std::env::var("USERNAME").unwrap_or_default();
86 if !username.is_empty() {
87 let _ = std::process::Command::new("icacls")
88 .arg(path_str)
89 .arg("/inheritance:r")
90 .arg("/grant:r")
91 .arg(format!("{username}:F"))
92 .status();
93 }
94 }
95 }
96 Ok(())
97 }
98}
99
100#[derive(Debug, Deserialize)]
105struct RawConfig {
106 #[serde(default)]
107 server: RawServer,
108 #[serde(default)]
109 accounts: Vec<RawAccount>,
110 #[serde(default)]
113 model_mapping: HashMap<String, String>,
114}
115
116#[derive(Debug, Deserialize)]
117struct RawServer {
118 #[serde(default = "default_host")]
119 host: String,
120 #[serde(default = "default_port")]
121 port: u16,
122 #[serde(default = "default_control_port")]
123 control_port: u16,
124 #[serde(default = "default_log_level")]
125 log_level: String,
126 upstream_url: Option<String>,
127 remote_key: Option<String>,
128 relay_url: Option<String>,
129 pub custom_domain: Option<String>,
130 sticky_ttl_minutes: Option<u64>,
132 expiry_soon_minutes: Option<u64>,
134 request_timeout_secs: Option<u64>,
136 telemetry_url: Option<String>,
139 telemetry_token: Option<String>,
141 instance_name: Option<String>,
144}
145
146impl Default for RawServer {
147 fn default() -> Self {
148 Self {
149 host: default_host(),
150 port: default_port(),
151 control_port: default_control_port(),
152 log_level: default_log_level(),
153 upstream_url: None,
154 remote_key: None,
155 relay_url: None,
156 custom_domain: None,
157 sticky_ttl_minutes: None,
158 expiry_soon_minutes: None,
159 request_timeout_secs: None,
160 telemetry_url: None,
161 telemetry_token: None,
162 instance_name: None,
163 }
164 }
165}
166
167#[derive(Debug, Deserialize)]
168struct RawAccount {
169 name: String,
170 #[serde(default = "default_plan_type")]
171 plan_type: String,
172 #[serde(default)]
174 provider: Option<String>,
175 #[serde(default)]
177 api_key: Option<String>,
178 #[serde(default)]
180 api_key_env: Option<String>,
181 #[serde(default)]
183 upstream_url: Option<String>,
184 #[serde(default)]
187 model: Option<String>,
188}
189
190fn default_host() -> String { "127.0.0.1".into() }
191
192pub fn default_instance_name() -> String {
193 hostname::get()
194 .ok()
195 .and_then(|h| h.into_string().ok())
196 .unwrap_or_else(|| "shunt".into())
197}
198fn default_port() -> u16 { 8082 }
199fn default_control_port() -> u16 { 19081 }
200fn default_log_level() -> String { "info".into() }
201fn default_plan_type() -> String { "pro".into() }
202
203#[derive(Debug, Clone)]
208pub struct ServerConfig {
209 pub host: String,
210 pub port: u16,
211 pub control_port: u16,
213 pub log_level: String,
214 pub upstream_url: String,
215 pub remote_key: Option<String>,
217 pub relay_url: String,
219 pub custom_domain: Option<String>,
221 pub sticky_ttl_ms: u64,
223 pub expiry_soon_secs: u64,
225 pub request_timeout_secs: u64,
227 pub telemetry_url: Option<String>,
229 pub telemetry_token: Option<String>,
231 pub instance_name: String,
233}
234
235impl Default for ServerConfig {
236 fn default() -> Self {
237 Self {
238 host: "127.0.0.1".into(),
239 port: 8082,
240 control_port: 19081,
241 log_level: "info".into(),
242 upstream_url: "https://api.anthropic.com".into(),
243 remote_key: None,
244 relay_url: "https://relay.ramcharan.shop".into(),
245 custom_domain: None,
246 sticky_ttl_ms: 10 * 60 * 1000,
247 expiry_soon_secs: 30 * 60,
248 request_timeout_secs: 600,
249 telemetry_url: None,
250 telemetry_token: None,
251 instance_name: default_instance_name(),
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
257pub struct AccountConfig {
258 pub name: String,
259 pub plan_type: String,
260 pub provider: Provider,
261 pub credential: Option<Credential>,
266 pub upstream_url: Option<String>,
270 pub model: Option<String>,
273}
274
275#[derive(Debug, Clone)]
276pub struct Config {
277 pub server: ServerConfig,
278 pub accounts: Vec<AccountConfig>,
279 pub config_file: PathBuf,
280 pub model_mapping: HashMap<String, String>,
283}
284
285pub fn load_config(path: Option<&Path>) -> Result<Config> {
290 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
291
292 if !p.exists() {
293 bail!(
294 "Config not found: {}\nRun `shunt setup` to get started.",
295 p.display()
296 );
297 }
298
299 let raw_text = std::fs::read_to_string(&p)
300 .with_context(|| format!("Failed to read config: {}", p.display()))?;
301
302 let raw: RawConfig = toml::from_str(&raw_text)
303 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
304
305 let default_upstream = raw.accounts.first()
309 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
310 .unwrap_or_default()
311 .default_upstream_url()
312 .to_owned();
313
314 let upstream_url = raw
315 .server
316 .upstream_url
317 .clone()
318 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
319 .unwrap_or(default_upstream);
320
321 let relay_url = raw
322 .server
323 .relay_url
324 .clone()
325 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
326 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
327
328 let telemetry_url = raw.server.telemetry_url.clone()
329 .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
330 let telemetry_token = raw.server.telemetry_token.clone()
331 .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
332 let instance_name = raw.server.instance_name.clone()
333 .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
334 .unwrap_or_else(default_instance_name);
335
336 let server = ServerConfig {
337 host: raw.server.host,
338 port: raw.server.port,
339 control_port: raw.server.control_port,
340 log_level: raw.server.log_level,
341 upstream_url,
342 remote_key: raw.server.remote_key,
343 relay_url,
344 custom_domain: raw.server.custom_domain,
345 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
346 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
347 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
348 telemetry_url,
349 telemetry_token,
350 instance_name,
351 };
352
353 if raw.accounts.is_empty() {
354 bail!("Config has no accounts. Run `shunt setup` to add one.");
355 }
356
357 let store = CredentialsStore::load();
358
359 let primary_provider = raw.accounts.first()
362 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
363 .unwrap_or_default();
364
365 let mut accounts = Vec::new();
366 for a in &raw.accounts {
367 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
368
369 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
377 .or_else(|| {
378 a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
380 })
381 .or_else(|| {
382 a.api_key_env.as_deref()
384 .and_then(|var| std::env::var(var).ok())
385 .map(|k| Credential::Apikey { key: k })
386 })
387 .or_else(|| {
388 provider.read_local_credentials()
391 });
392
393 let acct_upstream = a.upstream_url.clone().or_else(|| {
397 if provider != primary_provider {
398 Some(provider.default_upstream_url().to_owned())
399 } else {
400 None
401 }
402 });
403
404 accounts.push(AccountConfig {
405 name: a.name.clone(),
406 plan_type: a.plan_type.clone(),
407 provider,
408 credential: cred,
409 upstream_url: acct_upstream,
410 model: a.model.clone(),
411 });
412 }
413
414 Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
415}
416
417pub fn config_template(accounts: &[(&str, &str)]) -> String {
422 let mut out = String::from(
423 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
424 );
425 for (name, plan_type) in accounts {
426 out.push_str(&format!(
427 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
428 ));
429 }
430 out
431}