1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7
8pub const APP_NAME: &str = "shunt";
9
10pub fn config_path() -> PathBuf {
11 dirs::config_dir()
12 .unwrap_or_else(|| PathBuf::from("."))
13 .join(APP_NAME)
14 .join("config.toml")
15}
16
17pub fn credentials_path() -> PathBuf {
18 dirs::config_dir()
19 .unwrap_or_else(|| PathBuf::from("."))
20 .join(APP_NAME)
21 .join("credentials.json")
22}
23
24pub fn state_path() -> PathBuf {
25 dirs::data_local_dir()
26 .unwrap_or_else(|| PathBuf::from("."))
27 .join(APP_NAME)
28 .join("state.json")
29}
30
31pub fn log_path() -> PathBuf {
32 dirs::data_local_dir()
33 .unwrap_or_else(|| PathBuf::from("."))
34 .join(APP_NAME)
35 .join("proxy.log")
36}
37
38pub fn pid_path() -> PathBuf {
39 dirs::data_local_dir()
40 .unwrap_or_else(|| PathBuf::from("."))
41 .join(APP_NAME)
42 .join("shunt.pid")
43}
44
45#[derive(Debug, Default, Serialize, Deserialize)]
50pub struct CredentialsStore {
51 pub accounts: HashMap<String, OAuthCredential>,
52}
53
54impl CredentialsStore {
55 pub fn load() -> Self {
56 let p = credentials_path();
57 if !p.exists() {
58 return Self::default();
59 }
60 match std::fs::read_to_string(&p) {
61 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
62 Err(_) => Self::default(),
63 }
64 }
65
66 pub fn save(&self) -> Result<()> {
67 let p = credentials_path();
68 if let Some(parent) = p.parent() {
69 std::fs::create_dir_all(parent)?;
70 }
71 let tmp = p.with_extension("tmp");
72 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
73 std::fs::rename(&tmp, &p)?;
74 #[cfg(unix)]
75 {
76 use std::os::unix::fs::PermissionsExt;
77 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
78 }
79 Ok(())
80 }
81}
82
83#[derive(Debug, Deserialize)]
88struct RawConfig {
89 #[serde(default)]
90 server: RawServer,
91 #[serde(default)]
92 accounts: Vec<RawAccount>,
93}
94
95#[derive(Debug, Deserialize)]
96struct RawServer {
97 #[serde(default = "default_host")]
98 host: String,
99 #[serde(default = "default_port")]
100 port: u16,
101 #[serde(default = "default_log_level")]
102 log_level: String,
103 upstream_url: Option<String>,
104 remote_key: Option<String>,
105}
106
107impl Default for RawServer {
108 fn default() -> Self {
109 Self {
110 host: default_host(),
111 port: default_port(),
112 log_level: default_log_level(),
113 upstream_url: None,
114 remote_key: None,
115 }
116 }
117}
118
119#[derive(Debug, Deserialize)]
120struct RawAccount {
121 name: String,
122 #[serde(default = "default_plan_type")]
123 plan_type: String,
124}
125
126fn default_host() -> String { "127.0.0.1".into() }
127fn default_port() -> u16 { 8082 }
128fn default_log_level() -> String { "info".into() }
129fn default_plan_type() -> String { "pro".into() }
130
131#[derive(Debug, Clone)]
136pub struct ServerConfig {
137 pub host: String,
138 pub port: u16,
139 pub log_level: String,
140 pub upstream_url: String,
141 pub remote_key: Option<String>,
143}
144
145#[derive(Debug, Clone)]
146pub struct AccountConfig {
147 pub name: String,
148 pub plan_type: String,
149 pub credential: Option<OAuthCredential>,
152}
153
154#[derive(Debug, Clone)]
155pub struct Config {
156 pub server: ServerConfig,
157 pub accounts: Vec<AccountConfig>,
158 pub config_file: PathBuf,
159}
160
161pub fn load_config(path: Option<&Path>) -> Result<Config> {
166 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
167
168 if !p.exists() {
169 bail!(
170 "Config not found: {}\nRun `shunt setup` to get started.",
171 p.display()
172 );
173 }
174
175 let raw_text = std::fs::read_to_string(&p)
176 .with_context(|| format!("Failed to read config: {}", p.display()))?;
177
178 let raw: RawConfig = toml::from_str(&raw_text)
179 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
180
181 let upstream_url = raw
182 .server
183 .upstream_url
184 .clone()
185 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
186 .unwrap_or_else(|| "https://api.anthropic.com".into());
187
188 let server = ServerConfig {
189 host: raw.server.host,
190 port: raw.server.port,
191 log_level: raw.server.log_level,
192 upstream_url,
193 remote_key: raw.server.remote_key,
194 };
195
196 if raw.accounts.is_empty() {
197 bail!("Config has no accounts. Run `shunt setup` to add one.");
198 }
199
200 let store = CredentialsStore::load();
201
202 let mut accounts = Vec::new();
203 for a in &raw.accounts {
204 let cred = if a.name == "main" || store.accounts.is_empty() {
206 store
208 .accounts
209 .get(&a.name)
210 .cloned()
211 .or_else(crate::oauth::read_claude_credentials)
212 } else {
213 store.accounts.get(&a.name).cloned()
214 };
215
216 accounts.push(AccountConfig {
217 name: a.name.clone(),
218 plan_type: a.plan_type.clone(),
219 credential: cred,
220 });
221 }
222
223 Ok(Config { server, accounts, config_file: p })
224}
225
226pub fn config_template(accounts: &[(&str, &str)]) -> String {
231 let mut out = String::from(
232 "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
233 );
234 for (name, plan_type) in accounts {
235 out.push_str(&format!(
236 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
237 ));
238 }
239 out
240}