entrenar/storage/sqlite/
types.rs

1//! Type definitions for SQLite storage backend.
2//!
3//! Contains parameter values, filter operations, and metadata structures.
4
5use crate::storage::RunStatus;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Parameter value types for log_param
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(tag = "type", content = "value")]
13pub enum ParameterValue {
14    String(String),
15    Int(i64),
16    Float(f64),
17    Bool(bool),
18    List(Vec<ParameterValue>),
19    Dict(HashMap<String, ParameterValue>),
20}
21
22impl ParameterValue {
23    /// Get type name for storage
24    pub fn type_name(&self) -> &'static str {
25        match self {
26            ParameterValue::String(_) => "string",
27            ParameterValue::Int(_) => "int",
28            ParameterValue::Float(_) => "float",
29            ParameterValue::Bool(_) => "bool",
30            ParameterValue::List(_) => "list",
31            ParameterValue::Dict(_) => "dict",
32        }
33    }
34
35    /// Serialize to JSON string for storage
36    pub fn to_json(&self) -> String {
37        serde_json::to_string(self).unwrap_or_default()
38    }
39
40    /// Deserialize from JSON string
41    pub fn from_json(s: &str) -> Option<Self> {
42        serde_json::from_str(s).ok()
43    }
44}
45
46/// Filter operations for parameter search
47#[derive(Debug, Clone, PartialEq)]
48pub enum FilterOp {
49    Eq,
50    Ne,
51    Gt,
52    Lt,
53    Gte,
54    Lte,
55    Contains,
56    StartsWith,
57}
58
59/// Parameter filter for searching runs
60#[derive(Debug, Clone)]
61pub struct ParamFilter {
62    pub key: String,
63    pub op: FilterOp,
64    pub value: ParameterValue,
65}
66
67/// Experiment metadata
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Experiment {
70    pub id: String,
71    pub name: String,
72    pub description: Option<String>,
73    pub config: Option<serde_json::Value>,
74    pub tags: HashMap<String, String>,
75    pub created_at: DateTime<Utc>,
76    pub updated_at: DateTime<Utc>,
77}
78
79/// Run metadata
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Run {
82    pub id: String,
83    pub experiment_id: String,
84    pub status: RunStatus,
85    pub start_time: DateTime<Utc>,
86    pub end_time: Option<DateTime<Utc>>,
87    pub params: HashMap<String, ParameterValue>,
88    pub tags: HashMap<String, String>,
89}
90
91/// Artifact reference
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ArtifactRef {
94    pub id: String,
95    pub run_id: String,
96    pub path: String,
97    pub size_bytes: u64,
98    pub sha256: String,
99    pub created_at: DateTime<Utc>,
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    // -------------------------------------------------------------------------
107    // ParameterValue Tests
108    // -------------------------------------------------------------------------
109
110    #[test]
111    fn test_parameter_value_type_name() {
112        assert_eq!(
113            ParameterValue::String("test".to_string()).type_name(),
114            "string"
115        );
116        assert_eq!(ParameterValue::Int(42).type_name(), "int");
117        assert_eq!(ParameterValue::Float(3.14).type_name(), "float");
118        assert_eq!(ParameterValue::Bool(true).type_name(), "bool");
119        assert_eq!(ParameterValue::List(vec![]).type_name(), "list");
120        assert_eq!(ParameterValue::Dict(HashMap::new()).type_name(), "dict");
121    }
122
123    #[test]
124    fn test_parameter_value_json_roundtrip() {
125        let values = vec![
126            ParameterValue::String("hello".to_string()),
127            ParameterValue::Int(42),
128            ParameterValue::Float(3.14),
129            ParameterValue::Bool(true),
130            ParameterValue::List(vec![ParameterValue::Int(1), ParameterValue::Int(2)]),
131        ];
132
133        for value in values {
134            let json = value.to_json();
135            let parsed = ParameterValue::from_json(&json).unwrap();
136            assert_eq!(value, parsed);
137        }
138    }
139
140    #[test]
141    fn test_parameter_value_dict() {
142        let mut dict = HashMap::new();
143        dict.insert("nested".to_string(), ParameterValue::Int(42));
144        let param = ParameterValue::Dict(dict);
145        assert_eq!(param.type_name(), "dict");
146
147        let json = param.to_json();
148        let parsed = ParameterValue::from_json(&json).unwrap();
149        assert_eq!(param, parsed);
150    }
151
152    #[test]
153    fn test_parameter_value_from_invalid_json() {
154        let result = ParameterValue::from_json("invalid json {{{");
155        assert!(result.is_none());
156    }
157}