1use axonml_core::Result;
6use axonml_nn::Module;
7use axonml_tensor::Tensor;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorData {
18 pub shape: Vec<usize>,
20 pub values: Vec<f32>,
22}
23
24impl TensorData {
25 #[must_use] pub fn from_tensor(tensor: &Tensor<f32>) -> Self {
27 Self {
28 shape: tensor.shape().to_vec(),
29 values: tensor.to_vec(),
30 }
31 }
32
33 pub fn to_tensor(&self) -> Result<Tensor<f32>> {
35 Tensor::from_vec(self.values.clone(), &self.shape)
36 }
37
38 #[must_use] pub fn numel(&self) -> usize {
40 self.values.len()
41 }
42
43 #[must_use] pub fn shape(&self) -> &[usize] {
45 &self.shape
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StateDictEntry {
56 pub data: TensorData,
58 pub requires_grad: bool,
60 #[serde(default)]
62 pub metadata: HashMap<String, String>,
63}
64
65impl StateDictEntry {
66 #[must_use] pub fn new(data: TensorData, requires_grad: bool) -> Self {
68 Self {
69 data,
70 requires_grad,
71 metadata: HashMap::new(),
72 }
73 }
74
75 #[must_use] pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
77 self.metadata.insert(key.to_string(), value.to_string());
78 self
79 }
80}
81
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct StateDict {
91 entries: HashMap<String, StateDictEntry>,
92 #[serde(default)]
93 metadata: HashMap<String, String>,
94}
95
96impl StateDict {
97 #[must_use] pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn from_module<M: Module>(module: &M) -> Self {
104 let mut state_dict = Self::new();
105
106 for param in module.parameters() {
107 let name = param.name().to_string();
108 let tensor_data = TensorData::from_tensor(¶m.data());
109 let entry = StateDictEntry::new(tensor_data, param.requires_grad());
110 state_dict.entries.insert(name, entry);
111 }
112
113 state_dict
114 }
115
116 pub fn insert(&mut self, name: String, data: TensorData) {
118 let entry = StateDictEntry::new(data, true);
119 self.entries.insert(name, entry);
120 }
121
122 pub fn insert_entry(&mut self, name: String, entry: StateDictEntry) {
124 self.entries.insert(name, entry);
125 }
126
127 #[must_use] pub fn get(&self, name: &str) -> Option<&StateDictEntry> {
129 self.entries.get(name)
130 }
131
132 pub fn get_mut(&mut self, name: &str) -> Option<&mut StateDictEntry> {
134 self.entries.get_mut(name)
135 }
136
137 #[must_use] pub fn contains(&self, name: &str) -> bool {
139 self.entries.contains_key(name)
140 }
141
142 #[must_use] pub fn len(&self) -> usize {
144 self.entries.len()
145 }
146
147 #[must_use] pub fn is_empty(&self) -> bool {
149 self.entries.is_empty()
150 }
151
152 pub fn keys(&self) -> impl Iterator<Item = &String> {
154 self.entries.keys()
155 }
156
157 pub fn entries(&self) -> impl Iterator<Item = (&String, &StateDictEntry)> {
159 self.entries.iter()
160 }
161
162 pub fn remove(&mut self, name: &str) -> Option<StateDictEntry> {
164 self.entries.remove(name)
165 }
166
167 pub fn merge(&mut self, other: StateDict) {
169 for (name, entry) in other.entries {
170 self.entries.insert(name, entry);
171 }
172 }
173
174 #[must_use] pub fn filter_prefix(&self, prefix: &str) -> StateDict {
176 let mut filtered = StateDict::new();
177 for (name, entry) in &self.entries {
178 if name.starts_with(prefix) {
179 filtered.entries.insert(name.clone(), entry.clone());
180 }
181 }
182 filtered
183 }
184
185 #[must_use] pub fn strip_prefix(&self, prefix: &str) -> StateDict {
187 let mut stripped = StateDict::new();
188 for (name, entry) in &self.entries {
189 let new_name = name.strip_prefix(prefix).unwrap_or(name).to_string();
190 stripped.entries.insert(new_name, entry.clone());
191 }
192 stripped
193 }
194
195 #[must_use] pub fn add_prefix(&self, prefix: &str) -> StateDict {
197 let mut prefixed = StateDict::new();
198 for (name, entry) in &self.entries {
199 let new_name = format!("{prefix}{name}");
200 prefixed.entries.insert(new_name, entry.clone());
201 }
202 prefixed
203 }
204
205 pub fn set_metadata(&mut self, key: &str, value: &str) {
207 self.metadata.insert(key.to_string(), value.to_string());
208 }
209
210 #[must_use] pub fn get_metadata(&self, key: &str) -> Option<&String> {
212 self.metadata.get(key)
213 }
214
215 #[must_use] pub fn total_params(&self) -> usize {
217 self.entries.values().map(|e| e.data.numel()).sum()
218 }
219
220 #[must_use] pub fn size_bytes(&self) -> usize {
222 self.total_params() * std::mem::size_of::<f32>()
223 }
224
225 #[must_use] pub fn summary(&self) -> String {
227 let mut lines = Vec::new();
228 lines.push(format!("StateDict with {} entries:", self.len()));
229 lines.push(format!(" Total parameters: {}", self.total_params()));
230 lines.push(format!(" Size: {} bytes", self.size_bytes()));
231 lines.push(" Entries:".to_string());
232
233 for (name, entry) in &self.entries {
234 lines.push(format!(
235 " {} - shape: {:?}, numel: {}",
236 name,
237 entry.data.shape,
238 entry.data.numel()
239 ));
240 }
241
242 lines.join("\n")
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_tensor_data_roundtrip() {
256 let original = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
257 let data = TensorData::from_tensor(&original);
258 let restored = data.to_tensor().unwrap();
259
260 assert_eq!(original.shape(), restored.shape());
261 assert_eq!(original.to_vec(), restored.to_vec());
262 }
263
264 #[test]
265 fn test_state_dict_operations() {
266 let mut state_dict = StateDict::new();
267
268 let data1 = TensorData {
269 shape: vec![10, 5],
270 values: vec![0.0; 50],
271 };
272 let data2 = TensorData {
273 shape: vec![5],
274 values: vec![0.0; 5],
275 };
276
277 state_dict.insert("linear.weight".to_string(), data1);
278 state_dict.insert("linear.bias".to_string(), data2);
279
280 assert_eq!(state_dict.len(), 2);
281 assert_eq!(state_dict.total_params(), 55);
282 assert!(state_dict.contains("linear.weight"));
283 assert!(state_dict.contains("linear.bias"));
284 }
285
286 #[test]
287 fn test_state_dict_filter_prefix() {
288 let mut state_dict = StateDict::new();
289
290 state_dict.insert(
291 "encoder.layer1.weight".to_string(),
292 TensorData {
293 shape: vec![10],
294 values: vec![0.0; 10],
295 },
296 );
297 state_dict.insert(
298 "encoder.layer1.bias".to_string(),
299 TensorData {
300 shape: vec![10],
301 values: vec![0.0; 10],
302 },
303 );
304 state_dict.insert(
305 "decoder.layer1.weight".to_string(),
306 TensorData {
307 shape: vec![10],
308 values: vec![0.0; 10],
309 },
310 );
311
312 let encoder_dict = state_dict.filter_prefix("encoder.");
313 assert_eq!(encoder_dict.len(), 2);
314 assert!(encoder_dict.contains("encoder.layer1.weight"));
315 }
316
317 #[test]
318 fn test_state_dict_strip_prefix() {
319 let mut state_dict = StateDict::new();
320
321 state_dict.insert(
322 "model.linear.weight".to_string(),
323 TensorData {
324 shape: vec![10],
325 values: vec![0.0; 10],
326 },
327 );
328
329 let stripped = state_dict.strip_prefix("model.");
330 assert!(stripped.contains("linear.weight"));
331 }
332
333 #[test]
334 fn test_state_dict_merge() {
335 let mut dict1 = StateDict::new();
336 dict1.insert(
337 "a".to_string(),
338 TensorData {
339 shape: vec![1],
340 values: vec![1.0],
341 },
342 );
343
344 let mut dict2 = StateDict::new();
345 dict2.insert(
346 "b".to_string(),
347 TensorData {
348 shape: vec![1],
349 values: vec![2.0],
350 },
351 );
352
353 dict1.merge(dict2);
354 assert_eq!(dict1.len(), 2);
355 assert!(dict1.contains("a"));
356 assert!(dict1.contains("b"));
357 }
358
359 #[test]
360 fn test_state_dict_summary() {
361 let mut state_dict = StateDict::new();
362 state_dict.insert(
363 "weight".to_string(),
364 TensorData {
365 shape: vec![10, 5],
366 values: vec![0.0; 50],
367 },
368 );
369
370 let summary = state_dict.summary();
371 assert!(summary.contains("1 entries"));
372 assert!(summary.contains("50"));
373 }
374}