use crate::errors::{GraphError, GraphResult};
use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
use crate::graph::Graph;
use smallvec::SmallVec;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
pub enum OperatorType {
Attention {
num_heads: usize,
hidden_dim: usize,
},
MLP {
hidden_dim: usize,
activation: String,
},
Norm {
norm_type: String,
eps: f64,
},
Embedding {
vocab_size: usize,
embed_dim: usize,
},
Linear {
in_features: usize,
out_features: usize,
},
Residual,
Custom {
name: String,
},
}
#[repr(align(64))]
#[derive(Clone, Debug)]
pub struct WeightTensor {
pub data: Box<[f64]>,
pub shape: SmallVec<[usize; 4]>,
pub strides: SmallVec<[usize; 4]>,
pub name: String,
}
impl WeightTensor {
pub fn new(name: String, data: Vec<f64>, shape: Vec<usize>) -> Self {
let expected_len = shape.iter().product::<usize>();
assert_eq!(
data.len(),
expected_len,
"Data length {} mismatch with shape {:?} (expected {})",
data.len(),
shape,
expected_len
);
let strides = compute_strides(&shape);
Self {
data: data.into_boxed_slice(),
shape: shape.into(),
strides: strides.into(),
name,
}
}
pub fn with_strides(
name: String,
data: Vec<f64>,
shape: Vec<usize>,
strides: Vec<usize>,
) -> Self {
let expected_len = shape.iter().product::<usize>();
assert_eq!(
data.len(),
expected_len,
"Data length {} mismatch with shape {:?}",
data.len(),
shape
);
Self {
data: data.into_boxed_slice(),
shape: shape.into(),
strides: strides.into(),
name,
}
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn numel(&self) -> usize {
self.data.len()
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn data(&self) -> &[f64] {
&self.data
}
pub fn as_slice_mut(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn reshape_mut(&mut self, new_shape: Vec<usize>) -> Result<(), TensorReshapeError> {
let new_size = new_shape.iter().product::<usize>();
if new_size != self.data.len() {
return Err(TensorReshapeError {
expected: self.data.len(),
got: new_size,
});
}
self.shape = new_shape.into();
self.strides = compute_strides(&self.shape).into();
Ok(())
}
pub fn l2_norm(&self) -> f64 {
self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn l2_diff(&self, other: &Self) -> f64 {
if self.shape != other.shape {
return f64::MAX;
}
self.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt()
}
pub fn get(&self, indices: &[usize]) -> Option<f64> {
if indices.len() != self.shape.len() {
return None;
}
for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
if idx >= dim {
return None;
}
}
let offset = indices
.iter()
.zip(self.strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum::<usize>();
self.data.get(offset).copied()
}
pub fn set(&mut self, indices: &[usize], value: f64) -> bool {
if indices.len() != self.shape.len() {
return false;
}
for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
if idx >= dim {
return false;
}
}
let offset = indices
.iter()
.zip(self.strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum::<usize>();
if let Some(elem) = self.data.get_mut(offset) {
*elem = value;
true
} else {
false
}
}
#[cfg(feature = "tensor")]
pub fn to_dense_tensor(&self) -> crate::tensor::DenseTensor {
crate::tensor::DenseTensor::new(self.data.to_vec(), self.shape.to_vec())
}
}
#[derive(Debug, Clone)]
pub struct TensorReshapeError {
pub expected: usize,
pub got: usize,
}
impl std::fmt::Display for TensorReshapeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Reshape error: expected {} elements, got {}",
self.expected, self.got
)
}
}
impl std::error::Error for TensorReshapeError {}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[derive(Debug, Clone)]
pub struct TopologyReport {
pub is_valid: bool,
pub node_count: usize,
pub edge_count: usize,
pub connected_components: usize,
pub is_dag: bool,
pub issues: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct WeightDiff {
pub max_l2_diff: f64,
pub avg_l2_diff: f64,
pub tensor_count: usize,
pub per_tensor_diff: HashMap<String, f64>,
}
pub struct ModelSwitch;
impl ModelSwitch {
#[cfg(feature = "safetensors")]
pub fn load_from_safetensors<P: AsRef<Path>>(path: P) -> GraphResult<Graph<OperatorType, WeightTensor>> {
use safetensors::SafeTensors;
use std::fs::File;
use std::io::Read;
let mut file = File::open(path.as_ref())
.map_err(|e| GraphError::IoError(format!("Failed to open file: {}", e)))?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)
.map_err(|e| GraphError::IoError(format!("Failed to read file: {}", e)))?;
let safetensors = SafeTensors::deserialize(&buffer)
.map_err(|e| GraphError::InvalidFormat(format!("Failed to deserialize safetensors: {}", e)))?;
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
for (name, tensor_view) in safetensors.tensors() {
let shape = tensor_view.shape().to_vec();
let dtype = tensor_view.dtype();
let data = match dtype {
safetensors::Dtype::F32 => {
let slice = tensor_view.data();
match bytemuck::try_cast_slice::<u8, f32>(slice) {
Ok(f32_data) => f32_data.iter().map(|&x| x as f64).collect(),
Err(_) => {
slice.chunks_exact(4)
.map(|chunk| {
let bytes: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
f32::from_le_bytes(bytes) as f64
})
.collect()
}
}
}
safetensors::Dtype::F64 => {
let slice = tensor_view.data();
match bytemuck::try_cast_slice::<u8, f64>(slice) {
Ok(f64_data) => f64_data.to_vec(),
Err(_) => {
slice.chunks_exact(8)
.map(|chunk| {
let bytes: [u8; 8] = [chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7]];
f64::from_le_bytes(bytes)
})
.collect()
}
}
}
safetensors::Dtype::F16 => {
let slice = tensor_view.data();
let f16_data: Vec<half::f16> = slice
.chunks_exact(2)
.map(|chunk| half::f16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]])))
.collect();
f16_data.iter().map(|x| x.to_f32() as f64).collect()
}
_ => {
return Err(GraphError::InvalidFormat(
format!("Unsupported dtype: {:?}", dtype)
));
}
};
let weight_tensor = WeightTensor::new(name.to_string(), data, shape);
let operator = Self::infer_operator_from_name(&name);
let node = graph.add_node(operator)?;
graph.add_edge(node, node, weight_tensor)?;
}
Ok(graph)
}
#[cfg(feature = "safetensors")]
pub fn save_to_safetensors<P: AsRef<Path>>(
graph: &Graph<OperatorType, WeightTensor>,
path: P,
) -> GraphResult<()> {
use std::collections::BTreeMap;
use safetensors::tensor::{TensorView, Dtype};
let mut tensor_data: BTreeMap<String, (Vec<u8>, Vec<usize>)> = BTreeMap::new();
for edge_ref in graph.edges() {
let weight = edge_ref.data();
let data_f32: Vec<f32> = weight.data.iter()
.map(|&x| x as f32)
.collect();
let byte_data: Vec<u8> = data_f32.iter()
.flat_map(|&x| x.to_le_bytes().to_vec())
.collect();
tensor_data.insert(
weight.name.clone(),
(byte_data, weight.shape.to_vec()),
);
}
let mut tensors: BTreeMap<String, TensorView> = BTreeMap::new();
for (name, (bytes, shape)) in &tensor_data {
let tensor_view = TensorView::new(
Dtype::F32,
shape.clone(),
bytes,
).map_err(|e| GraphError::InvalidFormat(format!("Failed to create tensor view: {}", e)))?;
tensors.insert(name.clone(), tensor_view);
}
let metadata: Option<std::collections::HashMap<String, String>> = None;
safetensors::serialize_to_file(&tensors, &metadata, path.as_ref())
.map_err(|e| GraphError::IoError(format!("Failed to write safetensors file: {}", e)))?;
Ok(())
}
pub fn validate_topology(
graph: &Graph<OperatorType, WeightTensor>,
) -> GraphResult<TopologyReport> {
use crate::algorithms::community::connected_components;
use crate::algorithms::traversal::topological_sort;
let node_count = graph.node_count();
let edge_count = graph.edge_count();
let mut issues = Vec::new();
if node_count == 0 {
issues.push("Graph is empty".to_string());
return Ok(TopologyReport {
is_valid: false,
node_count,
edge_count,
connected_components: 0,
is_dag: true,
issues,
});
}
let components = connected_components(graph);
if components.len() > 1 {
issues.push(format!("Graph has {} disconnected components", components.len()));
}
let is_dag = topological_sort(graph).is_ok();
if !is_dag {
issues.push("Graph contains cycles (may be valid for recurrent models)".to_string());
}
let isolated_count = graph
.nodes()
.filter(|n| graph.neighbors(n.index()).count() == 0)
.count();
if isolated_count > 0 {
issues.push(format!("Graph has {} isolated nodes", isolated_count));
}
let is_valid = issues.is_empty() || (components.len() == 1 && isolated_count == 0);
Ok(TopologyReport {
is_valid,
node_count,
edge_count,
connected_components: components.len(),
is_dag,
issues,
})
}
pub fn verify_weights(
original: &Graph<OperatorType, WeightTensor>,
modified: &Graph<OperatorType, WeightTensor>,
) -> GraphResult<WeightDiff> {
let mut per_tensor_diff: HashMap<String, f64> = HashMap::new();
let mut max_l2_diff = 0.0f64;
let mut total_diff = 0.0f64;
let mut tensor_count = 0;
let original_weights: HashMap<String, &WeightTensor> = original.edges()
.map(|e| (e.data().name.clone(), e.data()))
.collect();
for edge_ref in modified.edges() {
let modified_weight = edge_ref.data();
if let Some(&original_weight) = original_weights.get(&modified_weight.name) {
if original_weight.shape != modified_weight.shape {
per_tensor_diff.insert(
modified_weight.name.clone(),
f64::MAX,
);
max_l2_diff = f64::MAX;
tensor_count += 1;
continue;
}
let l2_diff = original_weight.l2_diff(modified_weight);
per_tensor_diff.insert(modified_weight.name.clone(), l2_diff);
if l2_diff > max_l2_diff {
max_l2_diff = l2_diff;
}
total_diff += l2_diff;
tensor_count += 1;
} else {
per_tensor_diff.insert(
modified_weight.name.clone(),
f64::MAX,
);
tensor_count += 1;
}
}
for name in original_weights.keys() {
if !per_tensor_diff.contains_key(name) {
per_tensor_diff.insert(name.clone(), f64::MAX);
tensor_count += 1;
}
}
let avg_l2_diff = if tensor_count > 0 {
total_diff / tensor_count as f64
} else {
0.0
};
Ok(WeightDiff {
max_l2_diff,
avg_l2_diff,
tensor_count,
per_tensor_diff,
})
}
#[allow(dead_code)]
fn infer_operator_from_name(name: &str) -> OperatorType {
let name_lower = name.to_lowercase();
if name_lower.contains("attention") || name_lower.contains("attn") {
OperatorType::Attention {
num_heads: 32,
hidden_dim: 4096,
}
} else if name_lower.contains("mlp") || name_lower.contains("ffn") {
OperatorType::MLP {
hidden_dim: 11008,
activation: "silu".to_string(),
}
} else if name_lower.contains("norm") || name_lower.contains("ln") {
OperatorType::Norm {
norm_type: "rmsnorm".to_string(),
eps: 1e-6,
}
} else if name_lower.contains("embed") {
OperatorType::Embedding {
vocab_size: 32000,
embed_dim: 4096,
}
} else if name_lower.contains("linear") || name_lower.contains("proj") {
OperatorType::Linear {
in_features: 4096,
out_features: 4096,
}
} else {
OperatorType::Custom {
name: name.to_string(),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weight_tensor_l2_norm() {
let tensor = WeightTensor::new(
"test".to_string(),
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
);
let norm = tensor.l2_norm();
assert!((norm - 5.477).abs() < 0.001);
}
#[test]
fn test_weight_tensor_l2_diff() {
let t1 = WeightTensor::new(
"test1".to_string(),
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
);
let t2 = WeightTensor::new(
"test2".to_string(),
vec![1.1, 2.1, 3.1, 4.1],
vec![2, 2],
);
let diff = t1.l2_diff(&t2);
assert!(diff < 0.5);
}
#[test]
fn test_weight_tensor_reshape_mut() {
let mut tensor = WeightTensor::new(
"test".to_string(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
);
tensor.reshape_mut(vec![3, 2]).unwrap();
assert_eq!(tensor.shape(), &[3, 2]);
assert_eq!(tensor.strides(), &[2, 1]);
let result = tensor.reshape_mut(vec![2, 2]);
assert!(result.is_err());
}
#[test]
fn test_weight_tensor_stride_access() {
let tensor = WeightTensor::new(
"test".to_string(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
);
assert_eq!(tensor.get(&[0, 0]), Some(1.0));
assert_eq!(tensor.get(&[0, 1]), Some(2.0));
assert_eq!(tensor.get(&[0, 2]), Some(3.0));
assert_eq!(tensor.get(&[1, 0]), Some(4.0));
assert_eq!(tensor.get(&[1, 1]), Some(5.0));
assert_eq!(tensor.get(&[1, 2]), Some(6.0));
assert_eq!(tensor.get(&[2, 0]), None);
assert_eq!(tensor.get(&[0, 3]), None);
}
#[test]
fn test_weight_tensor_set() {
let mut tensor = WeightTensor::new(
"test".to_string(),
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
);
assert!(tensor.set(&[0, 1], 10.0));
assert!(tensor.set(&[1, 0], 20.0));
assert_eq!(tensor.get(&[0, 0]), Some(1.0));
assert_eq!(tensor.get(&[0, 1]), Some(10.0));
assert_eq!(tensor.get(&[1, 0]), Some(20.0));
assert_eq!(tensor.get(&[1, 1]), Some(4.0));
assert!(!tensor.set(&[2, 0], 100.0));
}
#[test]
fn test_weight_tensor_ndim_and_numel() {
let tensor = WeightTensor::new(
"test".to_string(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
);
assert_eq!(tensor.ndim(), 2);
assert_eq!(tensor.numel(), 6);
}
#[test]
fn test_weight_tensor_struct_size() {
use std::mem::size_of;
assert!(size_of::<WeightTensor>() >= 64);
let tensor = WeightTensor::new(
"test".to_string(),
vec![1.0; 100],
vec![10, 10],
);
assert_eq!(tensor.numel(), 100);
}
#[test]
fn test_compute_strides() {
assert_eq!(compute_strides(&[5]), vec![1]);
assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
assert_eq!(compute_strides(&[2, 3, 4, 5]), vec![60, 20, 5, 1]);
let empty: &[usize] = &[];
assert_eq!(compute_strides(empty), Vec::<usize>::new());
}
#[test]
fn test_infer_operator_from_name() {
assert!(matches!(
ModelSwitch::infer_operator_from_name("model.layers.0.self_attn.q_proj"),
OperatorType::Attention { .. }
));
assert!(matches!(
ModelSwitch::infer_operator_from_name("model.layers.0.mlp.gate_proj"),
OperatorType::MLP { .. }
));
assert!(matches!(
ModelSwitch::infer_operator_from_name("model.norm.weight"),
OperatorType::Norm { .. }
));
}
#[test]
#[cfg(feature = "safetensors")]
fn test_save_to_safetensors() {
use std::fs;
use std::path::PathBuf;
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let embed_node = graph
.add_node(OperatorType::Embedding {
vocab_size: 1000,
embed_dim: 128,
})
.unwrap();
let attn_node = graph
.add_node(OperatorType::Attention {
num_heads: 8,
hidden_dim: 256,
})
.unwrap();
let mlp_node = graph
.add_node(OperatorType::MLP {
hidden_dim: 512,
activation: "relu".to_string(),
})
.unwrap();
let norm_node = graph
.add_node(OperatorType::Norm {
norm_type: "layernorm".to_string(),
eps: 1e-5,
})
.unwrap();
graph
.add_edge(
embed_node,
embed_node,
WeightTensor::new(
"model.embeddings.weight".to_string(),
vec![1.0; 1000 * 128],
vec![1000, 128],
),
)
.unwrap();
graph
.add_edge(
attn_node,
attn_node,
WeightTensor::new(
"model.layers.0.attention.qkv.weight".to_string(),
vec![0.5; 256 * 3 * 256],
vec![256, 3, 256],
),
)
.unwrap();
graph
.add_edge(
mlp_node,
mlp_node,
WeightTensor::new(
"model.layers.0.mlp.fc1.weight".to_string(),
vec![0.25; 256 * 512],
vec![256, 512],
),
)
.unwrap();
graph
.add_edge(
norm_node,
norm_node,
WeightTensor::new(
"model.norm.weight".to_string(),
vec![1.0; 256],
vec![256],
),
)
.unwrap();
graph.add_edge(embed_node, attn_node, WeightTensor::new(
"model.embed_to_attn.weight".to_string(),
vec![0.1; 128 * 256],
vec![128, 256],
)).unwrap();
graph.add_edge(attn_node, mlp_node, WeightTensor::new(
"model.attn_to_mlp.weight".to_string(),
vec![0.2; 256 * 256],
vec![256, 256],
)).unwrap();
graph.add_edge(mlp_node, norm_node, WeightTensor::new(
"model.mlp_to_norm.weight".to_string(),
vec![0.3; 512 * 256],
vec![512, 256],
)).unwrap();
let temp_path = PathBuf::from("test_save_to_safetensors_temp.safetensors");
let save_result = ModelSwitch::save_to_safetensors(&graph, &temp_path);
assert!(save_result.is_ok(), "Failed to save to safetensors: {:?}", save_result);
assert!(temp_path.exists(), "Safetensors file was not created");
let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path);
assert!(loaded_graph.is_ok(), "Failed to load from safetensors: {:?}", loaded_graph);
let loaded_graph = loaded_graph.unwrap();
assert_eq!(
7,
loaded_graph.edge_count(),
"Edge count should match number of tensors"
);
let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
println!("Save/Load round-trip weight diff: max={:.6e}, avg={:.6e}, count={}",
diff.max_l2_diff, diff.avg_l2_diff, diff.tensor_count);
assert!(
diff.max_l2_diff < 1e-5,
"Weight difference too large: max_l2_diff={}",
diff.max_l2_diff
);
let _ = fs::remove_file(&temp_path);
}
#[test]
#[cfg(feature = "safetensors")]
fn test_save_load_round_trip() {
use std::fs;
use std::path::PathBuf;
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let node = graph
.add_node(OperatorType::Linear {
in_features: 64,
out_features: 64,
})
.unwrap();
let original_data: Vec<f64> = (0..64 * 64).map(|i| (i as f64) * 0.01).collect();
graph
.add_edge(
node,
node,
WeightTensor::new(
"test.linear.weight".to_string(),
original_data.clone(),
vec![64, 64],
),
)
.unwrap();
let temp_path = PathBuf::from("test_round_trip_temp.safetensors");
ModelSwitch::save_to_safetensors(&graph, &temp_path).unwrap();
let loaded_graph = ModelSwitch::load_from_safetensors(&temp_path).unwrap();
let diff = ModelSwitch::verify_weights(&graph, &loaded_graph).unwrap();
println!("Round-trip L2 diff: max={:.6e}, avg={:.6e}", diff.max_l2_diff, diff.avg_l2_diff);
let _ = fs::remove_file(&temp_path);
}
}