Skip to main content

entrenar/server/
state.rs

1//! Server application state
2//!
3//! Shared state for the tracking server with thread-safe storage.
4
5use crate::server::{ExperimentResponse, Result, RunResponse, ServerConfig, ServerError};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::Instant;
11
12/// Experiment data
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Experiment {
15    pub id: String,
16    pub name: String,
17    pub description: Option<String>,
18    pub created_at: DateTime<Utc>,
19    pub tags: HashMap<String, String>,
20}
21
22impl From<Experiment> for ExperimentResponse {
23    fn from(exp: Experiment) -> Self {
24        Self {
25            id: exp.id,
26            name: exp.name,
27            description: exp.description,
28            created_at: exp.created_at.to_rfc3339(),
29            tags: exp.tags,
30        }
31    }
32}
33
34/// Run status
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum RunStatus {
37    Running,
38    Completed,
39    Failed,
40    Killed,
41}
42
43impl std::fmt::Display for RunStatus {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            RunStatus::Running => write!(f, "running"),
47            RunStatus::Completed => write!(f, "completed"),
48            RunStatus::Failed => write!(f, "failed"),
49            RunStatus::Killed => write!(f, "killed"),
50        }
51    }
52}
53
54impl std::str::FromStr for RunStatus {
55    type Err = ServerError;
56
57    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
58        match s.to_lowercase().as_str() {
59            "running" => Ok(RunStatus::Running),
60            "completed" => Ok(RunStatus::Completed),
61            "failed" => Ok(RunStatus::Failed),
62            "killed" => Ok(RunStatus::Killed),
63            _ => Err(ServerError::Validation(format!("Invalid status: {s}"))),
64        }
65    }
66}
67
68/// Run data
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Run {
71    pub id: String,
72    pub experiment_id: String,
73    pub name: Option<String>,
74    pub status: RunStatus,
75    pub start_time: DateTime<Utc>,
76    pub end_time: Option<DateTime<Utc>>,
77    pub params: HashMap<String, serde_json::Value>,
78    pub metrics: HashMap<String, f64>,
79    pub tags: HashMap<String, String>,
80}
81
82impl From<Run> for RunResponse {
83    fn from(run: Run) -> Self {
84        Self {
85            id: run.id,
86            experiment_id: run.experiment_id,
87            name: run.name,
88            status: run.status.to_string(),
89            start_time: run.start_time.to_rfc3339(),
90            end_time: run.end_time.map(|t| t.to_rfc3339()),
91            params: run.params,
92            metrics: run.metrics,
93            tags: run.tags,
94        }
95    }
96}
97
98/// In-memory storage for experiments and runs
99#[derive(Debug, Default)]
100pub struct InMemoryStorage {
101    experiments: RwLock<HashMap<String, Experiment>>,
102    runs: RwLock<HashMap<String, Run>>,
103    counter: RwLock<u64>,
104}
105
106impl InMemoryStorage {
107    pub fn new() -> Self {
108        Self::default()
109    }
110
111    /// Generate a unique ID
112    pub fn generate_id(&self, prefix: &str) -> String {
113        let mut counter = self.counter.write().expect("counter RwLock must not be poisoned");
114        *counter += 1;
115        format!("{}-{:08x}", prefix, *counter)
116    }
117
118    /// Create a new experiment
119    pub fn create_experiment(
120        &self,
121        name: &str,
122        description: Option<String>,
123        tags: Option<HashMap<String, String>>,
124    ) -> Result<Experiment> {
125        let id = self.generate_id("exp");
126        let experiment = Experiment {
127            id: id.clone(),
128            name: name.to_string(),
129            description,
130            created_at: Utc::now(),
131            tags: tags.unwrap_or_default(),
132        };
133
134        let mut experiments = self
135            .experiments
136            .write()
137            .map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
138        experiments.insert(id, experiment.clone());
139
140        Ok(experiment)
141    }
142
143    /// Get an experiment by ID
144    pub fn get_experiment(&self, id: &str) -> Result<Experiment> {
145        let experiments = self
146            .experiments
147            .read()
148            .map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
149
150        experiments
151            .get(id)
152            .cloned()
153            .ok_or_else(|| ServerError::NotFound(format!("Experiment not found: {id}")))
154    }
155
156    /// List all experiments
157    pub fn list_experiments(&self) -> Result<Vec<Experiment>> {
158        let experiments = self
159            .experiments
160            .read()
161            .map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
162
163        Ok(experiments.values().cloned().collect())
164    }
165
166    /// Create a new run
167    pub fn create_run(
168        &self,
169        experiment_id: &str,
170        name: Option<String>,
171        tags: Option<HashMap<String, String>>,
172    ) -> Result<Run> {
173        // Verify experiment exists
174        self.get_experiment(experiment_id)?;
175
176        let id = self.generate_id("run");
177        let run = Run {
178            id: id.clone(),
179            experiment_id: experiment_id.to_string(),
180            name,
181            status: RunStatus::Running,
182            start_time: Utc::now(),
183            end_time: None,
184            params: HashMap::new(),
185            metrics: HashMap::new(),
186            tags: tags.unwrap_or_default(),
187        };
188
189        let mut runs =
190            self.runs.write().map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
191        runs.insert(id, run.clone());
192
193        Ok(run)
194    }
195
196    /// Get a run by ID
197    pub fn get_run(&self, id: &str) -> Result<Run> {
198        let runs =
199            self.runs.read().map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
200
201        runs.get(id).cloned().ok_or_else(|| ServerError::NotFound(format!("Run not found: {id}")))
202    }
203
204    /// Update run status
205    pub fn update_run(
206        &self,
207        id: &str,
208        status: Option<RunStatus>,
209        end_time: Option<DateTime<Utc>>,
210    ) -> Result<Run> {
211        let mut runs =
212            self.runs.write().map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
213
214        let run = runs
215            .get_mut(id)
216            .ok_or_else(|| ServerError::NotFound(format!("Run not found: {id}")))?;
217
218        if let Some(s) = status {
219            run.status = s;
220        }
221        if let Some(t) = end_time {
222            run.end_time = Some(t);
223        }
224
225        Ok(run.clone())
226    }
227
228    /// Log parameters for a run
229    pub fn log_params(
230        &self,
231        run_id: &str,
232        params: HashMap<String, serde_json::Value>,
233    ) -> Result<()> {
234        let mut runs =
235            self.runs.write().map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
236
237        let run = runs
238            .get_mut(run_id)
239            .ok_or_else(|| ServerError::NotFound(format!("Run not found: {run_id}")))?;
240
241        run.params.extend(params);
242        Ok(())
243    }
244
245    /// Log metrics for a run
246    pub fn log_metrics(&self, run_id: &str, metrics: HashMap<String, f64>) -> Result<()> {
247        let mut runs =
248            self.runs.write().map_err(|e| ServerError::Internal(format!("Lock error: {e}")))?;
249
250        let run = runs
251            .get_mut(run_id)
252            .ok_or_else(|| ServerError::NotFound(format!("Run not found: {run_id}")))?;
253
254        run.metrics.extend(metrics);
255        Ok(())
256    }
257
258    /// Count experiments
259    pub fn experiments_count(&self) -> usize {
260        self.experiments.read().map(|e| e.len()).unwrap_or(0)
261    }
262
263    /// Count runs
264    pub fn runs_count(&self) -> usize {
265        self.runs.read().map(|r| r.len()).unwrap_or(0)
266    }
267}
268
269/// Application state shared across handlers
270#[derive(Clone)]
271pub struct AppState {
272    pub storage: Arc<InMemoryStorage>,
273    pub config: ServerConfig,
274    pub start_time: Instant,
275}
276
277impl AppState {
278    pub fn new(config: ServerConfig) -> Self {
279        Self { storage: Arc::new(InMemoryStorage::new()), config, start_time: Instant::now() }
280    }
281
282    /// Get uptime in seconds
283    pub fn uptime_secs(&self) -> u64 {
284        self.start_time.elapsed().as_secs()
285    }
286
287    /// Create with an explicit start time (for deterministic testing)
288    #[cfg(test)]
289    pub fn with_start_time(config: ServerConfig, start_time: Instant) -> Self {
290        Self { storage: Arc::new(InMemoryStorage::new()), config, start_time }
291    }
292}
293
294// =============================================================================
295// Tests
296// =============================================================================
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_in_memory_storage_new() {
304        let storage = InMemoryStorage::new();
305        assert_eq!(storage.experiments_count(), 0);
306        assert_eq!(storage.runs_count(), 0);
307    }
308
309    #[test]
310    fn test_generate_id() {
311        let storage = InMemoryStorage::new();
312        let id1 = storage.generate_id("test");
313        let id2 = storage.generate_id("test");
314        assert!(id1.starts_with("test-"));
315        assert!(id2.starts_with("test-"));
316        assert_ne!(id1, id2);
317    }
318
319    #[test]
320    fn test_create_experiment() {
321        let storage = InMemoryStorage::new();
322        let exp = storage
323            .create_experiment("my-exp", Some("desc".into()), None)
324            .expect("operation should succeed");
325        assert!(exp.id.starts_with("exp-"));
326        assert_eq!(exp.name, "my-exp");
327        assert_eq!(storage.experiments_count(), 1);
328    }
329
330    #[test]
331    fn test_get_experiment() {
332        let storage = InMemoryStorage::new();
333        let exp = storage.create_experiment("test", None, None).expect("operation should succeed");
334        let retrieved = storage.get_experiment(&exp.id).expect("operation should succeed");
335        assert_eq!(retrieved.name, "test");
336    }
337
338    #[test]
339    fn test_get_experiment_not_found() {
340        let storage = InMemoryStorage::new();
341        let result = storage.get_experiment("nonexistent");
342        assert!(result.is_err());
343    }
344
345    #[test]
346    fn test_list_experiments() {
347        let storage = InMemoryStorage::new();
348        storage.create_experiment("exp1", None, None).expect("operation should succeed");
349        storage.create_experiment("exp2", None, None).expect("operation should succeed");
350        let list = storage.list_experiments().expect("operation should succeed");
351        assert_eq!(list.len(), 2);
352    }
353
354    #[test]
355    fn test_create_run() {
356        let storage = InMemoryStorage::new();
357        let exp = storage.create_experiment("test", None, None).expect("operation should succeed");
358        let run = storage
359            .create_run(&exp.id, Some("run-1".into()), None)
360            .expect("operation should succeed");
361        assert!(run.id.starts_with("run-"));
362        assert_eq!(run.experiment_id, exp.id);
363        assert_eq!(run.status, RunStatus::Running);
364    }
365
366    #[test]
367    fn test_create_run_invalid_experiment() {
368        let storage = InMemoryStorage::new();
369        let result = storage.create_run("nonexistent", None, None);
370        assert!(result.is_err());
371    }
372
373    #[test]
374    fn test_update_run() {
375        let storage = InMemoryStorage::new();
376        let exp = storage.create_experiment("test", None, None).expect("operation should succeed");
377        let run = storage.create_run(&exp.id, None, None).expect("operation should succeed");
378
379        let updated = storage
380            .update_run(&run.id, Some(RunStatus::Completed), None)
381            .expect("operation should succeed");
382        assert_eq!(updated.status, RunStatus::Completed);
383    }
384
385    #[test]
386    fn test_log_params() {
387        let storage = InMemoryStorage::new();
388        let exp = storage.create_experiment("test", None, None).expect("operation should succeed");
389        let run = storage.create_run(&exp.id, None, None).expect("operation should succeed");
390
391        let mut params = HashMap::new();
392        params.insert("lr".to_string(), serde_json::json!(0.001));
393        storage.log_params(&run.id, params).expect("operation should succeed");
394
395        let updated = storage.get_run(&run.id).expect("operation should succeed");
396        assert!(updated.params.contains_key("lr"));
397    }
398
399    #[test]
400    fn test_log_metrics() {
401        let storage = InMemoryStorage::new();
402        let exp = storage.create_experiment("test", None, None).expect("operation should succeed");
403        let run = storage.create_run(&exp.id, None, None).expect("operation should succeed");
404
405        let mut metrics = HashMap::new();
406        metrics.insert("loss".to_string(), 0.5);
407        storage.log_metrics(&run.id, metrics).expect("operation should succeed");
408
409        let updated = storage.get_run(&run.id).expect("operation should succeed");
410        assert_eq!(updated.metrics.get("loss"), Some(&0.5));
411    }
412
413    #[test]
414    fn test_run_status_from_str() {
415        assert_eq!(
416            "running".parse::<RunStatus>().expect("parsing should succeed"),
417            RunStatus::Running
418        );
419        assert_eq!(
420            "completed".parse::<RunStatus>().expect("parsing should succeed"),
421            RunStatus::Completed
422        );
423        assert_eq!(
424            "failed".parse::<RunStatus>().expect("parsing should succeed"),
425            RunStatus::Failed
426        );
427        assert_eq!(
428            "killed".parse::<RunStatus>().expect("parsing should succeed"),
429            RunStatus::Killed
430        );
431        assert!("invalid".parse::<RunStatus>().is_err());
432    }
433
434    #[test]
435    fn test_run_status_display() {
436        assert_eq!(RunStatus::Running.to_string(), "running");
437        assert_eq!(RunStatus::Completed.to_string(), "completed");
438    }
439
440    #[test]
441    fn test_app_state_new() {
442        let config = ServerConfig::default();
443        let state = AppState::new(config);
444        assert_eq!(state.storage.experiments_count(), 0);
445    }
446
447    #[test]
448    fn test_app_state_uptime_deterministic() {
449        // Use with_start_time to avoid flaky Instant::now() timing assertions
450        let config = ServerConfig::default();
451        let state = AppState::with_start_time(config, Instant::now());
452        // uptime_secs returns u64 truncated seconds; just verify it doesn't panic
453        let _uptime = state.uptime_secs();
454    }
455
456    #[test]
457    fn test_experiment_to_response() {
458        let exp = Experiment {
459            id: "exp-1".to_string(),
460            name: "test".to_string(),
461            description: None,
462            created_at: Utc::now(),
463            tags: HashMap::new(),
464        };
465        let resp: ExperimentResponse = exp.into();
466        assert_eq!(resp.id, "exp-1");
467    }
468
469    #[test]
470    fn test_run_to_response() {
471        let run = Run {
472            id: "run-1".to_string(),
473            experiment_id: "exp-1".to_string(),
474            name: None,
475            status: RunStatus::Running,
476            start_time: Utc::now(),
477            end_time: None,
478            params: HashMap::new(),
479            metrics: HashMap::new(),
480            tags: HashMap::new(),
481        };
482        let resp: RunResponse = run.into();
483        assert_eq!(resp.id, "run-1");
484        assert_eq!(resp.status, "running");
485    }
486}