Skip to main content

axonml_serialize/
state_dict.rs

1//! State Dictionary - Model parameter storage
2//!
3//! Provides `StateDict` for storing and retrieving model parameters by name.
4
5use axonml_core::Result;
6use axonml_nn::Module;
7use axonml_tensor::Tensor;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// =============================================================================
12// TensorData
13// =============================================================================
14
15/// Serializable tensor data.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorData {
18    /// Shape of the tensor.
19    pub shape: Vec<usize>,
20    /// Flattened f32 values.
21    pub values: Vec<f32>,
22}
23
24impl TensorData {
25    /// Create `TensorData` from a Tensor.
26    #[must_use]
27    pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
28        Self {
29            shape: tensor.shape().to_vec(),
30            values: tensor.to_vec(),
31        }
32    }
33
34    /// Convert `TensorData` back to a Tensor.
35    pub fn to_tensor(&self) -> Result<Tensor<f32>> {
36        Tensor::from_vec(self.values.clone(), &self.shape)
37    }
38
39    /// Get the number of elements.
40    #[must_use]
41    pub fn numel(&self) -> usize {
42        self.values.len()
43    }
44
45    /// Get the shape.
46    #[must_use]
47    pub fn shape(&self) -> &[usize] {
48        &self.shape
49    }
50}
51
52// =============================================================================
53// StateDictEntry
54// =============================================================================
55
56/// An entry in the state dictionary.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct StateDictEntry {
59    /// The tensor data.
60    pub data: TensorData,
61    /// Whether this parameter requires gradients.
62    pub requires_grad: bool,
63    /// Optional metadata.
64    #[serde(default)]
65    pub metadata: HashMap<String, String>,
66}
67
68impl StateDictEntry {
69    /// Create a new entry from tensor data.
70    #[must_use]
71    pub fn new(data: TensorData, requires_grad: bool) -> Self {
72        Self {
73            data,
74            requires_grad,
75            metadata: HashMap::new(),
76        }
77    }
78
79    /// Add metadata to the entry.
80    #[must_use]
81    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
82        self.metadata.insert(key.to_string(), value.to_string());
83        self
84    }
85}
86
87// =============================================================================
88// StateDict
89// =============================================================================
90
91/// State dictionary for storing model parameters.
92///
93/// This is similar to `PyTorch`'s `state_dict`, mapping parameter names to tensors.
94#[derive(Debug, Clone, Default, Serialize, Deserialize)]
95pub struct StateDict {
96    entries: HashMap<String, StateDictEntry>,
97    #[serde(default)]
98    metadata: HashMap<String, String>,
99}
100
101impl StateDict {
102    /// Create an empty state dictionary.
103    #[must_use]
104    pub fn new() -> Self {
105        Self::default()
106    }
107
108    /// Create a state dictionary from a module.
109    ///
110    /// Uses `named_parameters()` to get fully-qualified parameter names
111    /// (e.g., "encoder_lstm1.weight_ih") that are unique across sub-modules.
112    /// Falls back to `parameters()` if `named_parameters()` returns empty.
113    pub fn from_module<M: Module>(module: &M) -> Self {
114        let mut state_dict = Self::new();
115
116        let named = module.named_parameters();
117        if !named.is_empty() {
118            for (name, param) in named {
119                let tensor_data = TensorData::from_tensor(&param.data());
120                let entry = StateDictEntry::new(tensor_data, param.requires_grad());
121                state_dict.entries.insert(name, entry);
122            }
123        } else {
124            // Fallback for modules that don't implement named_parameters
125            for param in module.parameters() {
126                let name = param.name().to_string();
127                let tensor_data = TensorData::from_tensor(&param.data());
128                let entry = StateDictEntry::new(tensor_data, param.requires_grad());
129                state_dict.entries.insert(name, entry);
130            }
131        }
132
133        state_dict
134    }
135
136    /// Insert a tensor into the state dictionary.
137    pub fn insert(&mut self, name: String, data: TensorData) {
138        let entry = StateDictEntry::new(data, true);
139        self.entries.insert(name, entry);
140    }
141
142    /// Insert an entry into the state dictionary.
143    pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
144        self.entries.insert(name, entry);
145    }
146
147    /// Get an entry by name.
148    #[must_use]
149    pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
150        self.entries.get(name)
151    }
152
153    /// Get a mutable entry by name.
154    pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
155        self.entries.get_mut(name)
156    }
157
158    /// Check if the state dictionary contains a key.
159    #[must_use]
160    pub fn contains(&self, name: &str) -> bool {
161        self.entries.contains_key(name)
162    }
163
164    /// Get the number of entries.
165    #[must_use]
166    pub fn len(&self) -> usize {
167        self.entries.len()
168    }
169
170    /// Check if the state dictionary is empty.
171    #[must_use]
172    pub fn is_empty(&self) -> bool {
173        self.entries.is_empty()
174    }
175
176    /// Get all keys.
177    pub fn keys(&self) -> impl Iterator<Item = &String> {
178        self.entries.keys()
179    }
180
181    /// Get all entries.
182    pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
183        self.entries.iter()
184    }
185
186    /// Remove an entry.
187    pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
188        self.entries.remove(name)
189    }
190
191    /// Merge another state dictionary into this one.
192    pub fn merge(&mut self, other: StateDict) {
193        for (name, entry) in other.entries {
194            self.entries.insert(name, entry);
195        }
196    }
197
198    /// Get a subset of entries matching a prefix.
199    #[must_use]
200    pub fn filter_prefix(&self, prefix: &str) -> StateDict {
201        let mut filtered = StateDict::new();
202        for (name, entry) in &self.entries {
203            if name.starts_with(prefix) {
204                filtered.entries.insert(name.clone(), entry.clone());
205            }
206        }
207        filtered
208    }
209
210    /// Strip a prefix from all keys.
211    #[must_use]
212    pub fn strip_prefix(&self, prefix: &str) -> StateDict {
213        let mut stripped = StateDict::new();
214        for (name, entry) in &self.entries {
215            let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
216            stripped.entries.insert(new_name, entry.clone());
217        }
218        stripped
219    }
220
221    /// Add a prefix to all keys.
222    #[must_use]
223    pub fn add_prefix(&self, prefix: &str) -> StateDict {
224        let mut prefixed = StateDict::new();
225        for (name, entry) in &self.entries {
226            let new_name = format!("{prefix}{name}");
227            prefixed.entries.insert(new_name, entry.clone());
228        }
229        prefixed
230    }
231
232    /// Set metadata on the state dictionary.
233    pub fn set_metadata(&mut self, key: &str, value: &str) {
234        self.metadata.insert(key.to_string(), value.to_string());
235    }
236
237    /// Get metadata from the state dictionary.
238    #[must_use]
239    pub fn get_metadata(&self, key: &str) -> Option<&String> {
240        self.metadata.get(key)
241    }
242
243    /// Get total number of parameters (elements across all tensors).
244    #[must_use]
245    pub fn total_params(&self) -> usize {
246        self.entries.values().map(|e| e.data.numel()).sum()
247    }
248
249    /// Get total size in bytes.
250    #[must_use]
251    pub fn size_bytes(&self) -> usize {
252        self.total_params() * std::mem::size_of::<f32>()
253    }
254
255    /// Print a summary of the state dictionary.
256    #[must_use]
257    pub fn summary(&self) -> String {
258        let mut lines = Vec::new();
259        lines.push(format!("StateDict with {} entries:", self.len()));
260        lines.push(format!("  Total parameters: {}", self.total_params()));
261        lines.push(format!("  Size: {} bytes", self.size_bytes()));
262        lines.push("  Entries:".to_string());
263
264        for (name, entry) in &self.entries {
265            lines.push(format!(
266                "    {} - shape: {:?}, numel: {}",
267                name,
268                entry.data.shape,
269                entry.data.numel()
270            ));
271        }
272
273        lines.join("\n")
274    }
275}
276
277// =============================================================================
278// Tests
279// =============================================================================
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_tensor_data_roundtrip() {
287        let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
288        let data = TensorData::from_tensor(&original);
289        let restored = data.to_tensor().unwrap();
290
291        assert_eq!(original.shape(), restored.shape());
292        assert_eq!(original.to_vec(), restored.to_vec());
293    }
294
295    #[test]
296    fn test_state_dict_operations() {
297        let mut state_dict = StateDict::new();
298
299        let data1 = TensorData {
300            shape: vec![10, 5],
301            values: vec![0.0; 50],
302        };
303        let data2 = TensorData {
304            shape: vec![5],
305            values: vec![0.0; 5],
306        };
307
308        state_dict.insert("linear.weight".to_string(), data1);
309        state_dict.insert("linear.bias".to_string(), data2);
310
311        assert_eq!(state_dict.len(), 2);
312        assert_eq!(state_dict.total_params(), 55);
313        assert!(state_dict.contains("linear.weight"));
314        assert!(state_dict.contains("linear.bias"));
315    }
316
317    #[test]
318    fn test_state_dict_filter_prefix() {
319        let mut state_dict = StateDict::new();
320
321        state_dict.insert(
322            "encoder.layer1.weight".to_string(),
323            TensorData {
324                shape: vec![10],
325                values: vec![0.0; 10],
326            },
327        );
328        state_dict.insert(
329            "encoder.layer1.bias".to_string(),
330            TensorData {
331                shape: vec![10],
332                values: vec![0.0; 10],
333            },
334        );
335        state_dict.insert(
336            "decoder.layer1.weight".to_string(),
337            TensorData {
338                shape: vec![10],
339                values: vec![0.0; 10],
340            },
341        );
342
343        let encoder_dict = state_dict.filter_prefix("encoder.");
344        assert_eq!(encoder_dict.len(), 2);
345        assert!(encoder_dict.contains("encoder.layer1.weight"));
346    }
347
348    #[test]
349    fn test_state_dict_strip_prefix() {
350        let mut state_dict = StateDict::new();
351
352        state_dict.insert(
353            "model.linear.weight".to_string(),
354            TensorData {
355                shape: vec![10],
356                values: vec![0.0; 10],
357            },
358        );
359
360        let stripped = state_dict.strip_prefix("model.");
361        assert!(stripped.contains("linear.weight"));
362    }
363
364    #[test]
365    fn test_state_dict_merge() {
366        let mut dict1 = StateDict::new();
367        dict1.insert(
368            "a".to_string(),
369            TensorData {
370                shape: vec![1],
371                values: vec![1.0],
372            },
373        );
374
375        let mut dict2 = StateDict::new();
376        dict2.insert(
377            "b".to_string(),
378            TensorData {
379                shape: vec![1],
380                values: vec![2.0],
381            },
382        );
383
384        dict1.merge(dict2);
385        assert_eq!(dict1.len(), 2);
386        assert!(dict1.contains("a"));
387        assert!(dict1.contains("b"));
388    }
389
390    #[test]
391    fn test_state_dict_summary() {
392        let mut state_dict = StateDict::new();
393        state_dict.insert(
394            "weight".to_string(),
395            TensorData {
396                shape: vec![10, 5],
397                values: vec![0.0; 50],
398            },
399        );
400
401        let summary = state_dict.summary();
402        assert!(summary.contains("1 entries"));
403        assert!(summary.contains("50"));
404    }
405}