pub mod debug;
pub mod execution;
pub mod metrics;
pub mod operations;
pub mod shape_inference;
pub mod type_checking;
pub use operations::{
execute_registered_operation, global_registry, is_operation_registered, register_operation,
CustomOperation, OperationRegistry,
};
pub use shape_inference::{ShapeInferenceContext, ShapeInfo};
pub use type_checking::TypeCheckingContext;
pub use execution::{interpret, interpret_with_inputs, ExecutionEnvironment, GraphInterpreter};
pub use metrics::{ExecutionMetrics, ExecutionTimer, MetricsCollector};
pub use debug::{utils, DebugExecutionEnvironment};
use crate::{FxGraph, TorshResult};
use std::collections::HashMap;
use torsh_core::dtype::DType;
pub fn infer_graph_shapes(
graph: &FxGraph,
input_shapes: HashMap<String, ShapeInfo>,
) -> TorshResult<HashMap<petgraph::graph::NodeIndex, ShapeInfo>> {
let mut context = ShapeInferenceContext::new();
context.infer_shapes(graph, input_shapes)?;
Ok(context.get_all_shapes().clone())
}
pub fn check_graph_types(
graph: &FxGraph,
input_types: HashMap<String, DType>,
) -> TorshResult<HashMap<petgraph::graph::NodeIndex, DType>> {
let mut context = TypeCheckingContext::new();
context.check_types(graph, input_types)?;
Ok(context.get_all_types().clone())
}
pub fn validate_graph(graph: &FxGraph) -> TorshResult<String> {
debug::utils::validate_graph_structure(graph)?;
debug::utils::validate_graph_executability(graph)?;
let summary = debug::utils::generate_execution_summary(graph);
let description = debug::utils::describe_graph(graph);
Ok(format!(
"Graph Validation: PASSED\n\n{}\n\n{}",
description, summary
))
}
pub fn system_info() -> String {
let registry = global_registry();
let operation_count = registry.operation_count();
let registered_operations = registry.list_operations();
let builtin_ops = [
"add",
"sub",
"mul",
"div",
"matmul",
"relu",
"sigmoid",
"tanh",
"gelu",
"softmax",
"layer_norm",
"batch_norm",
"conv2d",
"linear",
"linear_relu",
"conv2d_relu",
];
format!(
"ToRSh FX Graph Interpreter System\n\
===================================\n\
\n\
Modules:\n\
- Operations: Custom operation registry and management\n\
- Shape Inference: Graph shape analysis and inference\n\
- Type Checking: Graph type validation and checking\n\
- Execution: Core graph interpretation engine\n\
- Metrics: Performance monitoring and metrics collection\n\
- Debug: Debugging capabilities and development tools\n\
\n\
Built-in Operations: {} available\n\
{}\n\
\n\
Registered Custom Operations: {}\n\
{}\n\
\n\
Capabilities:\n\
- Graph execution with topological ordering\n\
- Custom operation registration and execution\n\
- Comprehensive shape inference\n\
- Type checking and validation\n\
- Performance monitoring and profiling\n\
- Debug execution environments\n\
- Graph structure validation\n\
- Backward compatibility maintained",
builtin_ops.len(),
builtin_ops.join(", "),
operation_count,
if registered_operations.is_empty() {
"None".to_string()
} else {
registered_operations.join(", ")
}
)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::zeros;
#[test]
fn test_system_info() {
let info = system_info();
assert!(info.contains("ToRSh FX Graph Interpreter System"));
assert!(info.contains("Built-in Operations"));
}
#[test]
fn test_execution_metrics() {
let mut metrics = ExecutionMetrics::new();
metrics.add_operation_time("add", 1.5);
metrics.add_operation_time("mul", 2.3);
assert_eq!(metrics.operation_count, 2);
assert_eq!(metrics.get_operation_time("add"), 1.5);
assert_eq!(metrics.get_operation_time("mul"), 2.3);
let report = metrics.generate_report();
assert!(report.contains("Performance Report"));
}
#[test]
fn test_debug_environment() {
let mut debug_env = DebugExecutionEnvironment::new(DeviceType::Cpu, true);
debug_env.log("Test message".to_string());
assert_eq!(debug_env.get_log().len(), 1);
assert_eq!(debug_env.get_log()[0], "Test message");
let summary = debug_env.execution_summary();
assert!(summary.contains("Debug Environment"));
}
#[test]
fn test_shape_info() {
use torsh_core::dtype::DType;
use torsh_core::shape::Shape;
let shape = Shape::new(vec![2, 3, 4]);
let dtype = DType::F32;
let shape_info = ShapeInfo::new(shape.clone(), dtype);
assert_eq!(shape_info.shape.dims(), shape.dims());
assert_eq!(shape_info.dtype, dtype);
}
#[test]
fn test_execution_environment() {
use petgraph::graph::NodeIndex;
let mut env = ExecutionEnvironment::new(DeviceType::Cpu);
let tensor = zeros(&[2, 2]).expect("zeros should succeed");
let node_idx = NodeIndex::new(0);
env.store(node_idx, tensor.clone());
assert!(env.has_value(node_idx));
assert_eq!(env.value_count(), 1);
let retrieved = env
.get(node_idx)
.expect("element retrieval should succeed for valid index");
assert_eq!(retrieved.shape().dims(), tensor.shape().dims());
}
#[test]
fn test_metrics_collector() {
let mut collector = MetricsCollector::new();
let mut metrics1 = ExecutionMetrics::new();
metrics1.add_operation_time("add", 10.0);
metrics1.set_total_time(10.0);
let mut metrics2 = ExecutionMetrics::new();
metrics2.add_operation_time("mul", 20.0);
metrics2.set_total_time(20.0);
collector.add_run(metrics1);
collector.add_run(metrics2);
assert_eq!(collector.run_count(), 2);
assert_eq!(collector.average_execution_time(), 15.0);
assert_eq!(collector.fastest_execution(), Some(10.0));
assert_eq!(collector.slowest_execution(), Some(20.0));
}
}