Skip to main content

colab_cli/server/
storage.rs

1use std::path::PathBuf;
2use std::sync::Mutex;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::client::api::{Shape, Variant};
9use crate::error::{ColabError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct StoredServer {
13    pub id: Uuid,
14    pub label: String,
15    pub variant: Variant,
16    pub accelerator: Option<String>,
17    #[serde(default)]
18    pub shape: Shape,
19    pub endpoint: String,
20    pub proxy_url: String,
21    pub proxy_token: String,
22    pub token_expires_at: DateTime<Utc>,
23    pub date_assigned: DateTime<Utc>,
24}
25
26pub struct ServerStorage {
27    path: PathBuf,
28    // memo so a single command's list() calls don't re-parse the file
29    cache: Mutex<Option<Vec<StoredServer>>>,
30}
31
32impl ServerStorage {
33    pub fn new(path: PathBuf) -> Self {
34        Self {
35            path,
36            cache: Mutex::new(None),
37        }
38    }
39
40    pub fn list(&self) -> Result<Vec<StoredServer>> {
41        if let Some(cached) = self
42            .cache
43            .lock()
44            .expect("server storage cache poisoned")
45            .as_ref()
46        {
47            return Ok(cached.clone());
48        }
49        let fresh = self.read_from_disk()?;
50        *self.cache.lock().expect("server storage cache poisoned") = Some(fresh.clone());
51        Ok(fresh)
52    }
53
54    fn read_from_disk(&self) -> Result<Vec<StoredServer>> {
55        match std::fs::read_to_string(&self.path) {
56            Ok(json) => Ok(serde_json::from_str(&json)?),
57            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(vec![]),
58            Err(e) => Err(ColabError::Io(e)),
59        }
60    }
61
62    pub fn get(&self, id: Uuid) -> Result<Option<StoredServer>> {
63        Ok(self.list()?.into_iter().find(|s| s.id == id))
64    }
65
66    pub fn get_by_endpoint(&self, endpoint: &str) -> Result<Option<StoredServer>> {
67        Ok(self.list()?.into_iter().find(|s| s.endpoint == endpoint))
68    }
69
70    pub fn upsert(&self, server: StoredServer) -> Result<()> {
71        let mut servers = self.list()?;
72        let pos = servers.iter().position(|s| s.id == server.id);
73        match pos {
74            Some(i) => {
75                let original_date = servers[i].date_assigned;
76                servers[i] = StoredServer {
77                    date_assigned: original_date,
78                    ..server
79                };
80            }
81            None => servers.push(server),
82        }
83        self.write(&servers)
84    }
85
86    pub fn remove(&self, id: Uuid) -> Result<bool> {
87        let mut servers = self.list()?;
88        let len_before = servers.len();
89        servers.retain(|s| s.id != id);
90        if servers.len() == len_before {
91            return Ok(false);
92        }
93        self.write(&servers)?;
94        Ok(true)
95    }
96
97    pub fn reconcile(
98        &self,
99        live_endpoints: &std::collections::HashSet<String>,
100    ) -> Result<Vec<StoredServer>> {
101        let servers = self.list()?;
102        let (keep, removed): (Vec<_>, Vec<_>) = servers
103            .into_iter()
104            .partition(|s| live_endpoints.contains(&s.endpoint));
105        if !removed.is_empty() {
106            self.write(&keep)?;
107        }
108        Ok(removed)
109    }
110
111    fn write(&self, servers: &[StoredServer]) -> Result<()> {
112        let mut sorted = servers.to_vec();
113        sorted.sort_by_key(|s| s.id);
114        let json = serde_json::to_string_pretty(&sorted)?;
115        let tmp = self.path.with_extension("json.tmp");
116        std::fs::write(&tmp, &json)?;
117        std::fs::rename(&tmp, &self.path)?;
118        // keep the cache in sync with what we just wrote
119        *self.cache.lock().expect("server storage cache poisoned") = Some(sorted);
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use tempfile::tempdir;
128
129    fn sample(id: Uuid, label: &str, endpoint: &str) -> StoredServer {
130        StoredServer {
131            id,
132            label: label.into(),
133            variant: Variant::Gpu,
134            accelerator: Some("T4".into()),
135            shape: Shape::HighMem,
136            endpoint: endpoint.into(),
137            proxy_url: "https://p.example".into(),
138            proxy_token: "tok".into(),
139            token_expires_at: Utc::now(),
140            date_assigned: Utc::now(),
141        }
142    }
143
144    #[test]
145    fn upsert_insert_then_update_preserves_date_assigned() {
146        let dir = tempdir().unwrap();
147        let storage = ServerStorage::new(dir.path().join("servers.json"));
148        let id = Uuid::new_v4();
149
150        let first = sample(id, "a", "ep-1");
151        let original_date = first.date_assigned;
152        storage.upsert(first).unwrap();
153
154        let mut second = sample(id, "renamed", "ep-1");
155        second.date_assigned = Utc::now() + chrono::Duration::hours(1);
156        storage.upsert(second).unwrap();
157
158        let listed = storage.list().unwrap();
159        assert_eq!(listed.len(), 1);
160        assert_eq!(listed[0].label, "renamed");
161        assert_eq!(listed[0].date_assigned, original_date);
162    }
163
164    #[test]
165    fn remove_reports_existence() {
166        let dir = tempdir().unwrap();
167        let storage = ServerStorage::new(dir.path().join("servers.json"));
168        let id = Uuid::new_v4();
169        storage.upsert(sample(id, "a", "ep")).unwrap();
170        assert!(storage.remove(id).unwrap());
171        assert!(!storage.remove(id).unwrap());
172    }
173
174    #[test]
175    fn reconcile_drops_stale_servers() {
176        let dir = tempdir().unwrap();
177        let storage = ServerStorage::new(dir.path().join("servers.json"));
178        storage
179            .upsert(sample(Uuid::new_v4(), "alive", "live-ep"))
180            .unwrap();
181        storage
182            .upsert(sample(Uuid::new_v4(), "stale", "dead-ep"))
183            .unwrap();
184
185        let mut live = std::collections::HashSet::new();
186        live.insert("live-ep".to_string());
187        let removed = storage.reconcile(&live).unwrap();
188        assert_eq!(removed.len(), 1);
189        assert_eq!(removed[0].endpoint, "dead-ep");
190
191        let remaining = storage.list().unwrap();
192        assert_eq!(remaining.len(), 1);
193        assert_eq!(remaining[0].endpoint, "live-ep");
194    }
195
196    #[test]
197    fn list_returns_empty_when_file_missing() {
198        let dir = tempdir().unwrap();
199        let storage = ServerStorage::new(dir.path().join("missing.json"));
200        assert!(storage.list().unwrap().is_empty());
201    }
202
203    #[test]
204    fn shape_round_trips_through_json() {
205        let dir = tempdir().unwrap();
206        let storage = ServerStorage::new(dir.path().join("servers.json"));
207        let id = Uuid::new_v4();
208        let mut s = sample(id, "hm", "ep");
209        s.shape = Shape::HighMem;
210        storage.upsert(s).unwrap();
211        let loaded = storage.get(id).unwrap().unwrap();
212        assert_eq!(loaded.shape, Shape::HighMem);
213    }
214}