1use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
8struct StateFile {
9 #[serde(default)]
10 current_context: Option<String>,
11}
12
13pub struct State {
14 path: PathBuf,
15 current_context: Option<String>,
16}
17
18impl State {
19 pub fn load(base: &Path) -> Result<Self> {
24 let path = base.join("state.toml");
25 if path.exists() {
26 let raw = std::fs::read_to_string(&path)
27 .with_context(|| format!("reading {}", path.display()))?;
28 let file: StateFile = toml::from_str(&raw)
29 .with_context(|| format!("invalid state.toml at {}", path.display()))?;
30 return Ok(Self {
31 path,
32 current_context: file.current_context,
33 });
34 }
35
36 let mut state = Self {
37 path,
38 current_context: None,
39 };
40 if let Some(legacy) = legacy_context_from_config(base)? {
41 state.current_context = Some(legacy);
42 state.save()?;
43 }
44 Ok(state)
45 }
46
47 pub fn current_context(&self) -> Option<&str> {
48 self.current_context.as_deref()
49 }
50
51 pub fn set_current_context(&mut self, name: Option<String>) -> Result<()> {
52 if self.current_context == name {
53 return Ok(());
54 }
55 self.current_context = name;
56 self.save()
57 }
58
59 fn save(&self) -> Result<()> {
60 let parent = self.path.parent().ok_or_else(|| {
61 anyhow::anyhow!("path has no parent directory: {}", self.path.display())
62 })?;
63 std::fs::create_dir_all(parent)?;
64 let file = StateFile {
65 current_context: self.current_context.clone(),
66 };
67 std::fs::write(&self.path, toml::to_string_pretty(&file)?)?;
68 Ok(())
69 }
70}
71
72fn legacy_context_from_config(base: &Path) -> Result<Option<String>> {
73 let config_path = base.join("config.toml");
74 let raw = match std::fs::read_to_string(&config_path) {
75 Ok(s) => s,
76 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
77 Err(e) => return Err(e).context(format!("reading {}", config_path.display())),
78 };
79 let table: toml::Value = toml::from_str(&raw)
80 .with_context(|| format!("invalid config.toml at {}", config_path.display()))?;
81 Ok(table
82 .get("current_context")
83 .or_else(|| table.get("default_context"))
84 .and_then(|v| v.as_str())
85 .map(str::to_string))
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn empty_when_file_missing_and_no_legacy_config() {
94 let tmp = tempfile::tempdir().unwrap();
95 let state = State::load(tmp.path()).unwrap();
96 assert_eq!(state.current_context(), None);
97 assert!(!tmp.path().join("state.toml").exists());
98 }
99
100 #[test]
101 fn set_persists_current_context() {
102 let tmp = tempfile::tempdir().unwrap();
103 let mut state = State::load(tmp.path()).unwrap();
104 state.set_current_context(Some("work".into())).unwrap();
105 let loaded = State::load(tmp.path()).unwrap();
106 assert_eq!(loaded.current_context(), Some("work"));
107 }
108
109 #[test]
110 fn clear_removes_current_context() {
111 let tmp = tempfile::tempdir().unwrap();
112 let mut state = State::load(tmp.path()).unwrap();
113 state.set_current_context(Some("work".into())).unwrap();
114 state.set_current_context(None).unwrap();
115 let loaded = State::load(tmp.path()).unwrap();
116 assert_eq!(loaded.current_context(), None);
117 }
118
119 #[test]
120 fn migrates_legacy_default_context_from_config() {
121 let tmp = tempfile::tempdir().unwrap();
122 std::fs::write(
123 tmp.path().join("config.toml"),
124 r#"
125roots = []
126
127default_context = "work"
128"#,
129 )
130 .unwrap();
131 let state = State::load(tmp.path()).unwrap();
132 assert_eq!(state.current_context(), Some("work"));
133 assert!(tmp.path().join("state.toml").exists());
134 }
135}