Skip to main content

axonml_serialize/
state_dict.rs

1//! State Dictionary - Model parameter storage
2//!
3//! # File
4//! `crates/axonml-serialize/src/state_dict.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_core::Result;
18use axonml_nn::Module;
19use axonml_tensor::Tensor;
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23// =============================================================================
24// TensorData
25// =============================================================================
26
27/// Serializable tensor data.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TensorData {
30    /// Shape of the tensor.
31    pub shape: Vec<usize>,
32    /// Flattened f32 values.
33    pub values: Vec<f32>,
34}
35
36impl TensorData {
37    /// Create `TensorData` from a Tensor.
38    #[must_use]
39    pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
40        Self {
41            shape: tensor.shape().to_vec(),
42            values: tensor.to_vec(),
43        }
44    }
45
46    /// Convert `TensorData` back to a Tensor.
47    pub fn to_tensor(&self) -> Result<Tensor<f32>> {
48        Tensor::from_vec(self.values.clone(), &self.shape)
49    }
50
51    /// Get the number of elements.
52    #[must_use]
53    pub fn numel(&self) -> usize {
54        self.values.len()
55    }
56
57    /// Get the shape.
58    #[must_use]
59    pub fn shape(&self) -> &[usize] {
60        &self.shape
61    }
62}
63
64// =============================================================================
65// StateDictEntry
66// =============================================================================
67
68/// An entry in the state dictionary.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct StateDictEntry {
71    /// The tensor data.
72    pub data: TensorData,
73    /// Whether this parameter requires gradients.
74    pub requires_grad: bool,
75    /// Optional metadata.
76    #[serde(default)]
77    pub metadata: HashMap<String, String>,
78}
79
80impl StateDictEntry {
81    /// Create a new entry from tensor data.
82    #[must_use]
83    pub fn new(data: TensorData, requires_grad: bool) -> Self {
84        Self {
85            data,
86            requires_grad,
87            metadata: HashMap::new(),
88        }
89    }
90
91    /// Add metadata to the entry.
92    #[must_use]
93    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
94        self.metadata.insert(key.to_string(), value.to_string());
95        self
96    }
97}
98
99// =============================================================================
100// StateDict
101// =============================================================================
102
103/// State dictionary for storing model parameters.
104///
105/// This is similar to `PyTorch`'s `state_dict`, mapping parameter names to tensors.
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct StateDict {
108    entries: HashMap<String, StateDictEntry>,
109    #[serde(default)]
110    metadata: HashMap<String, String>,
111}
112
113impl StateDict {
114    /// Create an empty state dictionary.
115    #[must_use]
116    pub fn new() -> Self {
117        Self::default()
118    }
119
120    /// Create a state dictionary from a module.
121    ///
122    /// Uses `named_parameters()` to get fully-qualified parameter names
123    /// (e.g., "encoder_lstm1.weight_ih") that are unique across sub-modules.
124    /// Falls back to `parameters()` if `named_parameters()` returns empty.
125    pub fn from_module<M: Module>(module: &M) -> Self {
126        let mut state_dict = Self::new();
127
128        let named = module.named_parameters();
129        if named.is_empty() {
130            // Fallback for modules that don't implement named_parameters.
131            // Always use indexed keys (param_0, param_1, ...) to avoid HashMap
132            // collisions when multiple sub-modules share parameter names like
133            // "weight" and "bias".
134            for (i, param) in module.parameters().iter().enumerate() {
135                let name = format!("param_{i}");
136                let tensor_data = TensorData::from_tensor(&param.data());
137                let entry = StateDictEntry::new(tensor_data, param.requires_grad());
138                state_dict.entries.insert(name, entry);
139            }
140        } else {
141            for (name, param) in named {
142                let tensor_data = TensorData::from_tensor(&param.data());
143                let entry = StateDictEntry::new(tensor_data, param.requires_grad());
144                state_dict.entries.insert(name, entry);
145            }
146        }
147
148        state_dict
149    }
150
151    /// Insert a tensor into the state dictionary.
152    pub fn insert(&mut self, name: String, data: TensorData) {
153        let entry = StateDictEntry::new(data, true);
154        self.entries.insert(name, entry);
155    }
156
157    /// Insert an entry into the state dictionary.
158    pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
159        self.entries.insert(name, entry);
160    }
161
162    /// Get an entry by name.
163    #[must_use]
164    pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
165        self.entries.get(name)
166    }
167
168    /// Get a mutable entry by name.
169    pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
170        self.entries.get_mut(name)
171    }
172
173    /// Check if the state dictionary contains a key.
174    #[must_use]
175    pub fn contains(&self, name: &str) -> bool {
176        self.entries.contains_key(name)
177    }
178
179    /// Get the number of entries.
180    #[must_use]
181    pub fn len(&self) -> usize {
182        self.entries.len()
183    }
184
185    /// Check if the state dictionary is empty.
186    #[must_use]
187    pub fn is_empty(&self) -> bool {
188        self.entries.is_empty()
189    }
190
191    /// Get all keys.
192    pub fn keys(&self) -> impl Iterator<Item = &String> {
193        self.entries.keys()
194    }
195
196    /// Get all entries.
197    pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
198        self.entries.iter()
199    }
200
201    /// Remove an entry.
202    pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
203        self.entries.remove(name)
204    }
205
206    /// Merge another state dictionary into this one.
207    pub fn merge(&mut self, other: StateDict) {
208        for (name, entry) in other.entries {
209            self.entries.insert(name, entry);
210        }
211    }
212
213    /// Get a subset of entries matching a prefix.
214    #[must_use]
215    pub fn filter_prefix(&self, prefix: &str) -> StateDict {
216        let mut filtered = StateDict::new();
217        for (name, entry) in &self.entries {
218            if name.starts_with(prefix) {
219                filtered.entries.insert(name.clone(), entry.clone());
220            }
221        }
222        filtered
223    }
224
225    /// Strip a prefix from all keys.
226    #[must_use]
227    pub fn strip_prefix(&self, prefix: &str) -> StateDict {
228        let mut stripped = StateDict::new();
229        for (name, entry) in &self.entries {
230            let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
231            stripped.entries.insert(new_name, entry.clone());
232        }
233        stripped
234    }
235
236    /// Add a prefix to all keys.
237    #[must_use]
238    pub fn add_prefix(&self, prefix: &str) -> StateDict {
239        let mut prefixed = StateDict::new();
240        for (name, entry) in &self.entries {
241            let new_name = format!("{prefix}{name}");
242            prefixed.entries.insert(new_name, entry.clone());
243        }
244        prefixed
245    }
246
247    /// Set metadata on the state dictionary.
248    pub fn set_metadata(&mut self, key: &str, value: &str) {
249        self.metadata.insert(key.to_string(), value.to_string());
250    }
251
252    /// Get metadata from the state dictionary.
253    #[must_use]
254    pub fn get_metadata(&self, key: &str) -> Option<&String> {
255        self.metadata.get(key)
256    }
257
258    /// Get total number of parameters (elements across all tensors).
259    #[must_use]
260    pub fn total_params(&self) -> usize {
261        self.entries.values().map(|e| e.data.numel()).sum()
262    }
263
264    /// Get total size in bytes.
265    #[must_use]
266    pub fn size_bytes(&self) -> usize {
267        self.total_params() * std::mem::size_of::<f32>()
268    }
269
270    /// Print a summary of the state dictionary.
271    #[must_use]
272    pub fn summary(&self) -> String {
273        let mut lines = Vec::new();
274        lines.push(format!("StateDict with {} entries:", self.len()));
275        lines.push(format!("  Total parameters: {}", self.total_params()));
276        lines.push(format!("  Size: {} bytes", self.size_bytes()));
277        lines.push("  Entries:".to_string());
278
279        for (name, entry) in &self.entries {
280            lines.push(format!(
281                "    {} - shape: {:?}, numel: {}",
282                name,
283                entry.data.shape,
284                entry.data.numel()
285            ));
286        }
287
288        lines.join("\n")
289    }
290}
291
292// =============================================================================
293// Tests
294// =============================================================================
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_tensor_data_roundtrip() {
302        let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
303        let data = TensorData::from_tensor(&original);
304        let restored = data.to_tensor().unwrap();
305
306        assert_eq!(original.shape(), restored.shape());
307        assert_eq!(original.to_vec(), restored.to_vec());
308    }
309
310    #[test]
311    fn test_state_dict_operations() {
312        let mut state_dict = StateDict::new();
313
314        let data1 = TensorData {
315            shape: vec![10, 5],
316            values: vec![0.0; 50],
317        };
318        let data2 = TensorData {
319            shape: vec![5],
320            values: vec![0.0; 5],
321        };
322
323        state_dict.insert("linear.weight".to_string(), data1);
324        state_dict.insert("linear.bias".to_string(), data2);
325
326        assert_eq!(state_dict.len(), 2);
327        assert_eq!(state_dict.total_params(), 55);
328        assert!(state_dict.contains("linear.weight"));
329        assert!(state_dict.contains("linear.bias"));
330    }
331
332    #[test]
333    fn test_state_dict_filter_prefix() {
334        let mut state_dict = StateDict::new();
335
336        state_dict.insert(
337            "encoder.layer1.weight".to_string(),
338            TensorData {
339                shape: vec![10],
340                values: vec![0.0; 10],
341            },
342        );
343        state_dict.insert(
344            "encoder.layer1.bias".to_string(),
345            TensorData {
346                shape: vec![10],
347                values: vec![0.0; 10],
348            },
349        );
350        state_dict.insert(
351            "decoder.layer1.weight".to_string(),
352            TensorData {
353                shape: vec![10],
354                values: vec![0.0; 10],
355            },
356        );
357
358        let encoder_dict = state_dict.filter_prefix("encoder.");
359        assert_eq!(encoder_dict.len(), 2);
360        assert!(encoder_dict.contains("encoder.layer1.weight"));
361    }
362
363    #[test]
364    fn test_state_dict_strip_prefix() {
365        let mut state_dict = StateDict::new();
366
367        state_dict.insert(
368            "model.linear.weight".to_string(),
369            TensorData {
370                shape: vec![10],
371                values: vec![0.0; 10],
372            },
373        );
374
375        let stripped = state_dict.strip_prefix("model.");
376        assert!(stripped.contains("linear.weight"));
377    }
378
379    #[test]
380    fn test_state_dict_merge() {
381        let mut dict1 = StateDict::new();
382        dict1.insert(
383            "a".to_string(),
384            TensorData {
385                shape: vec![1],
386                values: vec![1.0],
387            },
388        );
389
390        let mut dict2 = StateDict::new();
391        dict2.insert(
392            "b".to_string(),
393            TensorData {
394                shape: vec![1],
395                values: vec![2.0],
396            },
397        );
398
399        dict1.merge(dict2);
400        assert_eq!(dict1.len(), 2);
401        assert!(dict1.contains("a"));
402        assert!(dict1.contains("b"));
403    }
404
405    #[test]
406    fn test_state_dict_summary() {
407        let mut state_dict = StateDict::new();
408        state_dict.insert(
409            "weight".to_string(),
410            TensorData {
411                shape: vec![10, 5],
412                values: vec![0.0; 50],
413            },
414        );
415
416        let summary = state_dict.summary();
417        assert!(summary.contains("1 entries"));
418        assert!(summary.contains("50"));
419    }
420}