Skip to main content

god_graph/transformer/optimization/
switch.rs

1//! Model Switch: Bidirectional lossless conversion between Safetensors and GodGraph
2//!
3//! This module implements the Model Switch tool for converting between
4//! HuggingFace Safetensors format and GodGraph graph structure.
5//!
6//! ## Features
7//!
8//! - Safetensors → GodGraph loading
9//! - GodGraph → Safetensors exporting
10//! - Topology integrity validation
11//! - Weight precision verification (lossless check)
12//!
13//! ## Example
14//!
15//! ```no_run
16//! # #[cfg(feature = "safetensors")]
17//! use god_gragh::transformer::optimization::ModelSwitch;
18//!
19//! # #[cfg(feature = "safetensors")]
20//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
21//! // Load from Safetensors
22//! let graph = ModelSwitch::load_from_safetensors("model.safetensors")?;
23//!
24//! // Validate topology
25//! let report = ModelSwitch::validate_topology(&graph)?;
26//! println!("Topology valid: {}", report.is_valid);
27//!
28//! // Verify weights against original
29//! let diff = ModelSwitch::verify_weights(&graph, &graph)?;
30//! println!("Max L2 difference: {}", diff.max_l2_diff);
31//!
32//! // Save to Safetensors
33//! ModelSwitch::save_to_safetensors(&graph, "optimized.safetensors")?;
34//! # Ok(())
35//! # }
36//! # #[cfg(not(feature = "safetensors"))]
37//! # fn main() {}
38//! ```
39
40use crate::errors::{GraphError, GraphResult};
41use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
42use crate::graph::Graph;
43use smallvec::SmallVec;
44use std::collections::HashMap;
45use std::path::Path;
46
47/// Operator types for LLM computation graph nodes
48#[derive(Debug, Clone, PartialEq)]
49pub enum OperatorType {
50    /// Multi-head attention operator
51    Attention {
52        /// Number of attention heads
53        num_heads: usize,
54        /// Hidden dimension size
55        hidden_dim: usize,
56    },
57    /// Feed-forward network (MLP)
58    MLP {
59        /// Hidden layer dimension
60        hidden_dim: usize,
61        /// Activation function name (e.g., "gelu", "silu")
62        activation: String,
63    },
64    /// Layer normalization
65    Norm {
66        /// Normalization type (e.g., "layer", "rms")
67        norm_type: String,
68        /// Epsilon value for numerical stability
69        eps: f64,
70    },
71    /// Embedding lookup
72    Embedding {
73        /// Vocabulary size
74        vocab_size: usize,
75        /// Embedding dimension
76        embed_dim: usize,
77    },
78    /// Linear projection
79    Linear {
80        /// Input feature dimension
81        in_features: usize,
82        /// Output feature dimension
83        out_features: usize,
84    },
85    /// Residual connection (identity)
86    Residual,
87    /// Custom operator
88    Custom {
89        /// Custom operator name
90        name: String,
91    },
92}
93
94/// 64-byte aligned weight tensor with stride support for efficient N-dimensional access
95///
96/// This struct provides:
97/// - Inline storage with Box<[f64]> to avoid Vec reallocation
98/// - SmallVec<[usize; 4]> for shape and strides (avoids heap allocation for ≤4D tensors)
99/// - 64-byte alignment to prevent false sharing in multi-threaded scenarios
100/// - Stride-based indexing for non-contiguous memory access patterns
101/// - In-place reshape without data reallocation
102#[repr(align(64))]
103#[derive(Clone, Debug)]
104pub struct WeightTensor {
105    /// Tensor data stored in Box to avoid Vec reallocation overhead
106    pub data: Box<[f64]>,
107    /// Tensor shape (dimensions) with small array optimization
108    pub shape: SmallVec<[usize; 4]>,
109    /// Strides for each dimension (C-order by default)
110    pub strides: SmallVec<[usize; 4]>,
111    /// Tensor name/identifier for weight mapping
112    pub name: String,
113}
114
115impl WeightTensor {
116    /// Create a new weight tensor with automatic stride computation
117    ///
118    /// # Arguments
119    /// * `data` - Tensor data in row-major (C-order) format
120    /// * `shape` - Tensor dimensions
121    /// * `name` - Tensor identifier
122    ///
123    /// # Panics
124    /// Panics if data length doesn't match the product of shape dimensions
125    pub fn new(name: String, data: Vec<f64>, shape: Vec<usize>) -> Self {
126        let expected_len = shape.iter().product::<usize>();
127        assert_eq!(
128            data.len(),
129            expected_len,
130            "Data length {} mismatch with shape {:?} (expected {})",
131            data.len(),
132            shape,
133            expected_len
134        );
135
136        let strides = compute_strides(&shape);
137        Self {
138            data: data.into_boxed_slice(),
139            shape: shape.into(),
140            strides: strides.into(),
141            name,
142        }
143    }
144
145    /// Create a weight tensor from pre-computed strides
146    ///
147    /// # Arguments
148    /// * `data` - Tensor data
149    /// * `shape` - Tensor dimensions
150    /// * `strides` - Stride for each dimension
151    /// * `name` - Tensor identifier
152    pub fn with_strides(
153        name: String,
154        data: Vec<f64>,
155        shape: Vec<usize>,
156        strides: Vec<usize>,
157    ) -> Self {
158        let expected_len = shape.iter().product::<usize>();
159        assert_eq!(
160            data.len(),
161            expected_len,
162            "Data length {} mismatch with shape {:?}",
163            data.len(),
164            shape
165        );
166
167        Self {
168            data: data.into_boxed_slice(),
169            shape: shape.into(),
170            strides: strides.into(),
171            name,
172        }
173    }
174
175    /// Get the number of dimensions
176    pub fn ndim(&self) -> usize {
177        self.shape.len()
178    }
179
180    /// Get the total number of elements
181    pub fn numel(&self) -> usize {
182        self.data.len()
183    }
184
185    /// Get the shape as a slice
186    pub fn shape(&self) -> &[usize] {
187        &self.shape
188    }
189
190    /// Get the strides as a slice
191    pub fn strides(&self) -> &[usize] {
192        &self.strides
193    }
194
195    /// Get immutable access to the underlying data
196    pub fn data(&self) -> &[f64] {
197        &self.data
198    }
199
200    /// Get mutable access to the underlying data for in-place operations
201    pub fn as_slice_mut(&mut self) -> &mut [f64] {
202        &mut self.data
203    }
204
205    /// Reshape the tensor in-place without reallocating data
206    ///
207    /// # Arguments
208    /// * `new_shape` - New tensor dimensions
209    ///
210    /// # Returns
211    /// Ok if successful, Err if the new shape doesn't match the data size
212    pub fn reshape_mut(&mut self, new_shape: Vec<usize>) -> Result<(), TensorReshapeError> {
213        let new_size = new_shape.iter().product::<usize>();
214        if new_size != self.data.len() {
215            return Err(TensorReshapeError {
216                expected: self.data.len(),
217                got: new_size,
218            });
219        }
220        self.shape = new_shape.into();
221        self.strides = compute_strides(&self.shape).into();
222        Ok(())
223    }
224
225    /// Calculate L2 norm of the tensor
226    pub fn l2_norm(&self) -> f64 {
227        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
228    }
229
230    /// Calculate L2 difference with another tensor
231    pub fn l2_diff(&self, other: &Self) -> f64 {
232        if self.shape != other.shape {
233            return f64::MAX;
234        }
235        self.data
236            .iter()
237            .zip(other.data.iter())
238            .map(|(a, b)| (a - b).powi(2))
239            .sum::<f64>()
240            .sqrt()
241    }
242
243    /// Get element at multi-dimensional index using stride-based access
244    ///
245    /// # Arguments
246    /// * `indices` - Index for each dimension
247    ///
248    /// # Returns
249    /// Some(value) if indices are valid, None otherwise
250    pub fn get(&self, indices: &[usize]) -> Option<f64> {
251        if indices.len() != self.shape.len() {
252            return None;
253        }
254
255        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
256            if idx >= dim {
257                return None;
258            }
259        }
260
261        let offset = indices
262            .iter()
263            .zip(self.strides.iter())
264            .map(|(&idx, &stride)| idx * stride)
265            .sum::<usize>();
266
267        self.data.get(offset).copied()
268    }
269
270    /// Set element at multi-dimensional index using stride-based access
271    ///
272    /// # Arguments
273    /// * `indices` - Index for each dimension
274    /// * `value` - Value to set
275    ///
276    /// # Returns
277    /// true if successful, false if indices are invalid
278    pub fn set(&mut self, indices: &[usize], value: f64) -> bool {
279        if indices.len() != self.shape.len() {
280            return false;
281        }
282
283        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
284            if idx >= dim {
285                return false;
286            }
287        }
288
289        let offset = indices
290            .iter()
291            .zip(self.strides.iter())
292            .map(|(&idx, &stride)| idx * stride)
293            .sum::<usize>();
294
295        if let Some(elem) = self.data.get_mut(offset) {
296            *elem = value;
297            true
298        } else {
299            false
300        }
301    }
302
303    /// Convert to DenseTensor for compatibility with existing tensor operations
304    #[cfg(feature = "tensor")]
305    pub fn to_dense_tensor(&self) -> crate::tensor::DenseTensor {
306        crate::tensor::DenseTensor::new(self.data.to_vec(), self.shape.to_vec())
307    }
308}
309
310/// Error type for tensor reshape operations
311#[derive(Debug, Clone)]
312pub struct TensorReshapeError {
313    /// Expected number of elements
314    pub expected: usize,
315    /// Actual number of elements
316    pub got: usize,
317}
318
319impl std::fmt::Display for TensorReshapeError {
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        write!(
322            f,
323            "Reshape error: expected {} elements, got {}",
324            self.expected, self.got
325        )
326    }
327}
328
329impl std::error::Error for TensorReshapeError {}
330
331/// Compute strides for C-order (row-major) layout
332///
333/// # Arguments
334/// * `shape` - Tensor dimensions
335///
336/// # Returns
337/// Vector of strides for each dimension
338fn compute_strides(shape: &[usize]) -> Vec<usize> {
339    let ndim = shape.len();
340    if ndim == 0 {
341        return vec![];
342    }
343
344    let mut strides = vec![1; ndim];
345    for i in (0..ndim - 1).rev() {
346        strides[i] = strides[i + 1] * shape[i + 1];
347    }
348    strides
349}
350
351/// Topology validation report
352#[derive(Debug, Clone)]
353pub struct TopologyReport {
354    /// Whether the topology is valid
355    pub is_valid: bool,
356    /// Number of nodes
357    pub node_count: usize,
358    /// Number of edges
359    pub edge_count: usize,
360    /// Number of connected components
361    pub connected_components: usize,
362    /// Whether the graph is a DAG (directed acyclic graph)
363    pub is_dag: bool,
364    /// List of issues found
365    pub issues: Vec<String>,
366}
367
368/// Weight difference report
369#[derive(Debug, Clone)]
370pub struct WeightDiff {
371    /// Maximum L2 difference across all weights
372    pub max_l2_diff: f64,
373    /// Average L2 difference
374    pub avg_l2_diff: f64,
375    /// Number of tensors compared
376    pub tensor_count: usize,
377    /// Per-tensor differences
378    pub per_tensor_diff: HashMap<String, f64>,
379}
380
381/// Model Switch: Bidirectional conversion between Safetensors and GodGraph
382pub struct ModelSwitch;
383
384impl ModelSwitch {
385    /// Load a model from Safetensors format into GodGraph
386    ///
387    /// # Arguments
388    ///
389    /// * `path` - Path to the Safetensors file
390    ///
391    /// # Returns
392    ///
393    /// A GodGraph representation of the model
394    ///
395    /// # Errors
396    ///
397    /// Returns an error if the file cannot be read or parsed
398    #[cfg(feature = "safetensors")]
399    pub fn load_from_safetensors<P: AsRef<Path>>(path: P) -> GraphResult<Graph<OperatorType, WeightTensor>> {
400        use safetensors::SafeTensors;
401        use std::fs::File;
402        use std::io::Read;
403
404        let mut file = File::open(path.as_ref())
405            .map_err(|e| GraphError::IoError(format!("Failed to open file: {}", e)))?;
406        let mut buffer = Vec::new();
407        file.read_to_end(&mut buffer)
408            .map_err(|e| GraphError::IoError(format!("Failed to read file: {}", e)))?;
409
410        let safetensors = SafeTensors::deserialize(&buffer)
411            .map_err(|e| GraphError::InvalidFormat(format!("Failed to deserialize safetensors: {}", e)))?;
412
413        let mut graph = Graph::<OperatorType, WeightTensor>::directed();
414
415        // Parse tensors and build graph structure
416        // This is a simplified implementation; a full implementation would
417        // parse the model config and build the appropriate graph structure
418        for (name, tensor_view) in safetensors.tensors() {
419            let shape = tensor_view.shape().to_vec();
420            let dtype = tensor_view.dtype();
421
422            // Convert tensor data to f64
423            let data = match dtype {
424                safetensors::Dtype::F32 => {
425                    let slice = tensor_view.data();
426                    // Use try_cast_slice for unaligned data, with manual fallback
427                    match bytemuck::try_cast_slice::<u8, f32>(slice) {
428                        Ok(f32_data) => f32_data.iter().map(|&x| x as f64).collect(),
429                        Err(_) => {
430                            slice.chunks_exact(4)
431                                .map(|chunk| {
432                                    let bytes: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
433                                    f32::from_le_bytes(bytes) as f64
434                                })
435                                .collect()
436                        }
437                    }
438                }
439                safetensors::Dtype::F64 => {
440                    let slice = tensor_view.data();
441                    match bytemuck::try_cast_slice::<u8, f64>(slice) {
442                        Ok(f64_data) => f64_data.to_vec(),
443                        Err(_) => {
444                            slice.chunks_exact(8)
445                                .map(|chunk| {
446                                    let bytes: [u8; 8] = [chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7]];
447                                    f64::from_le_bytes(bytes)
448                                })
449                                .collect()
450                        }
451                    }
452                }
453                safetensors::Dtype::F16 => {
454                    let slice = tensor_view.data();
455                    // Convert bytes to f16 using proper API
456                    let f16_data: Vec<half::f16> = slice
457                        .chunks_exact(2)
458                        .map(|chunk| half::f16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]])))
459                        .collect();
460                    f16_data.iter().map(|x| x.to_f32() as f64).collect()
461                }
462                _ => {
463                    return Err(GraphError::InvalidFormat(
464                        format!("Unsupported dtype: {:?}", dtype)
465                    ));
466                }
467            };
468
469            // Create weight tensor
470            let weight_tensor = WeightTensor::new(name.to_string(), data, shape);
471
472            // Create operator node based on tensor name pattern
473            let operator = Self::infer_operator_from_name(&name);
474            let node = graph.add_node(operator)?;
475
476            // Store weight tensor as edge data (self-loop for now)
477            // In a full implementation, weights would be associated with specific nodes
478            // For now, add as self-loop to preserve the weight data
479            graph.add_edge(node, node, weight_tensor)?;
480        }
481
482        Ok(graph)
483    }
484
485    /// Save a GodGraph to Safetensors format
486    ///
487    /// # Arguments
488    ///
489    /// * `graph` - The GodGraph to save
490    /// * `path` - Output path for the Safetensors file
491    ///
492    /// # Errors
493    ///
494    /// Returns an error if the file cannot be written
495    ///
496    /// # Note
497    ///
498    /// This is a simplified implementation that stores all weights as F32.
499    /// A full implementation would preserve the original dtype.
500    #[cfg(feature = "safetensors")]
501    pub fn save_to_safetensors<P: AsRef<Path>>(
502        graph: &Graph<OperatorType, WeightTensor>,
503        path: P,
504    ) -> GraphResult<()> {
505        use std::collections::BTreeMap;
506        use safetensors::tensor::{TensorView, Dtype};
507
508        // Collect all tensor data first (owned data)
509        let mut tensor_data: BTreeMap<String, (Vec<u8>, Vec<usize>)> = BTreeMap::new();
510        
511        for edge_ref in graph.edges() {
512            let weight = edge_ref.data();
513            
514            // Convert f64 data back to F32 for storage (most common dtype)
515            let data_f32: Vec<f32> = weight.data.iter()
516                .map(|&x| x as f32)
517                .collect();
518            
519            let byte_data: Vec<u8> = data_f32.iter()
520                .flat_map(|&x| x.to_le_bytes().to_vec())
521                .collect();
522            
523            tensor_data.insert(
524                weight.name.clone(),
525                (byte_data, weight.shape.to_vec()),
526            );
527        }
528
529        // Create TensorViews - these borrow from tensor_data
530        let mut tensors: BTreeMap<String, TensorView> = BTreeMap::new();
531        for (name, (bytes, shape)) in &tensor_data {
532            let tensor_view = TensorView::new(
533                Dtype::F32,
534                shape.clone(),
535                bytes,
536            ).map_err(|e| GraphError::InvalidFormat(format!("Failed to create tensor view: {}", e)))?;
537            
538            tensors.insert(name.clone(), tensor_view);
539        }
540
541        // Create metadata (empty for now)
542        let metadata: Option<std::collections::HashMap<String, String>> = None;
543
544        // Serialize to file - tensors borrow from tensor_data which lives long enough
545        safetensors::serialize_to_file(&tensors, &metadata, path.as_ref())
546            .map_err(|e| GraphError::IoError(format!("Failed to write safetensors file: {}", e)))?;
547
548        Ok(())
549    }
550
551    /// Validate the topology of a graph
552    ///
553    /// # Arguments
554    ///
555    /// * `graph` - The graph to validate
556    ///
557    /// # Returns
558    ///
559    /// A topology validation report
560    pub fn validate_topology(
561        graph: &Graph<OperatorType, WeightTensor>,
562    ) -> GraphResult<TopologyReport> {
563        use crate::algorithms::community::connected_components;
564        use crate::algorithms::traversal::topological_sort;
565
566        let node_count = graph.node_count();
567        let edge_count = graph.edge_count();
568        let mut issues = Vec::new();
569
570        // Check for empty graph
571        if node_count == 0 {
572            issues.push("Graph is empty".to_string());
573            return Ok(TopologyReport {
574                is_valid: false,
575                node_count,
576                edge_count,
577                connected_components: 0,
578                is_dag: true,
579                issues,
580            });
581        }
582
583        // Check connected components
584        let components = connected_components(graph);
585        if components.len() > 1 {
586            issues.push(format!("Graph has {} disconnected components", components.len()));
587        }
588
589        // Check if DAG (for feedforward models)
590        let is_dag = topological_sort(graph).is_ok();
591        if !is_dag {
592            issues.push("Graph contains cycles (may be valid for recurrent models)".to_string());
593        }
594
595        // Check for isolated nodes
596        let isolated_count = graph
597            .nodes()
598            .filter(|n| graph.neighbors(n.index()).count() == 0)
599            .count();
600        if isolated_count > 0 {
601            issues.push(format!("Graph has {} isolated nodes", isolated_count));
602        }
603
604        let is_valid = issues.is_empty() || (components.len() == 1 && isolated_count == 0);
605
606        Ok(TopologyReport {
607            is_valid,
608            node_count,
609            edge_count,
610            connected_components: components.len(),
611            is_dag,
612            issues,
613        })
614    }
615
616    /// Verify weights between two graphs
617    ///
618    /// # Arguments
619    ///
620    /// * `original` - The original graph
621    /// * `modified` - The modified graph to compare
622    ///
623    /// # Returns
624    ///
625    /// A weight difference report
626    pub fn verify_weights(
627        original: &Graph<OperatorType, WeightTensor>,
628        modified: &Graph<OperatorType, WeightTensor>,
629    ) -> GraphResult<WeightDiff> {
630        let mut per_tensor_diff: HashMap<String, f64> = HashMap::new();
631        let mut max_l2_diff = 0.0f64;
632        let mut total_diff = 0.0f64;
633        let mut tensor_count = 0;
634
635        // Build a map of original weights by name
636        let original_weights: HashMap<String, &WeightTensor> = original.edges()
637            .map(|e| (e.data().name.clone(), e.data()))
638            .collect();
639
640        // Compare weights edge by edge
641        for edge_ref in modified.edges() {
642            let modified_weight = edge_ref.data();
643            
644            if let Some(&original_weight) = original_weights.get(&modified_weight.name) {
645                // Compare shapes first
646                if original_weight.shape != modified_weight.shape {
647                    per_tensor_diff.insert(
648                        modified_weight.name.clone(),
649                        f64::MAX,
650                    );
651                    max_l2_diff = f64::MAX;
652                    tensor_count += 1;
653                    continue;
654                }
655
656                // Calculate L2 difference
657                let l2_diff = original_weight.l2_diff(modified_weight);
658                per_tensor_diff.insert(modified_weight.name.clone(), l2_diff);
659                
660                if l2_diff > max_l2_diff {
661                    max_l2_diff = l2_diff;
662                }
663                total_diff += l2_diff;
664                tensor_count += 1;
665            } else {
666                // Weight not found in original
667                per_tensor_diff.insert(
668                    modified_weight.name.clone(),
669                    f64::MAX,
670                );
671                tensor_count += 1;
672            }
673        }
674
675        // Check for missing weights in modified graph
676        for name in original_weights.keys() {
677            if !per_tensor_diff.contains_key(name) {
678                per_tensor_diff.insert(name.clone(), f64::MAX);
679                tensor_count += 1;
680            }
681        }
682
683        let avg_l2_diff = if tensor_count > 0 {
684            total_diff / tensor_count as f64
685        } else {
686            0.0
687        };
688
689        Ok(WeightDiff {
690            max_l2_diff,
691            avg_l2_diff,
692            tensor_count,
693            per_tensor_diff,
694        })
695    }
696
697    /// Infer operator type from tensor name
698    #[allow(dead_code)]
699    fn infer_operator_from_name(name: &str) -> OperatorType {
700        let name_lower = name.to_lowercase();
701        
702        if name_lower.contains("attention") || name_lower.contains("attn") {
703            OperatorType::Attention {
704                num_heads: 32,
705                hidden_dim: 4096,
706            }
707        } else if name_lower.contains("mlp") || name_lower.contains("ffn") {
708            OperatorType::MLP {
709                hidden_dim: 11008,
710                activation: "silu".to_string(),
711            }
712        } else if name_lower.contains("norm") || name_lower.contains("ln") {
713            OperatorType::Norm {
714                norm_type: "rmsnorm".to_string(),
715                eps: 1e-6,
716            }
717        } else if name_lower.contains("embed") {
718            OperatorType::Embedding {
719                vocab_size: 32000,
720                embed_dim: 4096,
721            }
722        } else if name_lower.contains("linear") || name_lower.contains("proj") {
723            OperatorType::Linear {
724                in_features: 4096,
725                out_features: 4096,
726            }
727        } else {
728            OperatorType::Custom {
729                name: name.to_string(),
730            }
731        }
732    }
733}
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738
739    #[test]
740    fn test_weight_tensor_l2_norm() {
741        let tensor = WeightTensor::new(
742            "test".to_string(),
743            vec![1.0, 2.0, 3.0, 4.0],
744            vec![2, 2],
745        );
746        let norm = tensor.l2_norm();
747        assert!((norm - 5.477).abs() < 0.001);
748    }
749
750    #[test]
751    fn test_weight_tensor_l2_diff() {
752        let t1 = WeightTensor::new(
753            "test1".to_string(),
754            vec![1.0, 2.0, 3.0, 4.0],
755            vec![2, 2],
756        );
757        let t2 = WeightTensor::new(
758            "test2".to_string(),
759            vec![1.1, 2.1, 3.1, 4.1],
760            vec![2, 2],
761        );
762        let diff = t1.l2_diff(&t2);
763        assert!(diff < 0.5);
764    }
765
766    #[test]
767    fn test_weight_tensor_reshape_mut() {
768        let mut tensor = WeightTensor::new(
769            "test".to_string(),
770            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
771            vec![2, 3],
772        );
773        
774        // Reshape from [2, 3] to [3, 2]
775        tensor.reshape_mut(vec![3, 2]).unwrap();
776        assert_eq!(tensor.shape(), &[3, 2]);
777        assert_eq!(tensor.strides(), &[2, 1]);
778        
779        // Try invalid reshape
780        let result = tensor.reshape_mut(vec![2, 2]);
781        assert!(result.is_err());
782    }
783
784    #[test]
785    fn test_weight_tensor_stride_access() {
786        let tensor = WeightTensor::new(
787            "test".to_string(),
788            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
789            vec![2, 3],
790        );
791        
792        // Test get with stride-based indexing
793        assert_eq!(tensor.get(&[0, 0]), Some(1.0));
794        assert_eq!(tensor.get(&[0, 1]), Some(2.0));
795        assert_eq!(tensor.get(&[0, 2]), Some(3.0));
796        assert_eq!(tensor.get(&[1, 0]), Some(4.0));
797        assert_eq!(tensor.get(&[1, 1]), Some(5.0));
798        assert_eq!(tensor.get(&[1, 2]), Some(6.0));
799        
800        // Test out of bounds
801        assert_eq!(tensor.get(&[2, 0]), None);
802        assert_eq!(tensor.get(&[0, 3]), None);
803    }
804
805    #[test]
806    fn test_weight_tensor_set() {
807        let mut tensor = WeightTensor::new(
808            "test".to_string(),
809            vec![1.0, 2.0, 3.0, 4.0],
810            vec![2, 2],
811        );
812        
813        // Set values using stride-based indexing
814        assert!(tensor.set(&[0, 1], 10.0));
815        assert!(tensor.set(&[1, 0], 20.0));
816        
817        assert_eq!(tensor.get(&[0, 0]), Some(1.0));
818        assert_eq!(tensor.get(&[0, 1]), Some(10.0));
819        assert_eq!(tensor.get(&[1, 0]), Some(20.0));
820        assert_eq!(tensor.get(&[1, 1]), Some(4.0));
821        
822        // Test out of bounds set
823        assert!(!tensor.set(&[2, 0], 100.0));
824    }
825
826    #[test]
827    fn test_weight_tensor_ndim_and_numel() {
828        let tensor = WeightTensor::new(
829            "test".to_string(),
830            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
831            vec![2, 3],
832        );
833        
834        assert_eq!(tensor.ndim(), 2);
835        assert_eq!(tensor.numel(), 6);
836    }
837
838    #[test]
839    fn test_weight_tensor_struct_size() {
840        // Test that WeightTensor has the expected size
841        use std::mem::size_of;
842        
843        // WeightTensor should be 64-byte aligned due to repr(align(64))
844        assert!(size_of::<WeightTensor>() >= 64);
845        
846        // Create a tensor and verify basic properties
847        let tensor = WeightTensor::new(
848            "test".to_string(),
849            vec![1.0; 100],
850            vec![10, 10],
851        );
852        
853        // Verify data length
854        assert_eq!(tensor.numel(), 100);
855    }
856
857    #[test]
858    fn test_compute_strides() {
859        // 1D tensor
860        assert_eq!(compute_strides(&[5]), vec![1]);
861        
862        // 2D tensor (row-major)
863        assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
864        
865        // 3D tensor
866        assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
867        
868        // 4D tensor
869        assert_eq!(compute_strides(&[2, 3, 4, 5]), vec![60, 20, 5, 1]);
870        
871        // Empty tensor
872        let empty: &[usize] = &[];
873        assert_eq!(compute_strides(empty), Vec::<usize>::new());
874    }
875
876    #[test]
877    fn test_infer_operator_from_name() {
878        assert!(matches!(
879            ModelSwitch::infer_operator_from_name("model.layers.0.self_attn.q_proj"),
880            OperatorType::Attention { .. }
881        ));
882        assert!(matches!(
883            ModelSwitch::infer_operator_from_name("model.layers.0.mlp.gate_proj"),
884            OperatorType::MLP { .. }
885        ));
886        assert!(matches!(
887            ModelSwitch::infer_operator_from_name("model.norm.weight"),
888            OperatorType::Norm { .. }
889        ));
890    }
891
892    #[test]
893    #[cfg(feature = "safetensors")]
894    fn test_save_to_safetensors() {
895        use std::fs;
896        use std::path::PathBuf;
897
898        // Create a test graph with multiple nodes and weights
899        let mut graph = Graph::<OperatorType, WeightTensor>::directed();
900
901        // Add nodes with different operator types
902        let embed_node = graph
903            .add_node(OperatorType::Embedding {
904                vocab_size: 1000,
905                embed_dim: 128,
906            })
907            .unwrap();
908
909        let attn_node = graph
910            .add_node(OperatorType::Attention {
911                num_heads: 8,
912                hidden_dim: 256,
913            })
914            .unwrap();
915
916        let mlp_node = graph
917            .add_node(OperatorType::MLP {
918                hidden_dim: 512,
919                activation: "relu".to_string(),
920            })
921            .unwrap();
922
923        let norm_node = graph
924            .add_node(OperatorType::Norm {
925                norm_type: "layernorm".to_string(),
926                eps: 1e-5,
927            })
928            .unwrap();
929
930        // Add edges with weight tensors
931        graph
932            .add_edge(
933                embed_node,
934                embed_node,
935                WeightTensor::new(
936                    "model.embeddings.weight".to_string(),
937                    vec![1.0; 1000 * 128],
938                    vec![1000, 128],
939                ),
940            )
941            .unwrap();
942
943        graph
944            .add_edge(
945                attn_node,
946                attn_node,
947                WeightTensor::new(
948                    "model.layers.0.attention.qkv.weight".to_string(),
949                    vec![0.5; 256 * 3 * 256],
950                    vec![256, 3, 256],
951                ),
952            )
953            .unwrap();
954
955        graph
956            .add_edge(
957                mlp_node,
958                mlp_node,
959                WeightTensor::new(
960                    "model.layers.0.mlp.fc1.weight".to_string(),
961                    vec![0.25; 256 * 512],
962                    vec![256, 512],
963                ),
964            )
965            .unwrap();
966
967        graph
968            .add_edge(
969                norm_node,
970                norm_node,
971                WeightTensor::new(
972                    "model.norm.weight".to_string(),
973                    vec![1.0; 256],
974                    vec![256],
975                ),
976            )
977            .unwrap();
978
979        // Add edges between nodes to create a proper graph structure
980        graph.add_edge(embed_node, attn_node, WeightTensor::new(
981            "model.embed_to_attn.weight".to_string(),
982            vec![0.1; 128 * 256],
983            vec![128, 256],
984        )).unwrap();
985
986        graph.add_edge(attn_node, mlp_node, WeightTensor::new(
987            "model.attn_to_mlp.weight".to_string(),
988            vec![0.2; 256 * 256],
989            vec![256, 256],
990        )).unwrap();
991
992        graph.add_edge(mlp_node, norm_node, WeightTensor::new(
993            "model.mlp_to_norm.weight".to_string(),
994            vec![0.3; 512 * 256],
995            vec![512, 256],
996        )).unwrap();
997
998        // Create a temporary file path
999        let temp_path = PathBuf::from("test_save_to_safetensors_temp.safetensors");
1000
1001        // Save to safetensors
1002        let save_result = ModelSwitch::save_to_safetensors(&graph, &temp_path);
1003        assert!(save_result.is_ok(), "Failed to save to safetensors: {:?}", save_result);
1004
1005        // Verify file was created
1006        assert!(temp_path.exists(), "Safetensors file was not created");
1007
1008        // Load back from safetensors
1009        let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path);
1010        assert!(loaded_graph.is_ok(), "Failed to load from safetensors: {:?}", loaded_graph);
1011        let loaded_graph = loaded_graph.unwrap();
1012
1013        // Note: The current load_from_safetensors implementation creates one node per tensor
1014        // (with self-loop edges), so node/edge count will match the number of tensors
1015        // The important thing is that weight data is preserved
1016
1017        // Verify weight count (7 tensors in total)
1018        assert_eq!(
1019            7,
1020            loaded_graph.edge_count(),
1021            "Edge count should match number of tensors"
1022        );
1023
1024        // Verify weights using verify_weights - compare edge data only
1025        // Since node structure changes, we just verify the weight tensors are preserved
1026        let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
1027        println!("Save/Load round-trip weight diff: max={:.6e}, avg={:.6e}, count={}", 
1028                 diff.max_l2_diff, diff.avg_l2_diff, diff.tensor_count);
1029        
1030        // Allow small floating point errors from F32 conversion
1031        assert!(
1032            diff.max_l2_diff < 1e-5,
1033            "Weight difference too large: max_l2_diff={}",
1034            diff.max_l2_diff
1035        );
1036
1037        // Clean up temporary file
1038        let _ = fs::remove_file(&temp_path);
1039    }
1040
1041    #[test]
1042    #[cfg(feature = "safetensors")]
1043    fn test_save_load_round_trip() {
1044        use std::fs;
1045        use std::path::PathBuf;
1046
1047        // Create a simple test graph
1048        let mut graph = Graph::<OperatorType, WeightTensor>::directed();
1049
1050        let node = graph
1051            .add_node(OperatorType::Linear {
1052                in_features: 64,
1053                out_features: 64,
1054            })
1055            .unwrap();
1056
1057        // Add weight tensor
1058        let original_data: Vec<f64> = (0..64 * 64).map(|i| (i as f64) * 0.01).collect();
1059        graph
1060            .add_edge(
1061                node,
1062                node,
1063                WeightTensor::new(
1064                    "test.linear.weight".to_string(),
1065                    original_data.clone(),
1066                    vec![64, 64],
1067                ),
1068            )
1069            .unwrap();
1070
1071        // Save and load back
1072        let temp_path = PathBuf::from("test_round_trip_temp.safetensors");
1073        
1074        ModelSwitch::save_to_safetensors(&graph, &temp_path).unwrap();
1075        let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path).unwrap();
1076
1077        // Compare original and loaded weights
1078        let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
1079        
1080        // The conversion F64 -> F32 -> F64 introduces small errors
1081        // For values in range [0, 64), F32 precision is ~1e-7 to 1e-6
1082        println!("Round-trip L2 diff: max={:.6e}, avg={:.6e}", diff.max_l2_diff, diff.avg_l2_diff);
1083        
1084        // Clean up
1085        let _ = fs::remove_file(&temp_path);
1086    }
1087}