Skip to main content

entrenar/storage/
mod.rs

1//! Experiment Storage Module (ENT-001)
2//!
3//! Provides the `ExperimentStorage` trait and backends for persisting
4//! experiment tracking data.
5//!
6//! # Backends
7//!
8//! - `TruenoBackend`: Production backend using trueno-db (feature: "monitor")
9//! - `InMemoryStorage`: In-memory backend for testing and fuzzing
10//!
11//! # Example
12//!
13//! ```
14//! use entrenar::storage::{ExperimentStorage, InMemoryStorage, RunStatus};
15//!
16//! let mut storage = InMemoryStorage::new();
17//! let exp_id = storage.create_experiment("my-experiment", None).expect("create experiment");
18//! let run_id = storage.create_run(&exp_id).expect("create run");
19//! storage.start_run(&run_id).expect("start run");
20//! storage.log_metric(&run_id, "loss", 0, 0.5).expect("log metric");
21//! storage.complete_run(&run_id, RunStatus::Success).expect("complete run");
22//! ```
23
24pub mod cloud;
25pub mod memory;
26pub mod preflight;
27pub mod registry;
28pub mod sqlite;
29#[cfg(feature = "monitor")]
30pub mod trueno;
31
32pub use cloud::{
33    ArtifactBackend, ArtifactMetadata, AzureConfig, BackendConfig, CloudError, GCSConfig,
34    InMemoryBackend, LocalBackend, MockS3Backend, S3Config,
35};
36pub use memory::InMemoryStorage;
37pub use preflight::{
38    CheckMetadata, CheckResult, CheckType, Preflight, PreflightCheck, PreflightContext,
39    PreflightError, PreflightResults,
40};
41pub use registry::{
42    Comparison, InMemoryRegistry, MetricRequirement, ModelRegistry, ModelStage, ModelVersion,
43    PolicyCheckResult, PromotionPolicy, RegistryError, StageTransition, VersionComparison,
44};
45pub use sqlite::{
46    ArtifactRef, Experiment, FilterOp, ParamFilter, ParameterValue, Run, SqliteBackend,
47};
48#[cfg(feature = "monitor")]
49pub use trueno::TruenoBackend;
50
51use chrono::{DateTime, Utc};
52use serde::{Deserialize, Serialize};
53use thiserror::Error;
54
55/// Storage errors
56#[derive(Debug, Error)]
57pub enum StorageError {
58    #[error("IO error: {0}")]
59    Io(#[from] std::io::Error),
60
61    #[error("Experiment not found: {0}")]
62    ExperimentNotFound(String),
63
64    #[error("Run not found: {0}")]
65    RunNotFound(String),
66
67    #[error("Invalid state transition: {0}")]
68    InvalidState(String),
69
70    #[error("Storage backend error: {0}")]
71    Backend(String),
72}
73
74/// Result type for storage operations
75pub type Result<T> = std::result::Result<T, StorageError>;
76
77/// Status of a run
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum RunStatus {
80    /// Run is created but not yet started
81    Pending,
82    /// Run is currently executing
83    Running,
84    /// Run completed successfully
85    Success,
86    /// Run failed with an error
87    Failed,
88    /// Run was cancelled
89    Cancelled,
90}
91
92/// A single metric data point
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94pub struct MetricPoint {
95    /// Training step
96    pub step: u64,
97    /// Metric value
98    pub value: f64,
99    /// Timestamp when recorded
100    pub timestamp: DateTime<Utc>,
101}
102
103impl MetricPoint {
104    /// Create a new metric point with current timestamp
105    pub fn new(step: u64, value: f64) -> Self {
106        Self { step, value, timestamp: Utc::now() }
107    }
108
109    /// Create a metric point with specific timestamp
110    pub fn with_timestamp(step: u64, value: f64, timestamp: DateTime<Utc>) -> Self {
111        Self { step, value, timestamp }
112    }
113}
114
115/// Trait for experiment storage backends
116///
117/// This trait abstracts over different storage implementations, allowing
118/// for production use with TruenoDB and testing with in-memory storage.
119///
120/// # Thread Safety
121///
122/// All implementations must be `Send + Sync` to support concurrent access
123/// from multiple training threads.
124pub trait ExperimentStorage: Send + Sync {
125    /// Create a new experiment
126    ///
127    /// # Arguments
128    ///
129    /// * `name` - Human-readable experiment name
130    /// * `config` - Optional JSON configuration for the experiment
131    ///
132    /// # Returns
133    ///
134    /// Unique experiment ID
135    fn create_experiment(
136        &mut self,
137        name: &str,
138        config: Option<serde_json::Value>,
139    ) -> Result<String>;
140
141    /// Create a new run within an experiment
142    ///
143    /// The run starts in `Pending` status.
144    ///
145    /// # Arguments
146    ///
147    /// * `experiment_id` - ID of the parent experiment
148    ///
149    /// # Returns
150    ///
151    /// Unique run ID
152    fn create_run(&mut self, experiment_id: &str) -> Result<String>;
153
154    /// Start a run, transitioning from Pending to Running
155    ///
156    /// # Arguments
157    ///
158    /// * `run_id` - ID of the run to start
159    fn start_run(&mut self, run_id: &str) -> Result<()>;
160
161    /// Complete a run with the given status
162    ///
163    /// # Arguments
164    ///
165    /// * `run_id` - ID of the run
166    /// * `status` - Final status (Success, Failed, or Cancelled)
167    fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()>;
168
169    /// Log a metric value for a run
170    ///
171    /// # Arguments
172    ///
173    /// * `run_id` - ID of the run
174    /// * `key` - Metric name (e.g., "loss", "accuracy")
175    /// * `step` - Training step or epoch number
176    /// * `value` - Metric value
177    fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()>;
178
179    /// Log an artifact for a run
180    ///
181    /// # Arguments
182    ///
183    /// * `run_id` - ID of the run
184    /// * `key` - Artifact name (e.g., "model.safetensors")
185    /// * `data` - Artifact data bytes
186    ///
187    /// # Returns
188    ///
189    /// Content-addressable hash of the artifact
190    fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String>;
191
192    /// Get metrics for a specific run and key
193    ///
194    /// # Arguments
195    ///
196    /// * `run_id` - ID of the run
197    /// * `key` - Metric name to retrieve
198    ///
199    /// # Returns
200    ///
201    /// Vector of metric points, ordered by step
202    fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>>;
203
204    /// Get the current status of a run
205    fn get_run_status(&self, run_id: &str) -> Result<RunStatus>;
206
207    /// Set renacer span ID for distributed tracing
208    fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()>;
209
210    /// Get renacer span ID for a run
211    fn get_span_id(&self, run_id: &str) -> Result<Option<String>>;
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_metric_point_new() {
220        let point = MetricPoint::new(10, 0.5);
221        assert_eq!(point.step, 10);
222        assert!((point.value - 0.5).abs() < f64::EPSILON);
223    }
224
225    #[test]
226    fn test_metric_point_with_timestamp() {
227        let ts = Utc::now();
228        let point = MetricPoint::with_timestamp(5, 0.3, ts);
229        assert_eq!(point.step, 5);
230        assert_eq!(point.timestamp, ts);
231    }
232
233    #[test]
234    fn test_run_status_variants() {
235        assert_ne!(RunStatus::Pending, RunStatus::Running);
236        assert_ne!(RunStatus::Success, RunStatus::Failed);
237    }
238
239    #[test]
240    fn test_storage_error_display() {
241        let err = StorageError::ExperimentNotFound("exp-1".to_string());
242        assert!(err.to_string().contains("exp-1"));
243
244        let err = StorageError::RunNotFound("run-1".to_string());
245        assert!(err.to_string().contains("run-1"));
246
247        let err = StorageError::InvalidState("cannot start".to_string());
248        assert!(err.to_string().contains("cannot start"));
249    }
250}