1use axonml_core::Result;
18use axonml_nn::Module;
19use axonml_tensor::Tensor;
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TensorData {
30 pub shape: Vec<usize>,
32 pub values: Vec<f32>,
34}
35
36impl TensorData {
37 #[must_use]
39 pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
40 Self {
41 shape: tensor.shape().to_vec(),
42 values: tensor.to_vec(),
43 }
44 }
45
46 pub fn to_tensor(&self) -> Result<Tensor<f32>> {
48 Tensor::from_vec(self.values.clone(), &self.shape)
49 }
50
51 #[must_use]
53 pub fn numel(&self) -> usize {
54 self.values.len()
55 }
56
57 #[must_use]
59 pub fn shape(&self) -> &[usize] {
60 &self.shape
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct StateDictEntry {
71 pub data: TensorData,
73 pub requires_grad: bool,
75 #[serde(default)]
77 pub metadata: HashMap<String, String>,
78}
79
80impl StateDictEntry {
81 #[must_use]
83 pub fn new(data: TensorData, requires_grad: bool) -> Self {
84 Self {
85 data,
86 requires_grad,
87 metadata: HashMap::new(),
88 }
89 }
90
91 #[must_use]
93 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
94 self.metadata.insert(key.to_string(), value.to_string());
95 self
96 }
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct StateDict {
108 entries: HashMap<String, StateDictEntry>,
109 #[serde(default)]
110 metadata: HashMap<String, String>,
111}
112
113impl StateDict {
114 #[must_use]
116 pub fn new() -> Self {
117 Self::default()
118 }
119
120 pub fn from_module<M: Module>(module: &M) -> Self {
126 let mut state_dict = Self::new();
127
128 let named = module.named_parameters();
129 if named.is_empty() {
130 for (i, param) in module.parameters().iter().enumerate() {
135 let name = format!("param_{i}");
136 let tensor_data = TensorData::from_tensor(¶m.data());
137 let entry = StateDictEntry::new(tensor_data, param.requires_grad());
138 state_dict.entries.insert(name, entry);
139 }
140 } else {
141 for (name, param) in named {
142 let tensor_data = TensorData::from_tensor(¶m.data());
143 let entry = StateDictEntry::new(tensor_data, param.requires_grad());
144 state_dict.entries.insert(name, entry);
145 }
146 }
147
148 state_dict
149 }
150
151 pub fn insert(&mut self, name: String, data: TensorData) {
153 let entry = StateDictEntry::new(data, true);
154 self.entries.insert(name, entry);
155 }
156
157 pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
159 self.entries.insert(name, entry);
160 }
161
162 #[must_use]
164 pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
165 self.entries.get(name)
166 }
167
168 pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
170 self.entries.get_mut(name)
171 }
172
173 #[must_use]
175 pub fn contains(&self, name: &str) -> bool {
176 self.entries.contains_key(name)
177 }
178
179 #[must_use]
181 pub fn len(&self) -> usize {
182 self.entries.len()
183 }
184
185 #[must_use]
187 pub fn is_empty(&self) -> bool {
188 self.entries.is_empty()
189 }
190
191 pub fn keys(&self) -> impl Iterator<Item = &String> {
193 self.entries.keys()
194 }
195
196 pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
198 self.entries.iter()
199 }
200
201 pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
203 self.entries.remove(name)
204 }
205
206 pub fn merge(&mut self, other: StateDict) {
208 for (name, entry) in other.entries {
209 self.entries.insert(name, entry);
210 }
211 }
212
213 #[must_use]
215 pub fn filter_prefix(&self, prefix: &str) -> StateDict {
216 let mut filtered = StateDict::new();
217 for (name, entry) in &self.entries {
218 if name.starts_with(prefix) {
219 filtered.entries.insert(name.clone(), entry.clone());
220 }
221 }
222 filtered
223 }
224
225 #[must_use]
227 pub fn strip_prefix(&self, prefix: &str) -> StateDict {
228 let mut stripped = StateDict::new();
229 for (name, entry) in &self.entries {
230 let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
231 stripped.entries.insert(new_name, entry.clone());
232 }
233 stripped
234 }
235
236 #[must_use]
238 pub fn add_prefix(&self, prefix: &str) -> StateDict {
239 let mut prefixed = StateDict::new();
240 for (name, entry) in &self.entries {
241 let new_name = format!("{prefix}{name}");
242 prefixed.entries.insert(new_name, entry.clone());
243 }
244 prefixed
245 }
246
247 pub fn set_metadata(&mut self, key: &str, value: &str) {
249 self.metadata.insert(key.to_string(), value.to_string());
250 }
251
252 #[must_use]
254 pub fn get_metadata(&self, key: &str) -> Option<&String> {
255 self.metadata.get(key)
256 }
257
258 #[must_use]
260 pub fn total_params(&self) -> usize {
261 self.entries.values().map(|e| e.data.numel()).sum()
262 }
263
264 #[must_use]
266 pub fn size_bytes(&self) -> usize {
267 self.total_params() * std::mem::size_of::<f32>()
268 }
269
270 #[must_use]
272 pub fn summary(&self) -> String {
273 let mut lines = Vec::new();
274 lines.push(format!("StateDict with {} entries:", self.len()));
275 lines.push(format!(" Total parameters: {}", self.total_params()));
276 lines.push(format!(" Size: {} bytes", self.size_bytes()));
277 lines.push(" Entries:".to_string());
278
279 for (name, entry) in &self.entries {
280 lines.push(format!(
281 " {} - shape: {:?}, numel: {}",
282 name,
283 entry.data.shape,
284 entry.data.numel()
285 ));
286 }
287
288 lines.join("\n")
289 }
290}
291
292#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_tensor_data_roundtrip() {
302 let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
303 let data = TensorData::from_tensor(&original);
304 let restored = data.to_tensor().unwrap();
305
306 assert_eq!(original.shape(), restored.shape());
307 assert_eq!(original.to_vec(), restored.to_vec());
308 }
309
310 #[test]
311 fn test_state_dict_operations() {
312 let mut state_dict = StateDict::new();
313
314 let data1 = TensorData {
315 shape: vec![10, 5],
316 values: vec![0.0; 50],
317 };
318 let data2 = TensorData {
319 shape: vec![5],
320 values: vec![0.0; 5],
321 };
322
323 state_dict.insert("linear.weight".to_string(), data1);
324 state_dict.insert("linear.bias".to_string(), data2);
325
326 assert_eq!(state_dict.len(), 2);
327 assert_eq!(state_dict.total_params(), 55);
328 assert!(state_dict.contains("linear.weight"));
329 assert!(state_dict.contains("linear.bias"));
330 }
331
332 #[test]
333 fn test_state_dict_filter_prefix() {
334 let mut state_dict = StateDict::new();
335
336 state_dict.insert(
337 "encoder.layer1.weight".to_string(),
338 TensorData {
339 shape: vec![10],
340 values: vec![0.0; 10],
341 },
342 );
343 state_dict.insert(
344 "encoder.layer1.bias".to_string(),
345 TensorData {
346 shape: vec![10],
347 values: vec![0.0; 10],
348 },
349 );
350 state_dict.insert(
351 "decoder.layer1.weight".to_string(),
352 TensorData {
353 shape: vec![10],
354 values: vec![0.0; 10],
355 },
356 );
357
358 let encoder_dict = state_dict.filter_prefix("encoder.");
359 assert_eq!(encoder_dict.len(), 2);
360 assert!(encoder_dict.contains("encoder.layer1.weight"));
361 }
362
363 #[test]
364 fn test_state_dict_strip_prefix() {
365 let mut state_dict = StateDict::new();
366
367 state_dict.insert(
368 "model.linear.weight".to_string(),
369 TensorData {
370 shape: vec![10],
371 values: vec![0.0; 10],
372 },
373 );
374
375 let stripped = state_dict.strip_prefix("model.");
376 assert!(stripped.contains("linear.weight"));
377 }
378
379 #[test]
380 fn test_state_dict_merge() {
381 let mut dict1 = StateDict::new();
382 dict1.insert(
383 "a".to_string(),
384 TensorData {
385 shape: vec![1],
386 values: vec![1.0],
387 },
388 );
389
390 let mut dict2 = StateDict::new();
391 dict2.insert(
392 "b".to_string(),
393 TensorData {
394 shape: vec![1],
395 values: vec![2.0],
396 },
397 );
398
399 dict1.merge(dict2);
400 assert_eq!(dict1.len(), 2);
401 assert!(dict1.contains("a"));
402 assert!(dict1.contains("b"));
403 }
404
405 #[test]
406 fn test_state_dict_summary() {
407 let mut state_dict = StateDict::new();
408 state_dict.insert(
409 "weight".to_string(),
410 TensorData {
411 shape: vec![10, 5],
412 values: vec![0.0; 50],
413 },
414 );
415
416 let summary = state_dict.summary();
417 assert!(summary.contains("1 entries"));
418 assert!(summary.contains("50"));
419 }
420}