entrenar/storage/mod.rs
1//! Experiment Storage Module (ENT-001)
2//!
3//! Provides the `ExperimentStorage` trait and backends for persisting
4//! experiment tracking data.
5//!
6//! # Backends
7//!
8//! - `TruenoBackend`: Production backend using trueno-db (feature: "monitor")
9//! - `InMemoryStorage`: In-memory backend for testing and fuzzing
10//!
11//! # Example
12//!
13//! ```
14//! use entrenar::storage::{ExperimentStorage, InMemoryStorage, RunStatus};
15//!
16//! let mut storage = InMemoryStorage::new();
17//! let exp_id = storage.create_experiment("my-experiment", None).expect("create experiment");
18//! let run_id = storage.create_run(&exp_id).expect("create run");
19//! storage.start_run(&run_id).expect("start run");
20//! storage.log_metric(&run_id, "loss", 0, 0.5).expect("log metric");
21//! storage.complete_run(&run_id, RunStatus::Success).expect("complete run");
22//! ```
23
24pub mod cloud;
25pub mod memory;
26pub mod preflight;
27pub mod registry;
28pub mod sqlite;
29#[cfg(feature = "monitor")]
30pub mod trueno;
31
32pub use cloud::{
33 ArtifactBackend, ArtifactMetadata, AzureConfig, BackendConfig, CloudError, GCSConfig,
34 InMemoryBackend, LocalBackend, MockS3Backend, S3Config,
35};
36pub use memory::InMemoryStorage;
37pub use preflight::{
38 CheckMetadata, CheckResult, CheckType, Preflight, PreflightCheck, PreflightContext,
39 PreflightError, PreflightResults,
40};
41pub use registry::{
42 Comparison, InMemoryRegistry, MetricRequirement, ModelRegistry, ModelStage, ModelVersion,
43 PolicyCheckResult, PromotionPolicy, RegistryError, StageTransition, VersionComparison,
44};
45pub use sqlite::{
46 ArtifactRef, Experiment, FilterOp, ParamFilter, ParameterValue, Run, SqliteBackend,
47};
48#[cfg(feature = "monitor")]
49pub use trueno::TruenoBackend;
50
51use chrono::{DateTime, Utc};
52use serde::{Deserialize, Serialize};
53use thiserror::Error;
54
55/// Storage errors
56#[derive(Debug, Error)]
57pub enum StorageError {
58 #[error("IO error: {0}")]
59 Io(#[from] std::io::Error),
60
61 #[error("Experiment not found: {0}")]
62 ExperimentNotFound(String),
63
64 #[error("Run not found: {0}")]
65 RunNotFound(String),
66
67 #[error("Invalid state transition: {0}")]
68 InvalidState(String),
69
70 #[error("Storage backend error: {0}")]
71 Backend(String),
72}
73
74/// Result type for storage operations
75pub type Result<T> = std::result::Result<T, StorageError>;
76
77/// Status of a run
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum RunStatus {
80 /// Run is created but not yet started
81 Pending,
82 /// Run is currently executing
83 Running,
84 /// Run completed successfully
85 Success,
86 /// Run failed with an error
87 Failed,
88 /// Run was cancelled
89 Cancelled,
90}
91
92/// A single metric data point
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94pub struct MetricPoint {
95 /// Training step
96 pub step: u64,
97 /// Metric value
98 pub value: f64,
99 /// Timestamp when recorded
100 pub timestamp: DateTime<Utc>,
101}
102
103impl MetricPoint {
104 /// Create a new metric point with current timestamp
105 pub fn new(step: u64, value: f64) -> Self {
106 Self { step, value, timestamp: Utc::now() }
107 }
108
109 /// Create a metric point with specific timestamp
110 pub fn with_timestamp(step: u64, value: f64, timestamp: DateTime<Utc>) -> Self {
111 Self { step, value, timestamp }
112 }
113}
114
115/// Trait for experiment storage backends
116///
117/// This trait abstracts over different storage implementations, allowing
118/// for production use with TruenoDB and testing with in-memory storage.
119///
120/// # Thread Safety
121///
122/// All implementations must be `Send + Sync` to support concurrent access
123/// from multiple training threads.
124pub trait ExperimentStorage: Send + Sync {
125 /// Create a new experiment
126 ///
127 /// # Arguments
128 ///
129 /// * `name` - Human-readable experiment name
130 /// * `config` - Optional JSON configuration for the experiment
131 ///
132 /// # Returns
133 ///
134 /// Unique experiment ID
135 fn create_experiment(
136 &mut self,
137 name: &str,
138 config: Option<serde_json::Value>,
139 ) -> Result<String>;
140
141 /// Create a new run within an experiment
142 ///
143 /// The run starts in `Pending` status.
144 ///
145 /// # Arguments
146 ///
147 /// * `experiment_id` - ID of the parent experiment
148 ///
149 /// # Returns
150 ///
151 /// Unique run ID
152 fn create_run(&mut self, experiment_id: &str) -> Result<String>;
153
154 /// Start a run, transitioning from Pending to Running
155 ///
156 /// # Arguments
157 ///
158 /// * `run_id` - ID of the run to start
159 fn start_run(&mut self, run_id: &str) -> Result<()>;
160
161 /// Complete a run with the given status
162 ///
163 /// # Arguments
164 ///
165 /// * `run_id` - ID of the run
166 /// * `status` - Final status (Success, Failed, or Cancelled)
167 fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()>;
168
169 /// Log a metric value for a run
170 ///
171 /// # Arguments
172 ///
173 /// * `run_id` - ID of the run
174 /// * `key` - Metric name (e.g., "loss", "accuracy")
175 /// * `step` - Training step or epoch number
176 /// * `value` - Metric value
177 fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()>;
178
179 /// Log an artifact for a run
180 ///
181 /// # Arguments
182 ///
183 /// * `run_id` - ID of the run
184 /// * `key` - Artifact name (e.g., "model.safetensors")
185 /// * `data` - Artifact data bytes
186 ///
187 /// # Returns
188 ///
189 /// Content-addressable hash of the artifact
190 fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String>;
191
192 /// Get metrics for a specific run and key
193 ///
194 /// # Arguments
195 ///
196 /// * `run_id` - ID of the run
197 /// * `key` - Metric name to retrieve
198 ///
199 /// # Returns
200 ///
201 /// Vector of metric points, ordered by step
202 fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>>;
203
204 /// Get the current status of a run
205 fn get_run_status(&self, run_id: &str) -> Result<RunStatus>;
206
207 /// Set renacer span ID for distributed tracing
208 fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()>;
209
210 /// Get renacer span ID for a run
211 fn get_span_id(&self, run_id: &str) -> Result<Option<String>>;
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_metric_point_new() {
220 let point = MetricPoint::new(10, 0.5);
221 assert_eq!(point.step, 10);
222 assert!((point.value - 0.5).abs() < f64::EPSILON);
223 }
224
225 #[test]
226 fn test_metric_point_with_timestamp() {
227 let ts = Utc::now();
228 let point = MetricPoint::with_timestamp(5, 0.3, ts);
229 assert_eq!(point.step, 5);
230 assert_eq!(point.timestamp, ts);
231 }
232
233 #[test]
234 fn test_run_status_variants() {
235 assert_ne!(RunStatus::Pending, RunStatus::Running);
236 assert_ne!(RunStatus::Success, RunStatus::Failed);
237 }
238
239 #[test]
240 fn test_storage_error_display() {
241 let err = StorageError::ExperimentNotFound("exp-1".to_string());
242 assert!(err.to_string().contains("exp-1"));
243
244 let err = StorageError::RunNotFound("run-1".to_string());
245 assert!(err.to_string().contains("run-1"));
246
247 let err = StorageError::InvalidState("cannot start".to_string());
248 assert!(err.to_string().contains("cannot start"));
249 }
250}