1use anyhow::{bail, Context, Result};
2
3fn validate_upstream_url(url: &str, allow_loopback: bool) -> Result<()> {
7 let parsed = url::Url::parse(url)
8 .with_context(|| format!("Invalid upstream URL: {url}"))?;
9 match parsed.scheme() {
10 "http" | "https" => {}
11 s => bail!("Upstream URL must use http or https, got scheme '{s}': {url}"),
12 }
13 if !allow_loopback {
14 if let Some(host) = parsed.host_str() {
15 let blocked = matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
16 || host.starts_with("169.254.")
17 || host.starts_with("fd");
18 if blocked {
19 bail!("Upstream URL must not point to loopback or link-local addresses: {url}");
20 }
21 }
22 }
23 Ok(())
24}
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29use crate::credential::{deserialize_credential_map, Credential};
30use crate::provider::Provider;
31
32pub const APP_NAME: &str = "shunt";
33
34pub fn config_path() -> PathBuf {
35 dirs::config_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join(APP_NAME)
38 .join("config.toml")
39}
40
41pub fn credentials_path() -> PathBuf {
42 dirs::config_dir()
43 .unwrap_or_else(|| PathBuf::from("."))
44 .join(APP_NAME)
45 .join("credentials.json")
46}
47
48pub fn state_path() -> PathBuf {
49 dirs::data_local_dir()
50 .unwrap_or_else(|| PathBuf::from("."))
51 .join(APP_NAME)
52 .join("state.json")
53}
54
55pub fn log_path() -> PathBuf {
56 dirs::data_local_dir()
57 .unwrap_or_else(|| PathBuf::from("."))
58 .join(APP_NAME)
59 .join("proxy.log")
60}
61
62pub fn notify_log_path() -> PathBuf {
63 dirs::data_local_dir()
64 .unwrap_or_else(|| PathBuf::from("."))
65 .join(APP_NAME)
66 .join("notify.log")
67}
68
69pub fn pid_path() -> PathBuf {
70 dirs::data_local_dir()
71 .unwrap_or_else(|| PathBuf::from("."))
72 .join(APP_NAME)
73 .join("shunt.pid")
74}
75
76#[derive(Debug, Default, Serialize, Deserialize)]
81pub struct CredentialsStore {
82 #[serde(deserialize_with = "deserialize_credential_map", default)]
83 pub accounts: HashMap<String, Credential>,
84}
85
86impl CredentialsStore {
87 pub fn load() -> Self {
88 let p = credentials_path();
89 if !p.exists() {
90 return Self::default();
91 }
92 match std::fs::read_to_string(&p) {
93 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
94 Err(_) => Self::default(),
95 }
96 }
97
98 pub fn save(&self) -> Result<()> {
99 let p = credentials_path();
100 if let Some(parent) = p.parent() {
101 std::fs::create_dir_all(parent)?;
102 }
103 let tmp = p.with_extension("tmp");
104 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
105 #[cfg(unix)]
106 {
107 use std::os::unix::fs::PermissionsExt;
108 std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
109 }
110 std::fs::rename(&tmp, &p)?;
111 #[cfg(windows)]
113 {
114 if let Some(path_str) = p.to_str() {
115 let username = std::env::var("USERNAME").unwrap_or_default();
116 if !username.is_empty() {
117 let _ = std::process::Command::new("icacls")
118 .arg(path_str)
119 .arg("/inheritance:r")
120 .arg("/grant:r")
121 .arg(format!("{username}:F"))
122 .status();
123 }
124 }
125 }
126 Ok(())
127 }
128}
129
130#[derive(Debug, Deserialize)]
135struct RawConfig {
136 #[serde(default)]
137 server: RawServer,
138 #[serde(default)]
139 accounts: Vec<RawAccount>,
140 #[serde(default)]
143 model_mapping: HashMap<String, String>,
144}
145
146#[derive(Debug, Deserialize)]
147struct RawServer {
148 #[serde(default = "default_host")]
149 host: String,
150 #[serde(default = "default_port")]
151 port: u16,
152 #[serde(default = "default_control_port")]
153 control_port: u16,
154 #[serde(default = "default_log_level")]
155 log_level: String,
156 upstream_url: Option<String>,
157 remote_key: Option<String>,
158 relay_url: Option<String>,
159 pub custom_domain: Option<String>,
160 sticky_ttl_minutes: Option<u64>,
162 expiry_soon_minutes: Option<u64>,
164 routing_strategy: Option<String>,
166 request_timeout_secs: Option<u64>,
168 rate_limit_rpm: Option<u32>,
170 trust_proxy_headers: Option<bool>,
174 telemetry_url: Option<String>,
177 telemetry_token: Option<String>,
179 instance_name: Option<String>,
182}
183
184impl Default for RawServer {
185 fn default() -> Self {
186 Self {
187 host: default_host(),
188 port: default_port(),
189 control_port: default_control_port(),
190 log_level: default_log_level(),
191 upstream_url: None,
192 remote_key: None,
193 relay_url: None,
194 custom_domain: None,
195 sticky_ttl_minutes: None,
196 expiry_soon_minutes: None,
197 routing_strategy: None,
198 request_timeout_secs: None,
199 rate_limit_rpm: None,
200 trust_proxy_headers: None,
201 telemetry_url: None,
202 telemetry_token: None,
203 instance_name: None,
204 }
205 }
206}
207
208#[derive(Debug, Deserialize)]
209struct RawAccount {
210 name: String,
211 #[serde(default = "default_plan_type")]
212 plan_type: String,
213 #[serde(default)]
215 provider: Option<String>,
216 #[serde(default)]
218 api_key: Option<String>,
219 #[serde(default)]
221 api_key_env: Option<String>,
222 #[serde(default)]
224 upstream_url: Option<String>,
225 #[serde(default)]
228 model: Option<String>,
229}
230
231fn default_host() -> String { "127.0.0.1".into() }
232
233pub fn default_instance_name() -> String {
234 hostname::get()
235 .ok()
236 .and_then(|h| h.into_string().ok())
237 .unwrap_or_else(|| "shunt".into())
238}
239fn default_port() -> u16 { 8082 }
240fn default_control_port() -> u16 { 19081 }
241fn default_log_level() -> String { "info".into() }
242fn default_plan_type() -> String { "pro".into() }
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
250pub enum RoutingStrategy {
251 Reaper,
256 Carousel,
259 Cushion,
263 #[default]
271 Maximus,
272}
273
274impl RoutingStrategy {
275 pub fn as_str(&self) -> &'static str {
276 match self {
277 Self::Reaper => "reaper",
278 Self::Carousel => "carousel",
279 Self::Cushion => "cushion",
280 Self::Maximus => "maximus",
281 }
282 }
283
284 pub fn from_str(s: &str) -> Option<Self> {
285 match s {
286 "reaper" | "earliest-expiry" | "earliest_expiry" => Some(Self::Reaper),
287 "carousel" | "round-robin" | "round_robin" => Some(Self::Carousel),
288 "cushion" | "most-available" | "most_available" => Some(Self::Cushion),
289 "maximus" => Some(Self::Maximus),
290 _ => None,
291 }
292 }
293}
294
295#[derive(Debug, Clone)]
296pub struct ServerConfig {
297 pub host: String,
298 pub port: u16,
299 pub control_port: u16,
301 pub log_level: String,
302 pub upstream_url: String,
303 pub remote_key: Option<String>,
305 pub relay_url: String,
307 pub custom_domain: Option<String>,
309 pub sticky_ttl_ms: u64,
311 pub expiry_soon_secs: u64,
313 pub routing_strategy: RoutingStrategy,
315 pub request_timeout_secs: u64,
317 pub rate_limit_rpm: u32,
319 pub trust_proxy_headers: bool,
321 pub telemetry_url: Option<String>,
323 pub telemetry_token: Option<String>,
325 pub instance_name: String,
327}
328
329impl Default for ServerConfig {
330 fn default() -> Self {
331 Self {
332 host: "127.0.0.1".into(),
333 port: 8082,
334 control_port: 19081,
335 log_level: "info".into(),
336 upstream_url: "https://api.anthropic.com".into(),
337 remote_key: None,
338 relay_url: "https://relay.ramcharan.shop".into(),
339 custom_domain: None,
340 sticky_ttl_ms: 10 * 60 * 1000,
341 expiry_soon_secs: 30 * 60,
342 routing_strategy: RoutingStrategy::Maximus,
343 request_timeout_secs: 600,
344 rate_limit_rpm: 0,
345 trust_proxy_headers: false,
346 telemetry_url: None,
347 telemetry_token: None,
348 instance_name: default_instance_name(),
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
354pub struct AccountConfig {
355 pub name: String,
356 pub plan_type: String,
357 pub provider: Provider,
358 pub credential: Option<Credential>,
363 pub upstream_url: Option<String>,
367 pub model: Option<String>,
370}
371
372#[derive(Debug, Clone)]
373pub struct Config {
374 pub server: ServerConfig,
375 pub accounts: Vec<AccountConfig>,
376 pub config_file: PathBuf,
377 pub model_mapping: HashMap<String, String>,
380}
381
382pub fn load_config(path: Option<&Path>) -> Result<Config> {
387 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
388
389 if !p.exists() {
390 bail!(
391 "Config not found: {}\nRun `shunt setup` to get started.",
392 p.display()
393 );
394 }
395
396 let raw_text = std::fs::read_to_string(&p)
397 .with_context(|| format!("Failed to read config: {}", p.display()))?;
398
399 let raw: RawConfig = toml::from_str(&raw_text)
400 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
401
402 let primary_provider_derived = raw.accounts.first()
406 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
407 .unwrap_or_default();
408 let default_upstream = primary_provider_derived.default_upstream_url().to_owned();
409
410 let upstream_url = raw
411 .server
412 .upstream_url
413 .clone()
414 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
415 .unwrap_or(default_upstream);
416
417 let relay_url = raw
418 .server
419 .relay_url
420 .clone()
421 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
422 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
423
424 let telemetry_url = raw.server.telemetry_url.clone()
425 .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
426 let telemetry_token = raw.server.telemetry_token.clone()
427 .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
428 let instance_name = raw.server.instance_name.clone()
429 .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
430 .unwrap_or_else(default_instance_name);
431
432 let server_url_is_local_derived = raw.server.upstream_url.is_none()
437 && std::env::var("SHUNT_UPSTREAM_URL").is_err()
438 && matches!(primary_provider_derived, Provider::Local);
439 validate_upstream_url(&upstream_url, server_url_is_local_derived)
440 .with_context(|| "server.upstream_url failed validation")?;
441
442 let server = ServerConfig {
443 host: raw.server.host,
444 port: raw.server.port,
445 control_port: raw.server.control_port,
446 log_level: raw.server.log_level,
447 upstream_url,
448 remote_key: raw.server.remote_key,
449 relay_url,
450 custom_domain: raw.server.custom_domain,
451 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
452 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
453 routing_strategy: raw.server.routing_strategy.as_deref()
454 .and_then(RoutingStrategy::from_str)
455 .unwrap_or_default(),
456 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
457 rate_limit_rpm: raw.server.rate_limit_rpm.unwrap_or(0),
458 trust_proxy_headers: raw.server.trust_proxy_headers.unwrap_or(false),
459 telemetry_url,
460 telemetry_token,
461 instance_name,
462 };
463
464 if raw.accounts.is_empty() {
465 bail!("Config has no accounts. Run `shunt setup` to add one.");
466 }
467
468 let store = CredentialsStore::load();
469
470 let primary_provider = primary_provider_derived;
472
473 let mut accounts = Vec::new();
474 for a in &raw.accounts {
475 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
476
477 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
485 .or_else(|| {
486 a.api_key.as_deref().map(|k| {
488 tracing::warn!(account = %a.name, "Inline api_key in config.toml is insecure — use api_key_env instead");
489 Credential::Apikey { key: k.to_owned() }
490 })
491 })
492 .or_else(|| {
493 a.api_key_env.as_deref()
495 .and_then(|var| std::env::var(var).ok())
496 .map(|k| Credential::Apikey { key: k })
497 })
498 .or_else(|| {
499 provider.read_local_credentials()
502 });
503
504 let is_local = matches!(provider, Provider::Local);
508 if let Some(ref url) = a.upstream_url {
509 validate_upstream_url(url, is_local)
511 .with_context(|| format!("account '{}' upstream_url failed validation", a.name))?;
512 }
513 let acct_upstream = a.upstream_url.clone().or_else(|| {
514 if provider != primary_provider {
515 Some(provider.default_upstream_url().to_owned())
516 } else {
517 None
518 }
519 });
520
521 accounts.push(AccountConfig {
522 name: a.name.clone(),
523 plan_type: a.plan_type.clone(),
524 provider,
525 credential: cred,
526 upstream_url: acct_upstream,
527 model: a.model.clone(),
528 });
529 }
530
531 Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
532}
533
534pub fn config_template(accounts: &[(&str, &str)]) -> String {
539 let mut out = String::from(
540 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
541 );
542 for (name, plan_type) in accounts {
543 out.push_str(&format!(
544 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
545 ));
546 }
547 out
548}