1use std::fs;
9use std::path::{Path, PathBuf};
10
11use crate::cache::atomic_write;
12use crate::error::{AppError, Result};
13use crate::vendor::VendorId;
14
15fn state_dir() -> Result<PathBuf> {
16 let base = directories::BaseDirs::new()
17 .ok_or_else(|| AppError::Other("could not resolve XDG cache dir".into()))?;
18 Ok(base.cache_dir().join("ai-usagebar"))
19}
20
21fn state_path() -> Result<PathBuf> {
22 Ok(state_dir()?.join("active_vendor"))
23}
24
25pub fn read() -> Option<VendorId> {
28 read_from(&state_path().ok()?)
29}
30
31pub fn read_from(path: &Path) -> Option<VendorId> {
35 let raw = fs::read_to_string(path).ok()?;
36 parse_slug(raw.trim())
37}
38
39pub fn write(vendor: VendorId) -> Result<()> {
41 write_to(&state_path()?, vendor)
42}
43
44pub fn write_to(path: &Path, vendor: VendorId) -> Result<()> {
47 atomic_write(path, vendor.slug().as_bytes())
48}
49
50pub fn cycle(enabled: &[VendorId], start: VendorId, delta: i32) -> Result<VendorId> {
54 cycle_at(&state_path()?, enabled, start, delta)
55}
56
57pub fn cycle_at(
61 path: &Path,
62 enabled: &[VendorId],
63 start: VendorId,
64 delta: i32,
65) -> Result<VendorId> {
66 if enabled.is_empty() {
67 return Err(AppError::Other("no enabled vendors to cycle".into()));
68 }
69 let current = read_from(path)
70 .filter(|v| enabled.contains(v))
71 .unwrap_or(start);
72 let cur_idx = enabled.iter().position(|v| *v == current).unwrap_or(0);
73 let n = enabled.len() as i32;
74 let next_idx = ((cur_idx as i32 + delta).rem_euclid(n)) as usize;
75 let next = enabled[next_idx];
76 write_to(path, next)?;
77 Ok(next)
78}
79
80fn parse_slug(s: &str) -> Option<VendorId> {
81 match s {
82 "anthropic" => Some(VendorId::Anthropic),
83 "openai" => Some(VendorId::Openai),
84 "zai" => Some(VendorId::Zai),
85 "openrouter" => Some(VendorId::Openrouter),
86 "deepseek" => Some(VendorId::Deepseek),
87 _ => None,
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use tempfile::TempDir;
95
96 const ALL_FOUR: [VendorId; 4] = [
97 VendorId::Anthropic,
98 VendorId::Openai,
99 VendorId::Zai,
100 VendorId::Openrouter,
101 ];
102
103 #[test]
104 fn parse_slug_round_trip() {
105 for id in VendorId::all() {
106 assert_eq!(parse_slug(id.slug()), Some(*id));
107 }
108 }
109
110 #[test]
111 fn parse_slug_unknown_returns_none() {
112 assert!(parse_slug("not-a-vendor").is_none());
113 assert!(parse_slug("").is_none());
114 }
115
116 #[test]
117 fn read_from_missing_or_garbage_returns_none() {
118 let td = TempDir::new().unwrap();
119 assert!(read_from(&td.path().join("active_vendor")).is_none());
121 let path = td.path().join("active_vendor");
123 write_to(&path, VendorId::Zai).unwrap();
124 assert_eq!(read_from(&path), Some(VendorId::Zai));
125 fs::write(&path, "not-a-vendor").unwrap();
127 assert!(read_from(&path).is_none());
128 }
129
130 #[test]
131 fn cycle_at_persists_state_across_calls() {
132 let td = TempDir::new().unwrap();
133 let path = td.path().join("active_vendor");
134
135 let v = cycle_at(&path, &ALL_FOUR, VendorId::Anthropic, 1).unwrap();
137 assert_eq!(v, VendorId::Openai);
138 assert_eq!(read_from(&path), Some(VendorId::Openai));
139
140 let v = cycle_at(&path, &ALL_FOUR, VendorId::Anthropic, 1).unwrap();
142 assert_eq!(v, VendorId::Zai);
143 assert_eq!(read_from(&path), Some(VendorId::Zai));
144 }
145
146 #[test]
147 fn cycle_at_wraps_forward_and_backward() {
148 let td = TempDir::new().unwrap();
149 let path = td.path().join("active_vendor");
150 write_to(&path, VendorId::Anthropic).unwrap();
151
152 assert_eq!(
154 cycle_at(&path, &ALL_FOUR, VendorId::Anthropic, -1).unwrap(),
155 VendorId::Openrouter
156 );
157 assert_eq!(
159 cycle_at(&path, &ALL_FOUR, VendorId::Anthropic, 1).unwrap(),
160 VendorId::Anthropic
161 );
162 }
163
164 #[test]
165 fn cycle_at_ignores_persisted_vendor_not_in_enabled_set() {
166 let td = TempDir::new().unwrap();
167 let path = td.path().join("active_vendor");
168 write_to(&path, VendorId::Deepseek).unwrap();
170 let enabled = [VendorId::Anthropic, VendorId::Openai];
171 let v = cycle_at(&path, &enabled, VendorId::Openai, 1).unwrap();
173 assert_eq!(v, VendorId::Anthropic);
174 }
175
176 #[test]
177 fn cycle_at_errors_on_empty_enabled() {
178 let td = TempDir::new().unwrap();
179 let path = td.path().join("active_vendor");
180 let res = cycle_at(&path, &[], VendorId::Anthropic, 1);
181 assert!(matches!(res, Err(AppError::Other(_))));
182 }
183}