pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const VERSION_MAJOR: u32 = 0;
pub const VERSION_MINOR: u32 = 1;
pub const VERSION_PATCH: u32 = 0;
use torsh_core::Result;
pub type TorshResult<T> = Result<T>;
pub mod fx;
pub use fx::{Edge, FxGraph, GraphStats, MemoryEstimate, Node, SerializableGraph};
pub use benchmarking::{BenchmarkResult, GraphBenchmarkSuite, RegressionTester};
pub use checkpointing::{
create_checkpoint, load_checkpoint, save_checkpoint, CheckpointData, CheckpointFormat,
CheckpointManager, CheckpointMetadata, CheckpointOptions, ResumableInterpreter,
};
pub use codegen::{
CacheStats, CodeGenBackend, CodeGenerator, CompiledCode, CppCodeGen, LazyCompiler,
PythonCodeGen,
};
pub use custom_backends::{
execute_with_auto_backend, execute_with_backend, get_backend, list_available_backends,
register_backend_factory, BackendCapability, BackendContext, BackendExecutor, BackendFactory,
BackendInfo, BackendRegistry, BackendResult, BackendSelectionStrategy, CustomBackend,
};
pub use custom_operations::{
register_example_operations, CustomInt16AddOperation, CustomInt16MulOperation,
CustomInt16SubOperation, CustomTypeUnifyOperation, TypeConversionOperation,
};
pub use custom_types::{
global_extended_registry, register_extended_operation, CustomTypeUtils,
ExtendedCustomOperation, ExtendedOperationRegistry, ExtendedShapeInferenceContext,
ExtendedShapeInfo,
};
pub use distributed::{
create_execution_plan, execute_distributed, init_distributed, CollectiveOp,
CommunicationBackendType, DistributedConfig, DistributedExecutionPlan, DistributedExecutor,
DistributionStrategy, ReduceOp,
};
pub use dynamic_shapes::{
DynamicDim, DynamicShape, DynamicShapeInferenceContext, DynamicShapeInfo, ShapeConstraint,
};
pub use graph_analysis::{
calculate_graph_metrics, DetectedPattern, GraphDiff, GraphDifference, GraphLinter,
GraphMetrics, LintIssue, LintReport, LintSeverity, PatternDetector,
};
pub use graph_partitioning::{
DeviceInfo, DeviceType, GraphPartition, GraphPartitioner, PartitionedGraph,
PartitioningStrategy,
};
pub use heterogeneous_computing::{
DeviceCapability, ExecutionPlan, HeterogeneousExecutor, OperationSpecialization,
PlacementStrategy, SimpleDevice,
};
pub use memory_optimization::{
AdaptiveMemoryManager, AllocationStrategy, GraphMemoryLayout, MemoryAnalyzer,
MemoryMappedGraph, MemoryUsageReport,
};
pub use onnx_export::{export_to_onnx, OnnxExporter, OnnxModel};
pub use performance::{
CacheStatistics, GraphCache, GraphCompression, ParallelTraversal, PerformanceBottleneck,
PerformanceProfiler, PerformanceReport,
};
pub use torchscript_compat::{
TorchScriptExporter, TorchScriptGraph, TorchScriptImporter, TorchScriptModel,
};
pub use tracer::{Module, ModuleTracer, SymbolicTensor, TracingProxy};
pub use emerging_hardware::{
create_dna_backend, create_neuromorphic_backend, create_photonic_backend, AdaptationStrategy,
CompatibilityReport, EmergingHardware, EmergingHardwareBackend, EmergingHardwareResult,
ErrorCorrectionScheme, HardwareCapabilities, HardwareConstraint, HardwareSpecifications,
NeuromorphicProcessor, OptimizationObjective, PhotonicProcessor, PrecisionType,
QuantumInspiredProcessor, SpecializedOperation,
};
pub use interactive_editor::{
launch_interactive_editor, AutoSaveConfig, CollaborativeEdit, EditOperation, ExportFormat,
ImportFormat, InteractiveGraphEditor, PerformanceMetrics, UserSession, VisualizationConfig,
};
pub use neural_architecture_search::{
create_default_search_space, create_mobile_constraints, start_neural_architecture_search,
ArchitectureSearchSpace, CandidateArchitecture, HardwareConstraints, HardwarePlatform,
LayerType, NeuralArchitectureSearch, ObjectiveWeights, SearchResults, SearchStrategy,
};
pub use neuromorphic_optimization::{
create_loihi_optimizer, optimize_for_mobile_neuromorphic, EnergyEstimate, NeuromorphicHardware,
NeuromorphicOptimizationResult, NeuromorphicOptimizer, NeuronModel, OptimizationConfig,
SNNConversionParams, SpikeEncoding,
};
pub use python_integration::{
create_jax_integration, create_pytorch_integration, generate_python_api, graph_to_pytorch_code,
DeploymentPackage, GeneratedPythonCode, PyTorchModelMetadata, PythonBindingConfig,
PythonCodeGenOptions, PythonDeploymentTarget, PythonFramework, PythonIntegrationService,
TrainingInfo,
};
pub use quantization::{
apply_automatic_precision, prepare_graph_for_qat, quantize_graph_post_training,
select_automatic_precision, AutomaticPrecisionSelector, CalibrationData, PTQUtils,
PrecisionCriteria, PrecisionProfile, PrecisionRecommendation, PrecisionStrategy, QATUtils,
QuantizationAnnotation, QuantizationBenchmark, QuantizationContext, QuantizationParams,
QuantizationScheme,
};
pub use quantum_computing::{
create_local_quantum_backend, create_qaoa_circuit, create_qiskit_backend, create_vqe_circuit,
integrate_quantum_computing, CloudProvider, DataTransferType, ErrorMitigation,
HybridOptimizationStrategy, HybridWorkflow, NoiseModel, QuantumBackend, QuantumCircuit,
QuantumComputingBackend, QuantumExecutionResult, QuantumGate, QuantumPrecision, StateEncoding,
SynchronizationType,
};
pub mod checkpointing;
pub mod cloud_deployment;
pub mod codegen;
pub mod custom_backends;
pub mod custom_operations;
pub mod custom_types;
pub mod distributed;
pub mod dynamic_shapes;
pub mod emerging_hardware;
pub mod benchmarking;
pub mod graph_analysis;
pub mod graph_partitioning;
pub mod heterogeneous_computing;
pub mod interactive_editor;
pub mod interpreter;
pub mod memory_optimization;
pub mod model_zoo;
pub mod neural_architecture_search;
pub mod neuromorphic_optimization;
pub mod node;
pub mod onnx_export;
pub mod passes;
pub mod performance;
pub mod python_integration;
pub mod quantization;
pub mod quantum_computing;
pub mod subgraph_rewriter;
pub mod torchscript_compat;
pub mod tracer;
pub mod visualization;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_serialization_json() {
let mut graph = FxGraph::new();
let input = graph.graph.add_node(Node::Input("x".to_string()));
let relu = graph
.graph
.add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
let output = graph.graph.add_node(Node::Output);
graph.graph.add_edge(
input,
relu,
Edge {
name: "x".to_string(),
},
);
graph.graph.add_edge(
relu,
output,
Edge {
name: "relu_out".to_string(),
},
);
graph.inputs.push(input);
graph.outputs.push(output);
let json = graph.to_json().unwrap();
assert!(json.contains("Input"));
assert!(json.contains("relu"));
let deserialized = FxGraph::from_json(&json).unwrap();
assert_eq!(deserialized.node_count(), graph.node_count());
assert_eq!(deserialized.edge_count(), graph.edge_count());
}
#[test]
fn test_graph_serialization_binary() {
let mut graph = FxGraph::new();
let input = graph.graph.add_node(Node::Input("x".to_string()));
let relu = graph
.graph
.add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
let output = graph.graph.add_node(Node::Output);
graph.graph.add_edge(
input,
relu,
Edge {
name: "x".to_string(),
},
);
graph.graph.add_edge(
relu,
output,
Edge {
name: "relu_out".to_string(),
},
);
graph.inputs.push(input);
graph.outputs.push(output);
let binary = graph.to_binary().unwrap();
assert!(!binary.is_empty());
let deserialized = FxGraph::from_binary(&binary).unwrap();
assert_eq!(deserialized.node_count(), graph.node_count());
assert_eq!(deserialized.edge_count(), graph.edge_count());
}
#[test]
fn test_single_op_graph() {
let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
assert_eq!(graph.node_count(), 3); assert_eq!(graph.edge_count(), 2); assert_eq!(graph.inputs().len(), 1);
assert_eq!(graph.outputs().len(), 1);
assert!(graph.validate().is_ok());
let input_nodes = graph.input_nodes();
let call_nodes = graph.call_nodes();
let output_nodes = graph.output_nodes();
assert_eq!(input_nodes.len(), 1);
assert_eq!(call_nodes.len(), 1);
assert_eq!(output_nodes.len(), 1);
if let Node::Call(op_name, _) = &call_nodes[0].1 {
assert_eq!(op_name, "relu");
} else {
panic!("Expected Call node");
}
}
#[test]
fn test_sequential_ops_graph() {
let ops = vec!["relu", "sigmoid", "tanh"];
let graph = FxGraph::sequential_ops(&ops);
assert_eq!(graph.node_count(), 5); assert_eq!(graph.edge_count(), 4); assert_eq!(graph.inputs().len(), 1);
assert_eq!(graph.outputs().len(), 1);
assert!(graph.validate().is_ok());
let call_nodes = graph.call_nodes();
assert_eq!(call_nodes.len(), 3);
let mut op_names = Vec::new();
for (_, node) in call_nodes {
if let Node::Call(op_name, _) = node {
op_names.push(op_name.clone());
}
}
assert!(op_names.contains(&"relu".to_string()));
assert!(op_names.contains(&"sigmoid".to_string()));
assert!(op_names.contains(&"tanh".to_string()));
}
#[test]
fn test_empty_sequential_ops() {
let graph = FxGraph::sequential_ops(&[]);
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
assert_eq!(graph.inputs().len(), 0);
assert_eq!(graph.outputs().len(), 0);
}
#[test]
fn test_modular_architecture() {
let graph = FxGraph::single_op("test_op", vec!["input".to_string()]);
assert!(graph.node_count() > 0);
assert!(graph.edge_count() > 0);
assert!(graph.validate().is_ok());
let summary = graph.summary();
assert!(summary.contains("FX Graph Summary"));
let debug_graph = FxGraph::debug_minimal();
assert!(debug_graph.validate().is_ok());
let json = graph.to_json().unwrap();
assert!(!json.is_empty());
let deserialized = FxGraph::from_json(&json).unwrap();
assert_eq!(deserialized.node_count(), graph.node_count());
}
#[test]
fn test_graph_validation() {
let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
assert!(graph.validate().is_ok());
let mut invalid_graph = FxGraph::new();
let output = invalid_graph.add_node(Node::Output);
invalid_graph.add_output(output);
assert!(invalid_graph.validate().is_err());
let mut invalid_graph2 = FxGraph::new();
let input = invalid_graph2.add_node(Node::Input("x".to_string()));
invalid_graph2.add_input(input);
assert!(invalid_graph2.validate().is_err());
}
#[test]
fn test_performance_recommendations() {
let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
let recommendations = graph.performance_recommendations();
assert!(!recommendations.is_empty());
}
#[test]
fn test_operation_analysis() {
let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "relu"]);
let op_names = graph.get_operation_names();
assert_eq!(op_names.len(), 2); assert!(op_names.contains(&"relu".to_string()));
assert!(op_names.contains(&"sigmoid".to_string()));
assert!(graph.contains_operation("relu"));
assert!(graph.contains_operation("sigmoid"));
assert!(!graph.contains_operation("tanh"));
let counts = graph.operation_counts();
assert_eq!(counts.get("relu"), Some(&2)); assert_eq!(counts.get("sigmoid"), Some(&1)); assert_eq!(counts.get("tanh"), None); }
}
pub mod prelude {
pub use crate::fx::*;
pub use crate::{
benchmarking::*, checkpointing::*, codegen::*, custom_backends::*, distributed::*,
graph_analysis::*, tracer::*, Edge, FxGraph, GraphStats, MemoryEstimate, Node,
SerializableGraph,
};
}