1use crate::Tensor;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::collections::HashMap;
6
7fn deserialize_bool_lenient<'de, D>(deserializer: D) -> Result<bool, D::Error>
10where
11 D: Deserializer<'de>,
12{
13 #[derive(Deserialize)]
14 #[serde(untagged)]
15 enum BoolOrString {
16 Bool(bool),
17 Str(String),
18 }
19
20 match BoolOrString::deserialize(deserializer)? {
21 BoolOrString::Bool(b) => Ok(b),
22 BoolOrString::Str(s) => match s.to_lowercase().as_str() {
23 "true" => Ok(true),
24 "false" => Ok(false),
25 other => {
26 Err(serde::de::Error::custom(format!("expected 'true' or 'false', got '{other}'")))
27 }
28 },
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelMetadata {
35 pub name: String,
37
38 pub architecture: String,
40
41 pub version: String,
43
44 pub training_config: Option<HashMap<String, serde_json::Value>>,
46
47 pub custom: HashMap<String, serde_json::Value>,
49}
50
51impl ModelMetadata {
52 pub fn new(name: impl Into<String>, architecture: impl Into<String>) -> Self {
54 Self {
55 name: name.into(),
56 architecture: architecture.into(),
57 version: "0.1.0".to_string(),
58 training_config: None,
59 custom: HashMap::new(),
60 }
61 }
62
63 pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
65 self.custom.insert(key.into(), value);
66 self
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ParameterInfo {
73 pub name: String,
75
76 pub shape: Vec<usize>,
78
79 pub dtype: String,
81
82 #[serde(deserialize_with = "deserialize_bool_lenient")]
84 pub requires_grad: bool,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModelState {
90 pub metadata: ModelMetadata,
92
93 pub parameters: Vec<ParameterInfo>,
95
96 pub data: Vec<f32>,
98}
99
100pub struct Model {
102 pub metadata: ModelMetadata,
104
105 pub parameters: Vec<(String, Tensor)>,
107}
108
109impl Model {
110 pub fn new(metadata: ModelMetadata, parameters: Vec<(String, Tensor)>) -> Self {
112 Self { metadata, parameters }
113 }
114
115 pub fn get_parameter(&self, name: &str) -> Option<&Tensor> {
117 self.parameters.iter().find(|(n, _)| n == name).map(|(_, t)| t)
118 }
119
120 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Tensor> {
122 self.parameters.iter_mut().find(|(n, _)| n == name).map(|(_, t)| t)
123 }
124
125 pub fn to_state(&self) -> ModelState {
127 let mut data = Vec::new();
128 let parameters: Vec<ParameterInfo> = self
129 .parameters
130 .iter()
131 .map(|(name, tensor)| {
132 let shape = vec![tensor.len()];
133 let param_data = tensor.data();
134 data.extend_from_slice(
135 param_data.as_slice().expect("tensor data must be contiguous"),
136 );
137
138 ParameterInfo {
139 name: name.clone(),
140 shape,
141 dtype: "f32".to_string(),
142 requires_grad: tensor.requires_grad(),
143 }
144 })
145 .collect();
146
147 ModelState { metadata: self.metadata.clone(), parameters, data }
148 }
149
150 pub fn from_state(state: ModelState) -> Self {
152 let mut data_offset = 0;
153 let parameters: Vec<(String, Tensor)> = state
154 .parameters
155 .into_iter()
156 .map(|param_info| {
157 let size: usize = param_info.shape.iter().product();
158 let param_data = state.data[data_offset..data_offset + size].to_vec();
159 data_offset += size;
160
161 let tensor = Tensor::from_vec(param_data, param_info.requires_grad);
162 (param_info.name, tensor)
163 })
164 .collect();
165
166 Self { metadata: state.metadata, parameters }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_model_metadata_creation() {
176 let meta = ModelMetadata::new("test-model", "linear");
177 assert_eq!(meta.name, "test-model");
178 assert_eq!(meta.architecture, "linear");
179 assert_eq!(meta.version, "0.1.0");
180 }
181
182 #[test]
183 fn test_model_with_custom_metadata() {
184 let meta = ModelMetadata::new("test", "custom")
185 .with_custom("layers", serde_json::json!(12))
186 .with_custom("hidden_size", serde_json::json!(768));
187
188 assert_eq!(meta.custom.len(), 2);
189 assert_eq!(meta.custom.get("layers").expect("key should exist"), &serde_json::json!(12));
190 }
191
192 #[test]
193 fn test_model_parameter_access() {
194 let params = vec![
195 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
196 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
197 ];
198
199 let model = Model::new(ModelMetadata::new("test", "linear"), params);
200
201 assert!(model.get_parameter("weight").is_some());
202 assert!(model.get_parameter("bias").is_some());
203 assert!(model.get_parameter("nonexistent").is_none());
204 }
205
206 #[test]
207 fn test_model_state_round_trip() {
208 let params = vec![
209 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
210 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
211 ];
212
213 let original = Model::new(ModelMetadata::new("test", "linear"), params);
214 let state = original.to_state();
215 let restored = Model::from_state(state);
216
217 assert_eq!(original.metadata.name, restored.metadata.name);
218 assert_eq!(original.parameters.len(), restored.parameters.len());
219
220 let orig_weight = original.get_parameter("weight").expect("parameter should exist");
222 let rest_weight = restored.get_parameter("weight").expect("parameter should exist");
223 assert_eq!(orig_weight.data(), rest_weight.data());
224 }
225
226 #[test]
227 fn test_model_get_parameter_mut() {
228 let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
229 let mut model = Model::new(ModelMetadata::new("test", "linear"), params);
230
231 let tensor = model.get_parameter_mut("weight").expect("parameter should exist");
232 assert!(tensor.requires_grad());
233
234 assert!(model.get_parameter_mut("nonexistent").is_none());
235 }
236
237 #[test]
238 fn test_parameter_info_clone() {
239 let info = ParameterInfo {
240 name: "layer1.weight".to_string(),
241 shape: vec![10, 20],
242 dtype: "f32".to_string(),
243 requires_grad: true,
244 };
245 let cloned = info.clone();
246 assert_eq!(info.name, cloned.name);
247 assert_eq!(info.shape, cloned.shape);
248 }
249
250 #[test]
251 fn test_model_state_fields() {
252 let state = ModelState {
253 metadata: ModelMetadata::new("test", "arch"),
254 parameters: vec![ParameterInfo {
255 name: "w".to_string(),
256 shape: vec![5],
257 dtype: "f32".to_string(),
258 requires_grad: true,
259 }],
260 data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
261 };
262 let cloned = state.clone();
263 assert_eq!(state.parameters.len(), cloned.parameters.len());
264 assert_eq!(state.data.len(), cloned.data.len());
265 }
266
267 #[test]
268 fn test_model_metadata_clone() {
269 let meta = ModelMetadata::new("model", "transformer");
270 let cloned = meta.clone();
271 assert_eq!(meta.name, cloned.name);
272 }
273}