1use axonml_core::Result;
6use axonml_nn::Module;
7use axonml_tensor::Tensor;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorData {
18 pub shape: Vec<usize>,
20 pub values: Vec<f32>,
22}
23
24impl TensorData {
25 #[must_use]
27 pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
28 Self {
29 shape: tensor.shape().to_vec(),
30 values: tensor.to_vec(),
31 }
32 }
33
34 pub fn to_tensor(&self) -> Result<Tensor<f32>> {
36 Tensor::from_vec(self.values.clone(), &self.shape)
37 }
38
39 #[must_use]
41 pub fn numel(&self) -> usize {
42 self.values.len()
43 }
44
45 #[must_use]
47 pub fn shape(&self) -> &[usize] {
48 &self.shape
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct StateDictEntry {
59 pub data: TensorData,
61 pub requires_grad: bool,
63 #[serde(default)]
65 pub metadata: HashMap<String, String>,
66}
67
68impl StateDictEntry {
69 #[must_use]
71 pub fn new(data: TensorData, requires_grad: bool) -> Self {
72 Self {
73 data,
74 requires_grad,
75 metadata: HashMap::new(),
76 }
77 }
78
79 #[must_use]
81 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
82 self.metadata.insert(key.to_string(), value.to_string());
83 self
84 }
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
95pub struct StateDict {
96 entries: HashMap<String, StateDictEntry>,
97 #[serde(default)]
98 metadata: HashMap<String, String>,
99}
100
101impl StateDict {
102 #[must_use]
104 pub fn new() -> Self {
105 Self::default()
106 }
107
108 pub fn from_module<M: Module>(module: &M) -> Self {
114 let mut state_dict = Self::new();
115
116 let named = module.named_parameters();
117 if !named.is_empty() {
118 for (name, param) in named {
119 let tensor_data = TensorData::from_tensor(¶m.data());
120 let entry = StateDictEntry::new(tensor_data, param.requires_grad());
121 state_dict.entries.insert(name, entry);
122 }
123 } else {
124 for param in module.parameters() {
126 let name = param.name().to_string();
127 let tensor_data = TensorData::from_tensor(¶m.data());
128 let entry = StateDictEntry::new(tensor_data, param.requires_grad());
129 state_dict.entries.insert(name, entry);
130 }
131 }
132
133 state_dict
134 }
135
136 pub fn insert(&mut self, name: String, data: TensorData) {
138 let entry = StateDictEntry::new(data, true);
139 self.entries.insert(name, entry);
140 }
141
142 pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
144 self.entries.insert(name, entry);
145 }
146
147 #[must_use]
149 pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
150 self.entries.get(name)
151 }
152
153 pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
155 self.entries.get_mut(name)
156 }
157
158 #[must_use]
160 pub fn contains(&self, name: &str) -> bool {
161 self.entries.contains_key(name)
162 }
163
164 #[must_use]
166 pub fn len(&self) -> usize {
167 self.entries.len()
168 }
169
170 #[must_use]
172 pub fn is_empty(&self) -> bool {
173 self.entries.is_empty()
174 }
175
176 pub fn keys(&self) -> impl Iterator<Item = &String> {
178 self.entries.keys()
179 }
180
181 pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
183 self.entries.iter()
184 }
185
186 pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
188 self.entries.remove(name)
189 }
190
191 pub fn merge(&mut self, other: StateDict) {
193 for (name, entry) in other.entries {
194 self.entries.insert(name, entry);
195 }
196 }
197
198 #[must_use]
200 pub fn filter_prefix(&self, prefix: &str) -> StateDict {
201 let mut filtered = StateDict::new();
202 for (name, entry) in &self.entries {
203 if name.starts_with(prefix) {
204 filtered.entries.insert(name.clone(), entry.clone());
205 }
206 }
207 filtered
208 }
209
210 #[must_use]
212 pub fn strip_prefix(&self, prefix: &str) -> StateDict {
213 let mut stripped = StateDict::new();
214 for (name, entry) in &self.entries {
215 let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
216 stripped.entries.insert(new_name, entry.clone());
217 }
218 stripped
219 }
220
221 #[must_use]
223 pub fn add_prefix(&self, prefix: &str) -> StateDict {
224 let mut prefixed = StateDict::new();
225 for (name, entry) in &self.entries {
226 let new_name = format!("{prefix}{name}");
227 prefixed.entries.insert(new_name, entry.clone());
228 }
229 prefixed
230 }
231
232 pub fn set_metadata(&mut self, key: &str, value: &str) {
234 self.metadata.insert(key.to_string(), value.to_string());
235 }
236
237 #[must_use]
239 pub fn get_metadata(&self, key: &str) -> Option<&String> {
240 self.metadata.get(key)
241 }
242
243 #[must_use]
245 pub fn total_params(&self) -> usize {
246 self.entries.values().map(|e| e.data.numel()).sum()
247 }
248
249 #[must_use]
251 pub fn size_bytes(&self) -> usize {
252 self.total_params() * std::mem::size_of::<f32>()
253 }
254
255 #[must_use]
257 pub fn summary(&self) -> String {
258 let mut lines = Vec::new();
259 lines.push(format!("StateDict with {} entries:", self.len()));
260 lines.push(format!(" Total parameters: {}", self.total_params()));
261 lines.push(format!(" Size: {} bytes", self.size_bytes()));
262 lines.push(" Entries:".to_string());
263
264 for (name, entry) in &self.entries {
265 lines.push(format!(
266 " {} - shape: {:?}, numel: {}",
267 name,
268 entry.data.shape,
269 entry.data.numel()
270 ));
271 }
272
273 lines.join("\n")
274 }
275}
276
277#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_tensor_data_roundtrip() {
287 let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
288 let data = TensorData::from_tensor(&original);
289 let restored = data.to_tensor().unwrap();
290
291 assert_eq!(original.shape(), restored.shape());
292 assert_eq!(original.to_vec(), restored.to_vec());
293 }
294
295 #[test]
296 fn test_state_dict_operations() {
297 let mut state_dict = StateDict::new();
298
299 let data1 = TensorData {
300 shape: vec![10, 5],
301 values: vec![0.0; 50],
302 };
303 let data2 = TensorData {
304 shape: vec![5],
305 values: vec![0.0; 5],
306 };
307
308 state_dict.insert("linear.weight".to_string(), data1);
309 state_dict.insert("linear.bias".to_string(), data2);
310
311 assert_eq!(state_dict.len(), 2);
312 assert_eq!(state_dict.total_params(), 55);
313 assert!(state_dict.contains("linear.weight"));
314 assert!(state_dict.contains("linear.bias"));
315 }
316
317 #[test]
318 fn test_state_dict_filter_prefix() {
319 let mut state_dict = StateDict::new();
320
321 state_dict.insert(
322 "encoder.layer1.weight".to_string(),
323 TensorData {
324 shape: vec![10],
325 values: vec![0.0; 10],
326 },
327 );
328 state_dict.insert(
329 "encoder.layer1.bias".to_string(),
330 TensorData {
331 shape: vec![10],
332 values: vec![0.0; 10],
333 },
334 );
335 state_dict.insert(
336 "decoder.layer1.weight".to_string(),
337 TensorData {
338 shape: vec![10],
339 values: vec![0.0; 10],
340 },
341 );
342
343 let encoder_dict = state_dict.filter_prefix("encoder.");
344 assert_eq!(encoder_dict.len(), 2);
345 assert!(encoder_dict.contains("encoder.layer1.weight"));
346 }
347
348 #[test]
349 fn test_state_dict_strip_prefix() {
350 let mut state_dict = StateDict::new();
351
352 state_dict.insert(
353 "model.linear.weight".to_string(),
354 TensorData {
355 shape: vec![10],
356 values: vec![0.0; 10],
357 },
358 );
359
360 let stripped = state_dict.strip_prefix("model.");
361 assert!(stripped.contains("linear.weight"));
362 }
363
364 #[test]
365 fn test_state_dict_merge() {
366 let mut dict1 = StateDict::new();
367 dict1.insert(
368 "a".to_string(),
369 TensorData {
370 shape: vec![1],
371 values: vec![1.0],
372 },
373 );
374
375 let mut dict2 = StateDict::new();
376 dict2.insert(
377 "b".to_string(),
378 TensorData {
379 shape: vec![1],
380 values: vec![2.0],
381 },
382 );
383
384 dict1.merge(dict2);
385 assert_eq!(dict1.len(), 2);
386 assert!(dict1.contains("a"));
387 assert!(dict1.contains("b"));
388 }
389
390 #[test]
391 fn test_state_dict_summary() {
392 let mut state_dict = StateDict::new();
393 state_dict.insert(
394 "weight".to_string(),
395 TensorData {
396 shape: vec![10, 5],
397 values: vec![0.0; 50],
398 },
399 );
400
401 let summary = state_dict.summary();
402 assert!(summary.contains("1 entries"));
403 assert!(summary.contains("50"));
404 }
405}