1use crate::errors::{GraphError, GraphResult};
41use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
42use crate::graph::Graph;
43use smallvec::SmallVec;
44use std::collections::HashMap;
45use std::path::Path;
46
47#[derive(Debug, Clone, PartialEq)]
49pub enum OperatorType {
50 Attention {
52 num_heads: usize,
54 hidden_dim: usize,
56 },
57 MLP {
59 hidden_dim: usize,
61 activation: String,
63 },
64 Norm {
66 norm_type: String,
68 eps: f64,
70 },
71 Embedding {
73 vocab_size: usize,
75 embed_dim: usize,
77 },
78 Linear {
80 in_features: usize,
82 out_features: usize,
84 },
85 Residual,
87 Custom {
89 name: String,
91 },
92}
93
94#[repr(align(64))]
103#[derive(Clone, Debug)]
104pub struct WeightTensor {
105 pub data: Box<[f64]>,
107 pub shape: SmallVec<[usize; 4]>,
109 pub strides: SmallVec<[usize; 4]>,
111 pub name: String,
113}
114
115impl WeightTensor {
116 pub fn new(name: String, data: Vec<f64>, shape: Vec<usize>) -> Self {
126 let expected_len = shape.iter().product::<usize>();
127 assert_eq!(
128 data.len(),
129 expected_len,
130 "Data length {} mismatch with shape {:?} (expected {})",
131 data.len(),
132 shape,
133 expected_len
134 );
135
136 let strides = compute_strides(&shape);
137 Self {
138 data: data.into_boxed_slice(),
139 shape: shape.into(),
140 strides: strides.into(),
141 name,
142 }
143 }
144
145 pub fn with_strides(
153 name: String,
154 data: Vec<f64>,
155 shape: Vec<usize>,
156 strides: Vec<usize>,
157 ) -> Self {
158 let expected_len = shape.iter().product::<usize>();
159 assert_eq!(
160 data.len(),
161 expected_len,
162 "Data length {} mismatch with shape {:?}",
163 data.len(),
164 shape
165 );
166
167 Self {
168 data: data.into_boxed_slice(),
169 shape: shape.into(),
170 strides: strides.into(),
171 name,
172 }
173 }
174
175 pub fn ndim(&self) -> usize {
177 self.shape.len()
178 }
179
180 pub fn numel(&self) -> usize {
182 self.data.len()
183 }
184
185 pub fn shape(&self) -> &[usize] {
187 &self.shape
188 }
189
190 pub fn strides(&self) -> &[usize] {
192 &self.strides
193 }
194
195 pub fn data(&self) -> &[f64] {
197 &self.data
198 }
199
200 pub fn as_slice_mut(&mut self) -> &mut [f64] {
202 &mut self.data
203 }
204
205 pub fn reshape_mut(&mut self, new_shape: Vec<usize>) -> Result<(), TensorReshapeError> {
213 let new_size = new_shape.iter().product::<usize>();
214 if new_size != self.data.len() {
215 return Err(TensorReshapeError {
216 expected: self.data.len(),
217 got: new_size,
218 });
219 }
220 self.shape = new_shape.into();
221 self.strides = compute_strides(&self.shape).into();
222 Ok(())
223 }
224
225 pub fn l2_norm(&self) -> f64 {
227 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
228 }
229
230 pub fn l2_diff(&self, other: &Self) -> f64 {
232 if self.shape != other.shape {
233 return f64::MAX;
234 }
235 self.data
236 .iter()
237 .zip(other.data.iter())
238 .map(|(a, b)| (a - b).powi(2))
239 .sum::<f64>()
240 .sqrt()
241 }
242
243 pub fn get(&self, indices: &[usize]) -> Option<f64> {
251 if indices.len() != self.shape.len() {
252 return None;
253 }
254
255 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
256 if idx >= dim {
257 return None;
258 }
259 }
260
261 let offset = indices
262 .iter()
263 .zip(self.strides.iter())
264 .map(|(&idx, &stride)| idx * stride)
265 .sum::<usize>();
266
267 self.data.get(offset).copied()
268 }
269
270 pub fn set(&mut self, indices: &[usize], value: f64) -> bool {
279 if indices.len() != self.shape.len() {
280 return false;
281 }
282
283 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
284 if idx >= dim {
285 return false;
286 }
287 }
288
289 let offset = indices
290 .iter()
291 .zip(self.strides.iter())
292 .map(|(&idx, &stride)| idx * stride)
293 .sum::<usize>();
294
295 if let Some(elem) = self.data.get_mut(offset) {
296 *elem = value;
297 true
298 } else {
299 false
300 }
301 }
302
303 #[cfg(feature = "tensor")]
305 pub fn to_dense_tensor(&self) -> crate::tensor::DenseTensor {
306 crate::tensor::DenseTensor::new(self.data.to_vec(), self.shape.to_vec())
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct TensorReshapeError {
313 pub expected: usize,
315 pub got: usize,
317}
318
319impl std::fmt::Display for TensorReshapeError {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 write!(
322 f,
323 "Reshape error: expected {} elements, got {}",
324 self.expected, self.got
325 )
326 }
327}
328
329impl std::error::Error for TensorReshapeError {}
330
331fn compute_strides(shape: &[usize]) -> Vec<usize> {
339 let ndim = shape.len();
340 if ndim == 0 {
341 return vec![];
342 }
343
344 let mut strides = vec![1; ndim];
345 for i in (0..ndim - 1).rev() {
346 strides[i] = strides[i + 1] * shape[i + 1];
347 }
348 strides
349}
350
351#[derive(Debug, Clone)]
353pub struct TopologyReport {
354 pub is_valid: bool,
356 pub node_count: usize,
358 pub edge_count: usize,
360 pub connected_components: usize,
362 pub is_dag: bool,
364 pub issues: Vec<String>,
366}
367
368#[derive(Debug, Clone)]
370pub struct WeightDiff {
371 pub max_l2_diff: f64,
373 pub avg_l2_diff: f64,
375 pub tensor_count: usize,
377 pub per_tensor_diff: HashMap<String, f64>,
379}
380
381pub struct ModelSwitch;
383
384impl ModelSwitch {
385 #[cfg(feature = "safetensors")]
399 pub fn load_from_safetensors<P: AsRef<Path>>(path: P) -> GraphResult<Graph<OperatorType, WeightTensor>> {
400 use safetensors::SafeTensors;
401 use std::fs::File;
402 use std::io::Read;
403
404 let mut file = File::open(path.as_ref())
405 .map_err(|e| GraphError::IoError(format!("Failed to open file: {}", e)))?;
406 let mut buffer = Vec::new();
407 file.read_to_end(&mut buffer)
408 .map_err(|e| GraphError::IoError(format!("Failed to read file: {}", e)))?;
409
410 let safetensors = SafeTensors::deserialize(&buffer)
411 .map_err(|e| GraphError::InvalidFormat(format!("Failed to deserialize safetensors: {}", e)))?;
412
413 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
414
415 for (name, tensor_view) in safetensors.tensors() {
419 let shape = tensor_view.shape().to_vec();
420 let dtype = tensor_view.dtype();
421
422 let data = match dtype {
424 safetensors::Dtype::F32 => {
425 let slice = tensor_view.data();
426 match bytemuck::try_cast_slice::<u8, f32>(slice) {
428 Ok(f32_data) => f32_data.iter().map(|&x| x as f64).collect(),
429 Err(_) => {
430 slice.chunks_exact(4)
431 .map(|chunk| {
432 let bytes: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
433 f32::from_le_bytes(bytes) as f64
434 })
435 .collect()
436 }
437 }
438 }
439 safetensors::Dtype::F64 => {
440 let slice = tensor_view.data();
441 match bytemuck::try_cast_slice::<u8, f64>(slice) {
442 Ok(f64_data) => f64_data.to_vec(),
443 Err(_) => {
444 slice.chunks_exact(8)
445 .map(|chunk| {
446 let bytes: [u8; 8] = [chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7]];
447 f64::from_le_bytes(bytes)
448 })
449 .collect()
450 }
451 }
452 }
453 safetensors::Dtype::F16 => {
454 let slice = tensor_view.data();
455 let f16_data: Vec<half::f16> = slice
457 .chunks_exact(2)
458 .map(|chunk| half::f16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]])))
459 .collect();
460 f16_data.iter().map(|x| x.to_f32() as f64).collect()
461 }
462 _ => {
463 return Err(GraphError::InvalidFormat(
464 format!("Unsupported dtype: {:?}", dtype)
465 ));
466 }
467 };
468
469 let weight_tensor = WeightTensor::new(name.to_string(), data, shape);
471
472 let operator = Self::infer_operator_from_name(&name);
474 let node = graph.add_node(operator)?;
475
476 graph.add_edge(node, node, weight_tensor)?;
480 }
481
482 Ok(graph)
483 }
484
485 #[cfg(feature = "safetensors")]
501 pub fn save_to_safetensors<P: AsRef<Path>>(
502 graph: &Graph<OperatorType, WeightTensor>,
503 path: P,
504 ) -> GraphResult<()> {
505 use std::collections::BTreeMap;
506 use safetensors::tensor::{TensorView, Dtype};
507
508 let mut tensor_data: BTreeMap<String, (Vec<u8>, Vec<usize>)> = BTreeMap::new();
510
511 for edge_ref in graph.edges() {
512 let weight = edge_ref.data();
513
514 let data_f32: Vec<f32> = weight.data.iter()
516 .map(|&x| x as f32)
517 .collect();
518
519 let byte_data: Vec<u8> = data_f32.iter()
520 .flat_map(|&x| x.to_le_bytes().to_vec())
521 .collect();
522
523 tensor_data.insert(
524 weight.name.clone(),
525 (byte_data, weight.shape.to_vec()),
526 );
527 }
528
529 let mut tensors: BTreeMap<String, TensorView> = BTreeMap::new();
531 for (name, (bytes, shape)) in &tensor_data {
532 let tensor_view = TensorView::new(
533 Dtype::F32,
534 shape.clone(),
535 bytes,
536 ).map_err(|e| GraphError::InvalidFormat(format!("Failed to create tensor view: {}", e)))?;
537
538 tensors.insert(name.clone(), tensor_view);
539 }
540
541 let metadata: Option<std::collections::HashMap<String, String>> = None;
543
544 safetensors::serialize_to_file(&tensors, &metadata, path.as_ref())
546 .map_err(|e| GraphError::IoError(format!("Failed to write safetensors file: {}", e)))?;
547
548 Ok(())
549 }
550
551 pub fn validate_topology(
561 graph: &Graph<OperatorType, WeightTensor>,
562 ) -> GraphResult<TopologyReport> {
563 use crate::algorithms::community::connected_components;
564 use crate::algorithms::traversal::topological_sort;
565
566 let node_count = graph.node_count();
567 let edge_count = graph.edge_count();
568 let mut issues = Vec::new();
569
570 if node_count == 0 {
572 issues.push("Graph is empty".to_string());
573 return Ok(TopologyReport {
574 is_valid: false,
575 node_count,
576 edge_count,
577 connected_components: 0,
578 is_dag: true,
579 issues,
580 });
581 }
582
583 let components = connected_components(graph);
585 if components.len() > 1 {
586 issues.push(format!("Graph has {} disconnected components", components.len()));
587 }
588
589 let is_dag = topological_sort(graph).is_ok();
591 if !is_dag {
592 issues.push("Graph contains cycles (may be valid for recurrent models)".to_string());
593 }
594
595 let isolated_count = graph
597 .nodes()
598 .filter(|n| graph.neighbors(n.index()).count() == 0)
599 .count();
600 if isolated_count > 0 {
601 issues.push(format!("Graph has {} isolated nodes", isolated_count));
602 }
603
604 let is_valid = issues.is_empty() || (components.len() == 1 && isolated_count == 0);
605
606 Ok(TopologyReport {
607 is_valid,
608 node_count,
609 edge_count,
610 connected_components: components.len(),
611 is_dag,
612 issues,
613 })
614 }
615
616 pub fn verify_weights(
627 original: &Graph<OperatorType, WeightTensor>,
628 modified: &Graph<OperatorType, WeightTensor>,
629 ) -> GraphResult<WeightDiff> {
630 let mut per_tensor_diff: HashMap<String, f64> = HashMap::new();
631 let mut max_l2_diff = 0.0f64;
632 let mut total_diff = 0.0f64;
633 let mut tensor_count = 0;
634
635 let original_weights: HashMap<String, &WeightTensor> = original.edges()
637 .map(|e| (e.data().name.clone(), e.data()))
638 .collect();
639
640 for edge_ref in modified.edges() {
642 let modified_weight = edge_ref.data();
643
644 if let Some(&original_weight) = original_weights.get(&modified_weight.name) {
645 if original_weight.shape != modified_weight.shape {
647 per_tensor_diff.insert(
648 modified_weight.name.clone(),
649 f64::MAX,
650 );
651 max_l2_diff = f64::MAX;
652 tensor_count += 1;
653 continue;
654 }
655
656 let l2_diff = original_weight.l2_diff(modified_weight);
658 per_tensor_diff.insert(modified_weight.name.clone(), l2_diff);
659
660 if l2_diff > max_l2_diff {
661 max_l2_diff = l2_diff;
662 }
663 total_diff += l2_diff;
664 tensor_count += 1;
665 } else {
666 per_tensor_diff.insert(
668 modified_weight.name.clone(),
669 f64::MAX,
670 );
671 tensor_count += 1;
672 }
673 }
674
675 for name in original_weights.keys() {
677 if !per_tensor_diff.contains_key(name) {
678 per_tensor_diff.insert(name.clone(), f64::MAX);
679 tensor_count += 1;
680 }
681 }
682
683 let avg_l2_diff = if tensor_count > 0 {
684 total_diff / tensor_count as f64
685 } else {
686 0.0
687 };
688
689 Ok(WeightDiff {
690 max_l2_diff,
691 avg_l2_diff,
692 tensor_count,
693 per_tensor_diff,
694 })
695 }
696
697 #[allow(dead_code)]
699 fn infer_operator_from_name(name: &str) -> OperatorType {
700 let name_lower = name.to_lowercase();
701
702 if name_lower.contains("attention") || name_lower.contains("attn") {
703 OperatorType::Attention {
704 num_heads: 32,
705 hidden_dim: 4096,
706 }
707 } else if name_lower.contains("mlp") || name_lower.contains("ffn") {
708 OperatorType::MLP {
709 hidden_dim: 11008,
710 activation: "silu".to_string(),
711 }
712 } else if name_lower.contains("norm") || name_lower.contains("ln") {
713 OperatorType::Norm {
714 norm_type: "rmsnorm".to_string(),
715 eps: 1e-6,
716 }
717 } else if name_lower.contains("embed") {
718 OperatorType::Embedding {
719 vocab_size: 32000,
720 embed_dim: 4096,
721 }
722 } else if name_lower.contains("linear") || name_lower.contains("proj") {
723 OperatorType::Linear {
724 in_features: 4096,
725 out_features: 4096,
726 }
727 } else {
728 OperatorType::Custom {
729 name: name.to_string(),
730 }
731 }
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_weight_tensor_l2_norm() {
741 let tensor = WeightTensor::new(
742 "test".to_string(),
743 vec![1.0, 2.0, 3.0, 4.0],
744 vec![2, 2],
745 );
746 let norm = tensor.l2_norm();
747 assert!((norm - 5.477).abs() < 0.001);
748 }
749
750 #[test]
751 fn test_weight_tensor_l2_diff() {
752 let t1 = WeightTensor::new(
753 "test1".to_string(),
754 vec![1.0, 2.0, 3.0, 4.0],
755 vec![2, 2],
756 );
757 let t2 = WeightTensor::new(
758 "test2".to_string(),
759 vec![1.1, 2.1, 3.1, 4.1],
760 vec![2, 2],
761 );
762 let diff = t1.l2_diff(&t2);
763 assert!(diff < 0.5);
764 }
765
766 #[test]
767 fn test_weight_tensor_reshape_mut() {
768 let mut tensor = WeightTensor::new(
769 "test".to_string(),
770 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
771 vec![2, 3],
772 );
773
774 tensor.reshape_mut(vec![3, 2]).unwrap();
776 assert_eq!(tensor.shape(), &[3, 2]);
777 assert_eq!(tensor.strides(), &[2, 1]);
778
779 let result = tensor.reshape_mut(vec![2, 2]);
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn test_weight_tensor_stride_access() {
786 let tensor = WeightTensor::new(
787 "test".to_string(),
788 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
789 vec![2, 3],
790 );
791
792 assert_eq!(tensor.get(&[0, 0]), Some(1.0));
794 assert_eq!(tensor.get(&[0, 1]), Some(2.0));
795 assert_eq!(tensor.get(&[0, 2]), Some(3.0));
796 assert_eq!(tensor.get(&[1, 0]), Some(4.0));
797 assert_eq!(tensor.get(&[1, 1]), Some(5.0));
798 assert_eq!(tensor.get(&[1, 2]), Some(6.0));
799
800 assert_eq!(tensor.get(&[2, 0]), None);
802 assert_eq!(tensor.get(&[0, 3]), None);
803 }
804
805 #[test]
806 fn test_weight_tensor_set() {
807 let mut tensor = WeightTensor::new(
808 "test".to_string(),
809 vec![1.0, 2.0, 3.0, 4.0],
810 vec![2, 2],
811 );
812
813 assert!(tensor.set(&[0, 1], 10.0));
815 assert!(tensor.set(&[1, 0], 20.0));
816
817 assert_eq!(tensor.get(&[0, 0]), Some(1.0));
818 assert_eq!(tensor.get(&[0, 1]), Some(10.0));
819 assert_eq!(tensor.get(&[1, 0]), Some(20.0));
820 assert_eq!(tensor.get(&[1, 1]), Some(4.0));
821
822 assert!(!tensor.set(&[2, 0], 100.0));
824 }
825
826 #[test]
827 fn test_weight_tensor_ndim_and_numel() {
828 let tensor = WeightTensor::new(
829 "test".to_string(),
830 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
831 vec![2, 3],
832 );
833
834 assert_eq!(tensor.ndim(), 2);
835 assert_eq!(tensor.numel(), 6);
836 }
837
838 #[test]
839 fn test_weight_tensor_struct_size() {
840 use std::mem::size_of;
842
843 assert!(size_of::<WeightTensor>() >= 64);
845
846 let tensor = WeightTensor::new(
848 "test".to_string(),
849 vec![1.0; 100],
850 vec![10, 10],
851 );
852
853 assert_eq!(tensor.numel(), 100);
855 }
856
857 #[test]
858 fn test_compute_strides() {
859 assert_eq!(compute_strides(&[5]), vec![1]);
861
862 assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
864
865 assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
867
868 assert_eq!(compute_strides(&[2, 3, 4, 5]), vec![60, 20, 5, 1]);
870
871 let empty: &[usize] = &[];
873 assert_eq!(compute_strides(empty), Vec::<usize>::new());
874 }
875
876 #[test]
877 fn test_infer_operator_from_name() {
878 assert!(matches!(
879 ModelSwitch::infer_operator_from_name("model.layers.0.self_attn.q_proj"),
880 OperatorType::Attention { .. }
881 ));
882 assert!(matches!(
883 ModelSwitch::infer_operator_from_name("model.layers.0.mlp.gate_proj"),
884 OperatorType::MLP { .. }
885 ));
886 assert!(matches!(
887 ModelSwitch::infer_operator_from_name("model.norm.weight"),
888 OperatorType::Norm { .. }
889 ));
890 }
891
892 #[test]
893 #[cfg(feature = "safetensors")]
894 fn test_save_to_safetensors() {
895 use std::fs;
896 use std::path::PathBuf;
897
898 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
900
901 let embed_node = graph
903 .add_node(OperatorType::Embedding {
904 vocab_size: 1000,
905 embed_dim: 128,
906 })
907 .unwrap();
908
909 let attn_node = graph
910 .add_node(OperatorType::Attention {
911 num_heads: 8,
912 hidden_dim: 256,
913 })
914 .unwrap();
915
916 let mlp_node = graph
917 .add_node(OperatorType::MLP {
918 hidden_dim: 512,
919 activation: "relu".to_string(),
920 })
921 .unwrap();
922
923 let norm_node = graph
924 .add_node(OperatorType::Norm {
925 norm_type: "layernorm".to_string(),
926 eps: 1e-5,
927 })
928 .unwrap();
929
930 graph
932 .add_edge(
933 embed_node,
934 embed_node,
935 WeightTensor::new(
936 "model.embeddings.weight".to_string(),
937 vec![1.0; 1000 * 128],
938 vec![1000, 128],
939 ),
940 )
941 .unwrap();
942
943 graph
944 .add_edge(
945 attn_node,
946 attn_node,
947 WeightTensor::new(
948 "model.layers.0.attention.qkv.weight".to_string(),
949 vec![0.5; 256 * 3 * 256],
950 vec![256, 3, 256],
951 ),
952 )
953 .unwrap();
954
955 graph
956 .add_edge(
957 mlp_node,
958 mlp_node,
959 WeightTensor::new(
960 "model.layers.0.mlp.fc1.weight".to_string(),
961 vec![0.25; 256 * 512],
962 vec![256, 512],
963 ),
964 )
965 .unwrap();
966
967 graph
968 .add_edge(
969 norm_node,
970 norm_node,
971 WeightTensor::new(
972 "model.norm.weight".to_string(),
973 vec![1.0; 256],
974 vec![256],
975 ),
976 )
977 .unwrap();
978
979 graph.add_edge(embed_node, attn_node, WeightTensor::new(
981 "model.embed_to_attn.weight".to_string(),
982 vec![0.1; 128 * 256],
983 vec![128, 256],
984 )).unwrap();
985
986 graph.add_edge(attn_node, mlp_node, WeightTensor::new(
987 "model.attn_to_mlp.weight".to_string(),
988 vec![0.2; 256 * 256],
989 vec![256, 256],
990 )).unwrap();
991
992 graph.add_edge(mlp_node, norm_node, WeightTensor::new(
993 "model.mlp_to_norm.weight".to_string(),
994 vec![0.3; 512 * 256],
995 vec![512, 256],
996 )).unwrap();
997
998 let temp_path = PathBuf::from("test_save_to_safetensors_temp.safetensors");
1000
1001 let save_result = ModelSwitch::save_to_safetensors(&graph, &temp_path);
1003 assert!(save_result.is_ok(), "Failed to save to safetensors: {:?}", save_result);
1004
1005 assert!(temp_path.exists(), "Safetensors file was not created");
1007
1008 let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path);
1010 assert!(loaded_graph.is_ok(), "Failed to load from safetensors: {:?}", loaded_graph);
1011 let loaded_graph = loaded_graph.unwrap();
1012
1013 assert_eq!(
1019 7,
1020 loaded_graph.edge_count(),
1021 "Edge count should match number of tensors"
1022 );
1023
1024 let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
1027 println!("Save/Load round-trip weight diff: max={:.6e}, avg={:.6e}, count={}",
1028 diff.max_l2_diff, diff.avg_l2_diff, diff.tensor_count);
1029
1030 assert!(
1032 diff.max_l2_diff < 1e-5,
1033 "Weight difference too large: max_l2_diff={}",
1034 diff.max_l2_diff
1035 );
1036
1037 let _ = fs::remove_file(&temp_path);
1039 }
1040
1041 #[test]
1042 #[cfg(feature = "safetensors")]
1043 fn test_save_load_round_trip() {
1044 use std::fs;
1045 use std::path::PathBuf;
1046
1047 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
1049
1050 let node = graph
1051 .add_node(OperatorType::Linear {
1052 in_features: 64,
1053 out_features: 64,
1054 })
1055 .unwrap();
1056
1057 let original_data: Vec<f64> = (0..64 * 64).map(|i| (i as f64) * 0.01).collect();
1059 graph
1060 .add_edge(
1061 node,
1062 node,
1063 WeightTensor::new(
1064 "test.linear.weight".to_string(),
1065 original_data.clone(),
1066 vec![64, 64],
1067 ),
1068 )
1069 .unwrap();
1070
1071 let temp_path = PathBuf::from("test_round_trip_temp.safetensors");
1073
1074 ModelSwitch::save_to_safetensors(&graph, &temp_path).unwrap();
1075 let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path).unwrap();
1076
1077 let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
1079
1080 println!("Round-trip L2 diff: max={:.6e}, avg={:.6e}", diff.max_l2_diff, diff.avg_l2_diff);
1083
1084 let _ = fs::remove_file(&temp_path);
1086 }
1087}