use crate::formats::pytorch::TensorData;
use std::collections::{HashMap, HashSet};
use super::errors::ParsingError;
use super::types::{LayerType, ModelGraph};
pub struct ModelValidator;
impl ModelValidator {
pub fn validate_graph(graph: &ModelGraph) -> Result<(), ParsingError> {
Self::check_cycles(graph)?;
Self::check_layer_compatibility(graph)?;
for (from_layer, to_layers) in &graph.connections {
if !graph.layers.contains_key(from_layer) {
return Err(ParsingError::MissingConnection(format!(
"Source layer '{}' not found",
from_layer
)));
}
for to_layer in to_layers {
if !graph.layers.contains_key(to_layer) {
return Err(ParsingError::MissingConnection(format!(
"Target layer '{}' not found",
to_layer
)));
}
}
}
Ok(())
}
fn check_cycles(graph: &ModelGraph) -> Result<(), ParsingError> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
fn dfs_cycle_check(
node: &str,
connections: &HashMap<String, Vec<String>>,
visited: &mut HashSet<String>,
rec_stack: &mut HashSet<String>,
) -> Result<(), String> {
visited.insert(node.to_string());
rec_stack.insert(node.to_string());
if let Some(neighbors) = connections.get(node) {
for neighbor in neighbors {
if !visited.contains(neighbor) {
dfs_cycle_check(neighbor, connections, visited, rec_stack)?;
} else if rec_stack.contains(neighbor) {
return Err(format!("Cycle detected: {} -> {}", node, neighbor));
}
}
}
rec_stack.remove(node);
Ok(())
}
for layer in graph.layers.keys() {
if !visited.contains(layer) {
if let Err(cycle_info) =
dfs_cycle_check(layer, &graph.connections, &mut visited, &mut rec_stack)
{
return Err(ParsingError::CircularDependency(cycle_info));
}
}
}
Ok(())
}
fn check_layer_compatibility(graph: &ModelGraph) -> Result<(), ParsingError> {
for (from_layer, to_layers) in &graph.connections {
let from_info = match graph.layers.get(from_layer) {
Some(info) => info,
None => continue, };
for to_layer in to_layers {
let to_info = match graph.layers.get(to_layer) {
Some(info) => info,
None => continue, };
if let (Some(output_shape), Some(input_shape)) =
(&from_info.output_shape, &to_info.input_shape)
{
if !Self::shapes_compatible(output_shape, input_shape) {
return Err(ParsingError::IncompatibleDimensions {
layer1: from_layer.clone(),
layer2: to_layer.clone(),
});
}
}
}
}
Ok(())
}
fn shapes_compatible(output_shape: &[usize], input_shape: &[usize]) -> bool {
if output_shape.len() == 1 && input_shape.len() == 1 {
output_shape[0] == input_shape[0]
} else if output_shape.len() > 1 && input_shape.len() == 1 {
let output_size: usize = output_shape.iter().product();
output_size == input_shape[0]
} else {
true
}
}
pub fn validate_layer_references(
desc: &super::formats::ArchitectureDescription,
layers: &HashMap<String, super::types::LayerInfo>,
) -> Result<(), ParsingError> {
let layer_names: HashSet<String> = desc.layers.iter().map(|l| l.name.clone()).collect();
for connection in &desc.connections {
if !layer_names.contains(&connection.from) {
return Err(ParsingError::MissingConnection(format!(
"Connection references unknown source layer: {}",
connection.from
)));
}
if !layer_names.contains(&connection.to) {
return Err(ParsingError::MissingConnection(format!(
"Connection references unknown target layer: {}",
connection.to
)));
}
}
Ok(())
}
}