1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12 dirs::config_dir()
13 .unwrap_or_else(|| PathBuf::from("."))
14 .join(APP_NAME)
15 .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19 dirs::config_dir()
20 .unwrap_or_else(|| PathBuf::from("."))
21 .join(APP_NAME)
22 .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26 dirs::data_local_dir()
27 .unwrap_or_else(|| PathBuf::from("."))
28 .join(APP_NAME)
29 .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33 dirs::data_local_dir()
34 .unwrap_or_else(|| PathBuf::from("."))
35 .join(APP_NAME)
36 .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40 dirs::data_local_dir()
41 .unwrap_or_else(|| PathBuf::from("."))
42 .join(APP_NAME)
43 .join("shunt.pid")
44}
45
46#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52 #[serde(deserialize_with = "deserialize_credential_map", default)]
53 pub accounts: HashMap<String, Credential>,
54}
55
56impl CredentialsStore {
57 pub fn load() -> Self {
58 let p = credentials_path();
59 if !p.exists() {
60 return Self::default();
61 }
62 match std::fs::read_to_string(&p) {
63 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
64 Err(_) => Self::default(),
65 }
66 }
67
68 pub fn save(&self) -> Result<()> {
69 let p = credentials_path();
70 if let Some(parent) = p.parent() {
71 std::fs::create_dir_all(parent)?;
72 }
73 let tmp = p.with_extension("tmp");
74 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
75 std::fs::rename(&tmp, &p)?;
76 #[cfg(unix)]
77 {
78 use std::os::unix::fs::PermissionsExt;
79 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
80 }
81 #[cfg(windows)]
83 {
84 if let Some(path_str) = p.to_str() {
85 let username = std::env::var("USERNAME").unwrap_or_default();
86 if !username.is_empty() {
87 let _ = std::process::Command::new("icacls")
88 .arg(path_str)
89 .arg("/inheritance:r")
90 .arg("/grant:r")
91 .arg(format!("{username}:F"))
92 .status();
93 }
94 }
95 }
96 Ok(())
97 }
98}
99
100#[derive(Debug, Deserialize)]
105struct RawConfig {
106 #[serde(default)]
107 server: RawServer,
108 #[serde(default)]
109 accounts: Vec<RawAccount>,
110 #[serde(default)]
113 model_mapping: HashMap<String, String>,
114}
115
116#[derive(Debug, Deserialize)]
117struct RawServer {
118 #[serde(default = "default_host")]
119 host: String,
120 #[serde(default = "default_port")]
121 port: u16,
122 #[serde(default = "default_control_port")]
123 control_port: u16,
124 #[serde(default = "default_log_level")]
125 log_level: String,
126 upstream_url: Option<String>,
127 remote_key: Option<String>,
128 relay_url: Option<String>,
129 pub custom_domain: Option<String>,
130 sticky_ttl_minutes: Option<u64>,
132 expiry_soon_minutes: Option<u64>,
134 request_timeout_secs: Option<u64>,
136}
137
138impl Default for RawServer {
139 fn default() -> Self {
140 Self {
141 host: default_host(),
142 port: default_port(),
143 control_port: default_control_port(),
144 log_level: default_log_level(),
145 upstream_url: None,
146 remote_key: None,
147 relay_url: None,
148 custom_domain: None,
149 sticky_ttl_minutes: None,
150 expiry_soon_minutes: None,
151 request_timeout_secs: None,
152 }
153 }
154}
155
156#[derive(Debug, Deserialize)]
157struct RawAccount {
158 name: String,
159 #[serde(default = "default_plan_type")]
160 plan_type: String,
161 #[serde(default)]
163 provider: Option<String>,
164 #[serde(default)]
166 api_key: Option<String>,
167 #[serde(default)]
169 api_key_env: Option<String>,
170 #[serde(default)]
172 upstream_url: Option<String>,
173 #[serde(default)]
176 model: Option<String>,
177}
178
179fn default_host() -> String { "127.0.0.1".into() }
180fn default_port() -> u16 { 8082 }
181fn default_control_port() -> u16 { 19081 }
182fn default_log_level() -> String { "info".into() }
183fn default_plan_type() -> String { "pro".into() }
184
185#[derive(Debug, Clone)]
190pub struct ServerConfig {
191 pub host: String,
192 pub port: u16,
193 pub control_port: u16,
195 pub log_level: String,
196 pub upstream_url: String,
197 pub remote_key: Option<String>,
199 pub relay_url: String,
201 pub custom_domain: Option<String>,
203 pub sticky_ttl_ms: u64,
205 pub expiry_soon_secs: u64,
207 pub request_timeout_secs: u64,
209}
210
211impl Default for ServerConfig {
212 fn default() -> Self {
213 Self {
214 host: "127.0.0.1".into(),
215 port: 8082,
216 control_port: 19081,
217 log_level: "info".into(),
218 upstream_url: "https://api.anthropic.com".into(),
219 remote_key: None,
220 relay_url: "https://relay.ramcharan.shop".into(),
221 custom_domain: None,
222 sticky_ttl_ms: 10 * 60 * 1000,
223 expiry_soon_secs: 30 * 60,
224 request_timeout_secs: 600,
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
230pub struct AccountConfig {
231 pub name: String,
232 pub plan_type: String,
233 pub provider: Provider,
234 pub credential: Option<Credential>,
239 pub upstream_url: Option<String>,
243 pub model: Option<String>,
246}
247
248#[derive(Debug, Clone)]
249pub struct Config {
250 pub server: ServerConfig,
251 pub accounts: Vec<AccountConfig>,
252 pub config_file: PathBuf,
253 pub model_mapping: HashMap<String, String>,
256}
257
258pub fn load_config(path: Option<&Path>) -> Result<Config> {
263 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
264
265 if !p.exists() {
266 bail!(
267 "Config not found: {}\nRun `shunt setup` to get started.",
268 p.display()
269 );
270 }
271
272 let raw_text = std::fs::read_to_string(&p)
273 .with_context(|| format!("Failed to read config: {}", p.display()))?;
274
275 let raw: RawConfig = toml::from_str(&raw_text)
276 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
277
278 let default_upstream = raw.accounts.first()
282 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
283 .unwrap_or_default()
284 .default_upstream_url()
285 .to_owned();
286
287 let upstream_url = raw
288 .server
289 .upstream_url
290 .clone()
291 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
292 .unwrap_or(default_upstream);
293
294 let relay_url = raw
295 .server
296 .relay_url
297 .clone()
298 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
299 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
300
301 let server = ServerConfig {
302 host: raw.server.host,
303 port: raw.server.port,
304 control_port: raw.server.control_port,
305 log_level: raw.server.log_level,
306 upstream_url,
307 remote_key: raw.server.remote_key,
308 relay_url,
309 custom_domain: raw.server.custom_domain,
310 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
311 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
312 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
313 };
314
315 if raw.accounts.is_empty() {
316 bail!("Config has no accounts. Run `shunt setup` to add one.");
317 }
318
319 let store = CredentialsStore::load();
320
321 let primary_provider = raw.accounts.first()
324 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
325 .unwrap_or_default();
326
327 let mut accounts = Vec::new();
328 for a in &raw.accounts {
329 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
330
331 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
339 .or_else(|| {
340 a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
342 })
343 .or_else(|| {
344 a.api_key_env.as_deref()
346 .and_then(|var| std::env::var(var).ok())
347 .map(|k| Credential::Apikey { key: k })
348 })
349 .or_else(|| {
350 provider.read_local_credentials()
353 });
354
355 let acct_upstream = a.upstream_url.clone().or_else(|| {
359 if provider != primary_provider {
360 Some(provider.default_upstream_url().to_owned())
361 } else {
362 None
363 }
364 });
365
366 accounts.push(AccountConfig {
367 name: a.name.clone(),
368 plan_type: a.plan_type.clone(),
369 provider,
370 credential: cred,
371 upstream_url: acct_upstream,
372 model: a.model.clone(),
373 });
374 }
375
376 Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
377}
378
379pub fn config_template(accounts: &[(&str, &str)]) -> String {
384 let mut out = String::from(
385 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
386 );
387 for (name, plan_type) in accounts {
388 out.push_str(&format!(
389 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
390 ));
391 }
392 out
393}