colab_cli/server/
storage.rs1use std::path::PathBuf;
2use std::sync::Mutex;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::client::api::{Shape, Variant};
9use crate::error::{ColabError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct StoredServer {
13 pub id: Uuid,
14 pub label: String,
15 pub variant: Variant,
16 pub accelerator: Option<String>,
17 #[serde(default)]
18 pub shape: Shape,
19 pub endpoint: String,
20 pub proxy_url: String,
21 pub proxy_token: String,
22 pub token_expires_at: DateTime<Utc>,
23 pub date_assigned: DateTime<Utc>,
24}
25
26pub struct ServerStorage {
27 path: PathBuf,
28 cache: Mutex<Option<Vec<StoredServer>>>,
30}
31
32impl ServerStorage {
33 pub fn new(path: PathBuf) -> Self {
34 Self {
35 path,
36 cache: Mutex::new(None),
37 }
38 }
39
40 pub fn list(&self) -> Result<Vec<StoredServer>> {
41 if let Some(cached) = self
42 .cache
43 .lock()
44 .expect("server storage cache poisoned")
45 .as_ref()
46 {
47 return Ok(cached.clone());
48 }
49 let fresh = self.read_from_disk()?;
50 *self.cache.lock().expect("server storage cache poisoned") = Some(fresh.clone());
51 Ok(fresh)
52 }
53
54 fn read_from_disk(&self) -> Result<Vec<StoredServer>> {
55 match std::fs::read_to_string(&self.path) {
56 Ok(json) => Ok(serde_json::from_str(&json)?),
57 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(vec![]),
58 Err(e) => Err(ColabError::Io(e)),
59 }
60 }
61
62 pub fn get(&self, id: Uuid) -> Result<Option<StoredServer>> {
63 Ok(self.list()?.into_iter().find(|s| s.id == id))
64 }
65
66 pub fn get_by_endpoint(&self, endpoint: &str) -> Result<Option<StoredServer>> {
67 Ok(self.list()?.into_iter().find(|s| s.endpoint == endpoint))
68 }
69
70 pub fn upsert(&self, server: StoredServer) -> Result<()> {
71 let mut servers = self.list()?;
72 let pos = servers.iter().position(|s| s.id == server.id);
73 match pos {
74 Some(i) => {
75 let original_date = servers[i].date_assigned;
76 servers[i] = StoredServer {
77 date_assigned: original_date,
78 ..server
79 };
80 }
81 None => servers.push(server),
82 }
83 self.write(&servers)
84 }
85
86 pub fn remove(&self, id: Uuid) -> Result<bool> {
87 let mut servers = self.list()?;
88 let len_before = servers.len();
89 servers.retain(|s| s.id != id);
90 if servers.len() == len_before {
91 return Ok(false);
92 }
93 self.write(&servers)?;
94 Ok(true)
95 }
96
97 pub fn reconcile(
98 &self,
99 live_endpoints: &std::collections::HashSet<String>,
100 ) -> Result<Vec<StoredServer>> {
101 let servers = self.list()?;
102 let (keep, removed): (Vec<_>, Vec<_>) = servers
103 .into_iter()
104 .partition(|s| live_endpoints.contains(&s.endpoint));
105 if !removed.is_empty() {
106 self.write(&keep)?;
107 }
108 Ok(removed)
109 }
110
111 fn write(&self, servers: &[StoredServer]) -> Result<()> {
112 let mut sorted = servers.to_vec();
113 sorted.sort_by_key(|s| s.id);
114 let json = serde_json::to_string_pretty(&sorted)?;
115 let tmp = self.path.with_extension("json.tmp");
116 std::fs::write(&tmp, &json)?;
117 std::fs::rename(&tmp, &self.path)?;
118 *self.cache.lock().expect("server storage cache poisoned") = Some(sorted);
120 Ok(())
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use tempfile::tempdir;
128
129 fn sample(id: Uuid, label: &str, endpoint: &str) -> StoredServer {
130 StoredServer {
131 id,
132 label: label.into(),
133 variant: Variant::Gpu,
134 accelerator: Some("T4".into()),
135 shape: Shape::HighMem,
136 endpoint: endpoint.into(),
137 proxy_url: "https://p.example".into(),
138 proxy_token: "tok".into(),
139 token_expires_at: Utc::now(),
140 date_assigned: Utc::now(),
141 }
142 }
143
144 #[test]
145 fn upsert_insert_then_update_preserves_date_assigned() {
146 let dir = tempdir().unwrap();
147 let storage = ServerStorage::new(dir.path().join("servers.json"));
148 let id = Uuid::new_v4();
149
150 let first = sample(id, "a", "ep-1");
151 let original_date = first.date_assigned;
152 storage.upsert(first).unwrap();
153
154 let mut second = sample(id, "renamed", "ep-1");
155 second.date_assigned = Utc::now() + chrono::Duration::hours(1);
156 storage.upsert(second).unwrap();
157
158 let listed = storage.list().unwrap();
159 assert_eq!(listed.len(), 1);
160 assert_eq!(listed[0].label, "renamed");
161 assert_eq!(listed[0].date_assigned, original_date);
162 }
163
164 #[test]
165 fn remove_reports_existence() {
166 let dir = tempdir().unwrap();
167 let storage = ServerStorage::new(dir.path().join("servers.json"));
168 let id = Uuid::new_v4();
169 storage.upsert(sample(id, "a", "ep")).unwrap();
170 assert!(storage.remove(id).unwrap());
171 assert!(!storage.remove(id).unwrap());
172 }
173
174 #[test]
175 fn reconcile_drops_stale_servers() {
176 let dir = tempdir().unwrap();
177 let storage = ServerStorage::new(dir.path().join("servers.json"));
178 storage
179 .upsert(sample(Uuid::new_v4(), "alive", "live-ep"))
180 .unwrap();
181 storage
182 .upsert(sample(Uuid::new_v4(), "stale", "dead-ep"))
183 .unwrap();
184
185 let mut live = std::collections::HashSet::new();
186 live.insert("live-ep".to_string());
187 let removed = storage.reconcile(&live).unwrap();
188 assert_eq!(removed.len(), 1);
189 assert_eq!(removed[0].endpoint, "dead-ep");
190
191 let remaining = storage.list().unwrap();
192 assert_eq!(remaining.len(), 1);
193 assert_eq!(remaining[0].endpoint, "live-ep");
194 }
195
196 #[test]
197 fn list_returns_empty_when_file_missing() {
198 let dir = tempdir().unwrap();
199 let storage = ServerStorage::new(dir.path().join("missing.json"));
200 assert!(storage.list().unwrap().is_empty());
201 }
202
203 #[test]
204 fn shape_round_trips_through_json() {
205 let dir = tempdir().unwrap();
206 let storage = ServerStorage::new(dir.path().join("servers.json"));
207 let id = Uuid::new_v4();
208 let mut s = sample(id, "hm", "ep");
209 s.shape = Shape::HighMem;
210 storage.upsert(s).unwrap();
211 let loaded = storage.get(id).unwrap().unwrap();
212 assert_eq!(loaded.shape, Shape::HighMem);
213 }
214}