use std::collections::HashSet;
use tensorlogic_ir::{EinsumGraph, OpType};
use crate::error::ExecutorError;
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
impl ValidationResult {
pub fn new() -> Self {
ValidationResult {
is_valid: true,
errors: Vec::new(),
warnings: Vec::new(),
}
}
pub fn add_error(&mut self, error: impl Into<String>) {
self.is_valid = false;
self.errors.push(error.into());
}
pub fn add_warning(&mut self, warning: impl Into<String>) {
self.warnings.push(warning.into());
}
pub fn merge(&mut self, other: ValidationResult) {
self.is_valid &= other.is_valid;
self.errors.extend(other.errors);
self.warnings.extend(other.warnings);
}
pub fn summary(&self) -> String {
let mut summary = String::new();
if self.is_valid {
summary.push_str("✓ Graph is valid\n");
} else {
summary.push_str("✗ Graph validation failed\n");
}
if !self.errors.is_empty() {
summary.push_str("\nErrors:\n");
for error in &self.errors {
summary.push_str(&format!(" - {}\n", error));
}
}
if !self.warnings.is_empty() {
summary.push_str("\nWarnings:\n");
for warning in &self.warnings {
summary.push_str(&format!(" - {}\n", warning));
}
}
summary
}
}
impl Default for ValidationResult {
fn default() -> Self {
Self::new()
}
}
pub struct GraphValidator;
impl GraphValidator {
pub fn new() -> Self {
GraphValidator
}
pub fn validate(&self, graph: &EinsumGraph) -> ValidationResult {
let mut result = ValidationResult::new();
if graph.nodes.is_empty() {
result.add_warning("Graph has no computation nodes");
}
self.validate_tensor_indices(graph, &mut result);
self.validate_dependencies(graph, &mut result);
self.validate_operations(graph, &mut result);
self.validate_dag(graph, &mut result);
result
}
fn validate_tensor_indices(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
let num_tensors = graph.tensors.len();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &input_idx in &node.inputs {
let max_valid_idx = num_tensors + node_idx;
if input_idx >= max_valid_idx {
result.add_error(format!(
"Node {} references invalid tensor index {} (max valid: {})",
node_idx, input_idx, max_valid_idx
));
}
}
}
}
fn validate_dependencies(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
let num_tensors = graph.tensors.len();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &input_idx in &node.inputs {
if input_idx >= num_tensors {
let dep_node_idx = input_idx - num_tensors;
if dep_node_idx >= node_idx {
result.add_error(format!(
"Node {} has forward dependency on node {}",
node_idx, dep_node_idx
));
}
}
}
}
}
fn validate_operations(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
for (node_idx, node) in graph.nodes.iter().enumerate() {
match &node.op {
OpType::Einsum { spec } => {
if spec.is_empty() {
result.add_error(format!("Node {} has empty einsum spec", node_idx));
}
if node.inputs.is_empty() {
result.add_error(format!("Node {} einsum has no inputs", node_idx));
}
}
OpType::ElemUnary { op: _ } => {
if node.inputs.len() != 1 {
result.add_error(format!(
"Node {} unary operation requires exactly 1 input, got {}",
node_idx,
node.inputs.len()
));
}
}
OpType::ElemBinary { op: _ } => {
if node.inputs.len() != 2 {
result.add_error(format!(
"Node {} binary operation requires exactly 2 inputs, got {}",
node_idx,
node.inputs.len()
));
}
}
OpType::Reduce { op: _, axes } => {
if node.inputs.len() != 1 {
result.add_error(format!(
"Node {} reduce operation requires exactly 1 input, got {}",
node_idx,
node.inputs.len()
));
}
if axes.is_empty() {
result.add_warning(format!(
"Node {} reduce operation has no axes (identity operation)",
node_idx
));
}
}
}
}
}
fn validate_dag(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
let num_nodes = graph.nodes.len();
let num_tensors = graph.tensors.len();
let mut visited = vec![false; num_nodes];
let mut rec_stack = vec![false; num_nodes];
for node_idx in 0..num_nodes {
if !visited[node_idx]
&& has_cycle_helper(node_idx, graph, num_tensors, &mut visited, &mut rec_stack)
{
result.add_error("Graph contains a cycle (not a DAG)");
break;
}
}
}
pub fn validate_or_error(&self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
let result = self.validate(graph);
if result.is_valid {
Ok(())
} else {
Err(ExecutorError::GraphValidationError(
result.errors.join("; "),
))
}
}
pub fn find_unreachable_nodes(&self, graph: &EinsumGraph) -> HashSet<usize> {
let num_nodes = graph.nodes.len();
let num_tensors = graph.tensors.len();
let mut reachable = HashSet::new();
if num_nodes > 0 {
let mut to_visit = vec![num_nodes - 1];
while let Some(node_idx) = to_visit.pop() {
if reachable.insert(node_idx) {
for &input_idx in &graph.nodes[node_idx].inputs {
if input_idx >= num_tensors {
let dep_node_idx = input_idx - num_tensors;
if !reachable.contains(&dep_node_idx) {
to_visit.push(dep_node_idx);
}
}
}
}
}
}
(0..num_nodes)
.filter(|idx| !reachable.contains(idx))
.collect()
}
}
#[allow(clippy::only_used_in_recursion)]
fn has_cycle_helper(
node_idx: usize,
graph: &EinsumGraph,
num_tensors: usize,
visited: &mut [bool],
rec_stack: &mut [bool],
) -> bool {
visited[node_idx] = true;
rec_stack[node_idx] = true;
for &input_idx in &graph.nodes[node_idx].inputs {
if input_idx >= num_tensors {
let dep_node_idx = input_idx - num_tensors;
if dep_node_idx >= visited.len() {
continue; }
if !visited[dep_node_idx] {
if has_cycle_helper(dep_node_idx, graph, num_tensors, visited, rec_stack) {
return true;
}
} else if rec_stack[dep_node_idx] {
return true;
}
}
}
rec_stack[node_idx] = false;
false
}
impl Default for GraphValidator {
fn default() -> Self {
Self::new()
}
}