Skip to main content

axonml_serialize/
state_dict.rs

1//! State Dictionary - Model parameter storage
2//!
3//! Provides `StateDict` for storing and retrieving model parameters by name.
4
5use axonml_core::Result;
6use axonml_nn::Module;
7use axonml_tensor::Tensor;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// =============================================================================
12// TensorData
13// =============================================================================
14
15/// Serializable tensor data.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorData {
18    /// Shape of the tensor.
19    pub shape: Vec<usize>,
20    /// Flattened f32 values.
21    pub values: Vec<f32>,
22}
23
24impl TensorData {
25    /// Create `TensorData` from a Tensor.
26    #[must_use] pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
27        Self {
28            shape: tensor.shape().to_vec(),
29            values: tensor.to_vec(),
30        }
31    }
32
33    /// Convert `TensorData` back to a Tensor.
34    pub fn to_tensor(&self) -> Result<Tensor<f32>> {
35        Tensor::from_vec(self.values.clone(), &self.shape)
36    }
37
38    /// Get the number of elements.
39    #[must_use] pub fn numel(&self) -> usize {
40        self.values.len()
41    }
42
43    /// Get the shape.
44    #[must_use] pub fn shape(&self) -> &[usize] {
45        &self.shape
46    }
47}
48
49// =============================================================================
50// StateDictEntry
51// =============================================================================
52
53/// An entry in the state dictionary.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StateDictEntry {
56    /// The tensor data.
57    pub data: TensorData,
58    /// Whether this parameter requires gradients.
59    pub requires_grad: bool,
60    /// Optional metadata.
61    #[serde(default)]
62    pub metadata: HashMap<String, String>,
63}
64
65impl StateDictEntry {
66    /// Create a new entry from tensor data.
67    #[must_use] pub fn new(data: TensorData, requires_grad: bool) -> Self {
68        Self {
69            data,
70            requires_grad,
71            metadata: HashMap::new(),
72        }
73    }
74
75    /// Add metadata to the entry.
76    #[must_use] pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
77        self.metadata.insert(key.to_string(), value.to_string());
78        self
79    }
80}
81
82// =============================================================================
83// StateDict
84// =============================================================================
85
86/// State dictionary for storing model parameters.
87///
88/// This is similar to `PyTorch`'s `state_dict`, mapping parameter names to tensors.
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct StateDict {
91    entries: HashMap<String, StateDictEntry>,
92    #[serde(default)]
93    metadata: HashMap<String, String>,
94}
95
96impl StateDict {
97    /// Create an empty state dictionary.
98    #[must_use] pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Create a state dictionary from a module.
103    pub fn from_module<M: Module>(module: &M) -> Self {
104        let mut state_dict = Self::new();
105
106        for param in module.parameters() {
107            let name = param.name().to_string();
108            let tensor_data = TensorData::from_tensor(&param.data());
109            let entry = StateDictEntry::new(tensor_data, param.requires_grad());
110            state_dict.entries.insert(name, entry);
111        }
112
113        state_dict
114    }
115
116    /// Insert a tensor into the state dictionary.
117    pub fn insert(&mut self, name: String, data: TensorData) {
118        let entry = StateDictEntry::new(data, true);
119        self.entries.insert(name, entry);
120    }
121
122    /// Insert an entry into the state dictionary.
123    pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
124        self.entries.insert(name, entry);
125    }
126
127    /// Get an entry by name.
128    #[must_use] pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
129        self.entries.get(name)
130    }
131
132    /// Get a mutable entry by name.
133    pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
134        self.entries.get_mut(name)
135    }
136
137    /// Check if the state dictionary contains a key.
138    #[must_use] pub fn contains(&self, name: &str) -> bool {
139        self.entries.contains_key(name)
140    }
141
142    /// Get the number of entries.
143    #[must_use] pub fn len(&self) -> usize {
144        self.entries.len()
145    }
146
147    /// Check if the state dictionary is empty.
148    #[must_use] pub fn is_empty(&self) -> bool {
149        self.entries.is_empty()
150    }
151
152    /// Get all keys.
153    pub fn keys(&self) -> impl Iterator<Item = &String> {
154        self.entries.keys()
155    }
156
157    /// Get all entries.
158    pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
159        self.entries.iter()
160    }
161
162    /// Remove an entry.
163    pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
164        self.entries.remove(name)
165    }
166
167    /// Merge another state dictionary into this one.
168    pub fn merge(&mut self, other: StateDict) {
169        for (name, entry) in other.entries {
170            self.entries.insert(name, entry);
171        }
172    }
173
174    /// Get a subset of entries matching a prefix.
175    #[must_use] pub fn filter_prefix(&self, prefix: &str) -> StateDict {
176        let mut filtered = StateDict::new();
177        for (name, entry) in &self.entries {
178            if name.starts_with(prefix) {
179                filtered.entries.insert(name.clone(), entry.clone());
180            }
181        }
182        filtered
183    }
184
185    /// Strip a prefix from all keys.
186    #[must_use] pub fn strip_prefix(&self, prefix: &str) -> StateDict {
187        let mut stripped = StateDict::new();
188        for (name, entry) in &self.entries {
189            let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
190            stripped.entries.insert(new_name, entry.clone());
191        }
192        stripped
193    }
194
195    /// Add a prefix to all keys.
196    #[must_use] pub fn add_prefix(&self, prefix: &str) -> StateDict {
197        let mut prefixed = StateDict::new();
198        for (name, entry) in &self.entries {
199            let new_name = format!("{prefix}{name}");
200            prefixed.entries.insert(new_name, entry.clone());
201        }
202        prefixed
203    }
204
205    /// Set metadata on the state dictionary.
206    pub fn set_metadata(&mut self, key: &str, value: &str) {
207        self.metadata.insert(key.to_string(), value.to_string());
208    }
209
210    /// Get metadata from the state dictionary.
211    #[must_use] pub fn get_metadata(&self, key: &str) -> Option<&String> {
212        self.metadata.get(key)
213    }
214
215    /// Get total number of parameters (elements across all tensors).
216    #[must_use] pub fn total_params(&self) -> usize {
217        self.entries.values().map(|e| e.data.numel()).sum()
218    }
219
220    /// Get total size in bytes.
221    #[must_use] pub fn size_bytes(&self) -> usize {
222        self.total_params() * std::mem::size_of::<f32>()
223    }
224
225    /// Print a summary of the state dictionary.
226    #[must_use] pub fn summary(&self) -> String {
227        let mut lines = Vec::new();
228        lines.push(format!("StateDict with {} entries:", self.len()));
229        lines.push(format!("  Total parameters: {}", self.total_params()));
230        lines.push(format!("  Size: {} bytes", self.size_bytes()));
231        lines.push("  Entries:".to_string());
232
233        for (name, entry) in &self.entries {
234            lines.push(format!(
235                "    {} - shape: {:?}, numel: {}",
236                name,
237                entry.data.shape,
238                entry.data.numel()
239            ));
240        }
241
242        lines.join("\n")
243    }
244}
245
246// =============================================================================
247// Tests
248// =============================================================================
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_tensor_data_roundtrip() {
256        let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
257        let data = TensorData::from_tensor(&original);
258        let restored = data.to_tensor().unwrap();
259
260        assert_eq!(original.shape(), restored.shape());
261        assert_eq!(original.to_vec(), restored.to_vec());
262    }
263
264    #[test]
265    fn test_state_dict_operations() {
266        let mut state_dict = StateDict::new();
267
268        let data1 = TensorData {
269            shape: vec![10, 5],
270            values: vec![0.0; 50],
271        };
272        let data2 = TensorData {
273            shape: vec![5],
274            values: vec![0.0; 5],
275        };
276
277        state_dict.insert("linear.weight".to_string(), data1);
278        state_dict.insert("linear.bias".to_string(), data2);
279
280        assert_eq!(state_dict.len(), 2);
281        assert_eq!(state_dict.total_params(), 55);
282        assert!(state_dict.contains("linear.weight"));
283        assert!(state_dict.contains("linear.bias"));
284    }
285
286    #[test]
287    fn test_state_dict_filter_prefix() {
288        let mut state_dict = StateDict::new();
289
290        state_dict.insert(
291            "encoder.layer1.weight".to_string(),
292            TensorData {
293                shape: vec![10],
294                values: vec![0.0; 10],
295            },
296        );
297        state_dict.insert(
298            "encoder.layer1.bias".to_string(),
299            TensorData {
300                shape: vec![10],
301                values: vec![0.0; 10],
302            },
303        );
304        state_dict.insert(
305            "decoder.layer1.weight".to_string(),
306            TensorData {
307                shape: vec![10],
308                values: vec![0.0; 10],
309            },
310        );
311
312        let encoder_dict = state_dict.filter_prefix("encoder.");
313        assert_eq!(encoder_dict.len(), 2);
314        assert!(encoder_dict.contains("encoder.layer1.weight"));
315    }
316
317    #[test]
318    fn test_state_dict_strip_prefix() {
319        let mut state_dict = StateDict::new();
320
321        state_dict.insert(
322            "model.linear.weight".to_string(),
323            TensorData {
324                shape: vec![10],
325                values: vec![0.0; 10],
326            },
327        );
328
329        let stripped = state_dict.strip_prefix("model.");
330        assert!(stripped.contains("linear.weight"));
331    }
332
333    #[test]
334    fn test_state_dict_merge() {
335        let mut dict1 = StateDict::new();
336        dict1.insert(
337            "a".to_string(),
338            TensorData {
339                shape: vec![1],
340                values: vec![1.0],
341            },
342        );
343
344        let mut dict2 = StateDict::new();
345        dict2.insert(
346            "b".to_string(),
347            TensorData {
348                shape: vec![1],
349                values: vec![2.0],
350            },
351        );
352
353        dict1.merge(dict2);
354        assert_eq!(dict1.len(), 2);
355        assert!(dict1.contains("a"));
356        assert!(dict1.contains("b"));
357    }
358
359    #[test]
360    fn test_state_dict_summary() {
361        let mut state_dict = StateDict::new();
362        state_dict.insert(
363            "weight".to_string(),
364            TensorData {
365                shape: vec![10, 5],
366                values: vec![0.0; 50],
367            },
368        );
369
370        let summary = state_dict.summary();
371        assert!(summary.contains("1 entries"));
372        assert!(summary.contains("50"));
373    }
374}