Skip to main content

cc_switch/daemon/
state.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::fs::{File, OpenOptions};
4use std::io::Write;
5use std::path::{Path, PathBuf};
6
7#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
8pub struct ProxyEntry {
9    pub provider: String,
10    pub upstream: String,
11    pub proxy_port: u16,
12    #[serde(default, skip_serializing_if = "Option::is_none")]
13    pub api_port: Option<u16>,
14    pub data_dir: PathBuf,
15    pub started_at: String,
16    pub restart_count: u32,
17}
18
19/// Version of the `cc-switch` binary that built this crate. Recorded into the
20/// daemon state at start time so a newer CLI can detect a stale running daemon.
21pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
22
23#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
24pub struct DaemonState {
25    pub schema_version: u32,
26    /// `cc-switch` version that started the daemon. Empty for state files
27    /// written before version tracking existed (treated as a mismatch).
28    #[serde(default)]
29    pub version: String,
30    pub pid: u32,
31    pub started_at: String,
32    pub stopped_at: Option<String>,
33    pub data_root: PathBuf,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub agg_port: Option<u16>,
36    pub proxies: Vec<ProxyEntry>,
37}
38
39impl DaemonState {
40    /// Load state from disk. Returns `Ok(None)` when the file does not exist;
41    /// returns `Err` with the path on corrupt JSON or other IO errors.
42    pub fn load(path: &Path) -> Result<Option<DaemonState>> {
43        let raw = match std::fs::read_to_string(path) {
44            Ok(contents) => contents,
45            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
46            Err(err) => {
47                return Err(err)
48                    .with_context(|| format!("failed to read daemon state at {}", path.display()));
49            }
50        };
51        let state: DaemonState = serde_json::from_str(&raw)
52            .with_context(|| format!("failed to parse daemon state at {}", path.display()))?;
53        Ok(Some(state))
54    }
55
56    /// Save state atomically: write to `<path>.tmp` (mode 0600 on Unix),
57    /// fsync, then rename over `path`.
58    pub fn save(&self, path: &Path) -> Result<()> {
59        let tmp_path = PathBuf::from(format!("{}.tmp", path.display()));
60        let json = serde_json::to_string_pretty(self)
61            .context("failed to serialize daemon state to JSON")?;
62        write_tmp_then_rename(&tmp_path, path, json.as_bytes())
63    }
64
65    /// True when the running daemon was started by a binary whose version
66    /// differs from this binary (or predates version tracking). Used to warn
67    /// the user that the daemon should be restarted.
68    pub fn version_mismatch(&self) -> bool {
69        self.version != CURRENT_VERSION
70    }
71
72    /// Exact-match lookup. No URL normalization.
73    pub fn find_proxy(&self, provider: &str, upstream: &str) -> Option<&ProxyEntry> {
74        self.proxies
75            .iter()
76            .find(|entry| entry.provider == provider && entry.upstream == upstream)
77    }
78}
79
80fn write_tmp_then_rename(tmp_path: &Path, final_path: &Path, bytes: &[u8]) -> Result<()> {
81    {
82        let mut file = open_tmp_for_write(tmp_path)?;
83        file.write_all(bytes)
84            .with_context(|| format!("failed to write daemon state to {}", tmp_path.display()))?;
85        file.sync_all()
86            .with_context(|| format!("failed to fsync daemon state at {}", tmp_path.display()))?;
87    }
88    std::fs::rename(tmp_path, final_path).with_context(|| {
89        format!(
90            "failed to rename {} -> {}",
91            tmp_path.display(),
92            final_path.display()
93        )
94    })
95}
96
97#[cfg(unix)]
98fn open_tmp_for_write(tmp_path: &Path) -> Result<File> {
99    use std::os::unix::fs::OpenOptionsExt;
100    OpenOptions::new()
101        .write(true)
102        .create(true)
103        .truncate(true)
104        .mode(0o600)
105        .open(tmp_path)
106        .with_context(|| format!("failed to open {} for write", tmp_path.display()))
107}
108
109#[cfg(not(unix))]
110fn open_tmp_for_write(tmp_path: &Path) -> Result<File> {
111    OpenOptions::new()
112        .write(true)
113        .create(true)
114        .truncate(true)
115        .open(tmp_path)
116        .with_context(|| format!("failed to open {} for write", tmp_path.display()))
117}
118
119#[cfg(test)]
120mod tests {
121    use super::{DaemonState, ProxyEntry};
122    use std::path::PathBuf;
123    use tempfile::TempDir;
124
125    fn sample_proxy(provider: &str, upstream: &str, proxy_port: u16) -> ProxyEntry {
126        ProxyEntry {
127            provider: provider.to_owned(),
128            upstream: upstream.to_owned(),
129            proxy_port,
130            api_port: Some(9000),
131            data_dir: PathBuf::from("/tmp/ccs"),
132            started_at: "2026-05-28T00:00:00Z".to_owned(),
133            restart_count: 0,
134        }
135    }
136
137    fn sample_state(proxies: Vec<ProxyEntry>) -> DaemonState {
138        DaemonState {
139            schema_version: 2,
140            version: super::CURRENT_VERSION.to_owned(),
141            pid: 4242,
142            started_at: "2026-05-28T00:00:00Z".to_owned(),
143            stopped_at: None,
144            data_root: PathBuf::from("/tmp/ccs"),
145            agg_port: None,
146            proxies,
147        }
148    }
149
150    #[test]
151    fn load_save_round_trip() {
152        let dir = TempDir::new().unwrap();
153        let path = dir.path().join("state.json");
154        let state = sample_state(vec![
155            sample_proxy("claude", "https://api.anthropic.com", 8080),
156            sample_proxy("codex", "https://api.openai.com", 8081),
157        ]);
158        state.save(&path).unwrap();
159        let loaded = DaemonState::load(&path).unwrap().expect("file exists");
160        assert_eq!(state, loaded);
161    }
162
163    #[test]
164    fn load_save_round_trip_with_none_ports() {
165        // Regression: api_port/agg_port use skip_serializing_if, so when they are
166        // None the fields are omitted on save. Without #[serde(default)] the
167        // reload fails with "missing field". Guard the None path explicitly.
168        let dir = TempDir::new().unwrap();
169        let path = dir.path().join("state.json");
170        let mut proxy = sample_proxy("claude", "https://api.anthropic.com", 8080);
171        proxy.api_port = None;
172        let mut state = sample_state(vec![proxy]);
173        state.agg_port = None;
174        state.save(&path).unwrap();
175        let loaded = DaemonState::load(&path).unwrap().expect("file exists");
176        assert_eq!(state, loaded);
177    }
178
179    #[test]
180    fn version_mismatch_detection() {
181        let mut state = sample_state(vec![]);
182        // sample_state stamps CURRENT_VERSION → matches.
183        assert!(!state.version_mismatch());
184        state.version = "0.0.1-old".to_owned();
185        assert!(state.version_mismatch());
186        // Pre-version state files deserialize to "" → treated as a mismatch.
187        state.version = String::new();
188        assert!(state.version_mismatch());
189    }
190
191    #[test]
192    fn load_pre_version_state_defaults_version_empty() {
193        // A state file written before version tracking has no `version` key.
194        let json = r#"{
195            "schema_version": 2,
196            "pid": 100,
197            "started_at": "2026-05-28T00:00:00Z",
198            "stopped_at": null,
199            "data_root": "/tmp/ccs",
200            "proxies": []
201        }"#;
202        let dir = TempDir::new().unwrap();
203        let path = dir.path().join("state.json");
204        std::fs::write(&path, json).unwrap();
205        let loaded = DaemonState::load(&path).unwrap().expect("file exists");
206        assert_eq!(loaded.version, "");
207        assert!(loaded.version_mismatch());
208    }
209
210    #[test]
211    fn load_missing_file_returns_none() {
212        let dir = TempDir::new().unwrap();
213        let path = dir.path().join("does_not_exist.json");
214        assert!(DaemonState::load(&path).unwrap().is_none());
215    }
216
217    #[test]
218    fn load_corrupt_json_returns_err_with_path() {
219        let dir = TempDir::new().unwrap();
220        let path = dir.path().join("corrupt.json");
221        std::fs::write(&path, "{not json").unwrap();
222        let err = DaemonState::load(&path).unwrap_err();
223        let rendered = format!("{err:#}");
224        assert!(
225            rendered.contains(path.to_string_lossy().as_ref()),
226            "error message should contain path; got: {rendered}"
227        );
228    }
229
230    #[test]
231    fn find_proxy_exact_match() {
232        let entry = sample_proxy("claude", "https://api.anthropic.com", 8080);
233        let state = sample_state(vec![entry.clone()]);
234        assert_eq!(
235            state.find_proxy("claude", "https://api.anthropic.com"),
236            Some(&entry)
237        );
238        assert_eq!(
239            state.find_proxy("claude", "https://api.anthropic.com/"),
240            None
241        );
242        assert_eq!(state.find_proxy("codex", "https://api.anthropic.com"), None);
243    }
244
245    #[test]
246    fn save_atomic_no_partial_file() {
247        let dir = TempDir::new().unwrap();
248        let path = dir.path().join("state.json");
249        let first = sample_state(vec![sample_proxy("claude", "https://a.example", 8080)]);
250        first.save(&path).unwrap();
251        let second = sample_state(vec![sample_proxy("codex", "https://b.example", 8081)]);
252        second.save(&path).unwrap();
253
254        let loaded = DaemonState::load(&path).unwrap().expect("file exists");
255        assert_eq!(second, loaded);
256
257        let tmp_path = PathBuf::from(format!("{}.tmp", path.display()));
258        assert!(
259            !tmp_path.exists(),
260            "temp file {tmp_path:?} should be renamed away after save"
261        );
262    }
263
264    #[cfg(unix)]
265    #[test]
266    fn save_sets_unix_0600_permissions() {
267        use std::os::unix::fs::PermissionsExt;
268        let dir = TempDir::new().unwrap();
269        let path = dir.path().join("state.json");
270        let state = sample_state(vec![]);
271        state.save(&path).unwrap();
272        let mode = std::fs::metadata(&path).unwrap().permissions().mode();
273        assert_eq!(mode & 0o777, 0o600, "expected 0600, got {:o}", mode & 0o777);
274    }
275}