Skip to main content

entrenar/io/
model.rs

1//! Model structure for serialization
2
3use crate::Tensor;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::collections::HashMap;
6
7/// Deserialize a bool from either a YAML boolean (`true`) or a quoted string (`"true"`).
8/// This supports CB-950 compliance where all truthy values must be quoted in YAML.
9fn deserialize_bool_lenient<'de, D>(deserializer: D) -> Result<bool, D::Error>
10where
11    D: Deserializer<'de>,
12{
13    #[derive(Deserialize)]
14    #[serde(untagged)]
15    enum BoolOrString {
16        Bool(bool),
17        Str(String),
18    }
19
20    match BoolOrString::deserialize(deserializer)? {
21        BoolOrString::Bool(b) => Ok(b),
22        BoolOrString::Str(s) => match s.to_lowercase().as_str() {
23            "true" => Ok(true),
24            "false" => Ok(false),
25            other => {
26                Err(serde::de::Error::custom(format!("expected 'true' or 'false', got '{other}'")))
27            }
28        },
29    }
30}
31
32/// Model metadata containing architecture and training information
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelMetadata {
35    /// Model name/identifier
36    pub name: String,
37
38    /// Model architecture type (e.g., "transformer", "linear", "custom")
39    pub architecture: String,
40
41    /// Model version
42    pub version: String,
43
44    /// Training configuration used
45    pub training_config: Option<HashMap<String, serde_json::Value>>,
46
47    /// Custom metadata fields
48    pub custom: HashMap<String, serde_json::Value>,
49}
50
51impl ModelMetadata {
52    /// Create new metadata with minimal fields
53    pub fn new(name: impl Into<String>, architecture: impl Into<String>) -> Self {
54        Self {
55            name: name.into(),
56            architecture: architecture.into(),
57            version: "0.1.0".to_string(),
58            training_config: None,
59            custom: HashMap::new(),
60        }
61    }
62
63    /// Add custom metadata field
64    pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
65        self.custom.insert(key.into(), value);
66        self
67    }
68}
69
70/// Information about a model parameter
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ParameterInfo {
73    /// Parameter name (e.g., "layer1.weight", "bias")
74    pub name: String,
75
76    /// Parameter shape
77    pub shape: Vec<usize>,
78
79    /// Data type (e.g., "f32", "i8")
80    pub dtype: String,
81
82    /// Whether this parameter requires gradients
83    #[serde(deserialize_with = "deserialize_bool_lenient")]
84    pub requires_grad: bool,
85}
86
87/// Serializable model state
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModelState {
90    /// Model metadata
91    pub metadata: ModelMetadata,
92
93    /// Parameter information
94    pub parameters: Vec<ParameterInfo>,
95
96    /// Flattened parameter data
97    pub data: Vec<f32>,
98}
99
100/// High-level model abstraction for I/O
101pub struct Model {
102    /// Model metadata
103    pub metadata: ModelMetadata,
104
105    /// Model parameters
106    pub parameters: Vec<(String, Tensor)>,
107}
108
109impl Model {
110    /// Create a new model
111    pub fn new(metadata: ModelMetadata, parameters: Vec<(String, Tensor)>) -> Self {
112        Self { metadata, parameters }
113    }
114
115    /// Get parameter by name
116    pub fn get_parameter(&self, name: &str) -> Option<&Tensor> {
117        self.parameters.iter().find(|(n, _)| n == name).map(|(_, t)| t)
118    }
119
120    /// Get mutable parameter by name
121    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Tensor> {
122        self.parameters.iter_mut().find(|(n, _)| n == name).map(|(_, t)| t)
123    }
124
125    /// Convert model to serializable state
126    pub fn to_state(&self) -> ModelState {
127        let mut data = Vec::new();
128        let parameters: Vec<ParameterInfo> = self
129            .parameters
130            .iter()
131            .map(|(name, tensor)| {
132                let shape = vec![tensor.len()];
133                let param_data = tensor.data();
134                data.extend_from_slice(
135                    param_data.as_slice().expect("tensor data must be contiguous"),
136                );
137
138                ParameterInfo {
139                    name: name.clone(),
140                    shape,
141                    dtype: "f32".to_string(),
142                    requires_grad: tensor.requires_grad(),
143                }
144            })
145            .collect();
146
147        ModelState { metadata: self.metadata.clone(), parameters, data }
148    }
149
150    /// Create model from serializable state
151    pub fn from_state(state: ModelState) -> Self {
152        let mut data_offset = 0;
153        let parameters: Vec<(String, Tensor)> = state
154            .parameters
155            .into_iter()
156            .map(|param_info| {
157                let size: usize = param_info.shape.iter().product();
158                let param_data = state.data[data_offset..data_offset + size].to_vec();
159                data_offset += size;
160
161                let tensor = Tensor::from_vec(param_data, param_info.requires_grad);
162                (param_info.name, tensor)
163            })
164            .collect();
165
166        Self { metadata: state.metadata, parameters }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_model_metadata_creation() {
176        let meta = ModelMetadata::new("test-model", "linear");
177        assert_eq!(meta.name, "test-model");
178        assert_eq!(meta.architecture, "linear");
179        assert_eq!(meta.version, "0.1.0");
180    }
181
182    #[test]
183    fn test_model_with_custom_metadata() {
184        let meta = ModelMetadata::new("test", "custom")
185            .with_custom("layers", serde_json::json!(12))
186            .with_custom("hidden_size", serde_json::json!(768));
187
188        assert_eq!(meta.custom.len(), 2);
189        assert_eq!(meta.custom.get("layers").expect("key should exist"), &serde_json::json!(12));
190    }
191
192    #[test]
193    fn test_model_parameter_access() {
194        let params = vec![
195            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
196            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
197        ];
198
199        let model = Model::new(ModelMetadata::new("test", "linear"), params);
200
201        assert!(model.get_parameter("weight").is_some());
202        assert!(model.get_parameter("bias").is_some());
203        assert!(model.get_parameter("nonexistent").is_none());
204    }
205
206    #[test]
207    fn test_model_state_round_trip() {
208        let params = vec![
209            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
210            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
211        ];
212
213        let original = Model::new(ModelMetadata::new("test", "linear"), params);
214        let state = original.to_state();
215        let restored = Model::from_state(state);
216
217        assert_eq!(original.metadata.name, restored.metadata.name);
218        assert_eq!(original.parameters.len(), restored.parameters.len());
219
220        // Check parameter data
221        let orig_weight = original.get_parameter("weight").expect("parameter should exist");
222        let rest_weight = restored.get_parameter("weight").expect("parameter should exist");
223        assert_eq!(orig_weight.data(), rest_weight.data());
224    }
225
226    #[test]
227    fn test_model_get_parameter_mut() {
228        let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
229        let mut model = Model::new(ModelMetadata::new("test", "linear"), params);
230
231        let tensor = model.get_parameter_mut("weight").expect("parameter should exist");
232        assert!(tensor.requires_grad());
233
234        assert!(model.get_parameter_mut("nonexistent").is_none());
235    }
236
237    #[test]
238    fn test_parameter_info_clone() {
239        let info = ParameterInfo {
240            name: "layer1.weight".to_string(),
241            shape: vec![10, 20],
242            dtype: "f32".to_string(),
243            requires_grad: true,
244        };
245        let cloned = info.clone();
246        assert_eq!(info.name, cloned.name);
247        assert_eq!(info.shape, cloned.shape);
248    }
249
250    #[test]
251    fn test_model_state_fields() {
252        let state = ModelState {
253            metadata: ModelMetadata::new("test", "arch"),
254            parameters: vec![ParameterInfo {
255                name: "w".to_string(),
256                shape: vec![5],
257                dtype: "f32".to_string(),
258                requires_grad: true,
259            }],
260            data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
261        };
262        let cloned = state.clone();
263        assert_eq!(state.parameters.len(), cloned.parameters.len());
264        assert_eq!(state.data.len(), cloned.data.len());
265    }
266
267    #[test]
268    fn test_model_metadata_clone() {
269        let meta = ModelMetadata::new("model", "transformer");
270        let cloned = meta.clone();
271        assert_eq!(meta.name, cloned.name);
272    }
273}