use crate::autograd::function::Function;
use crate::autograd::graph::{ComputationGraph, GraphNode};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock, Weak};
use std::time::Instant;
#[derive(Debug, Clone, PartialEq)]
pub enum DynamicOp {
MatMul,
Add,
Mul,
ReLU,
Sigmoid,
Conv2d {
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
Linear {
in_features: usize,
out_features: usize,
},
BatchNorm { num_features: usize },
Dropout { p: f64 },
Reshape { shape: Vec<usize> },
Custom(String),
}
pub struct DynamicNode<T: Float + Send + Sync + 'static> {
pub op: DynamicOp,
pub inputs: Vec<Arc<DynamicNode<T>>>,
pub cached_output: RwLock<Option<Tensor<T>>>,
pub dirty: RwLock<bool>,
pub id: usize,
pub execution_time: RwLock<Option<std::time::Duration>>,
pub memory_usage: RwLock<Option<usize>>,
}
impl<T: Float + Send + Sync + 'static> DynamicNode<T> {
pub fn new(op: DynamicOp, inputs: Vec<Arc<DynamicNode<T>>>, id: usize) -> Arc<Self> {
Arc::new(DynamicNode {
op,
inputs,
cached_output: RwLock::new(None),
dirty: RwLock::new(true),
id,
execution_time: RwLock::new(None),
memory_usage: RwLock::new(None),
})
}
pub fn mark_dirty(&self) {
*self.dirty.write().unwrap() = true;
*self.cached_output.write().unwrap() = None;
}
pub fn is_dirty(&self) -> bool {
*self.dirty.read().unwrap()
}
pub fn get_cached_output(&self) -> Option<Tensor<T>> {
self.cached_output.read().unwrap().clone()
}
pub fn set_cached_output(&self, output: Tensor<T>) {
*self.cached_output.write().unwrap() = Some(output);
*self.dirty.write().unwrap() = false;
}
}
pub struct DynamicExecutionContext<T: Float + Send + Sync + 'static> {
graph: Arc<RwLock<ComputationGraph<T>>>,
dynamic_nodes: HashMap<usize, Arc<DynamicNode<T>>>,
execution_order: RwLock<Option<Vec<usize>>>,
compiled_ops: HashMap<Vec<DynamicOp>, Arc<dyn Function<T>>>,
next_node_id: usize,
stats: DynamicExecutionStats,
}
#[derive(Debug, Default)]
pub struct DynamicExecutionStats {
pub total_ops: usize,
pub cache_hit_rate: f64,
pub total_execution_time: std::time::Duration,
pub memory_allocations: usize,
pub jit_compilations: usize,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
DynamicExecutionContext<T>
{
pub fn new() -> Self {
DynamicExecutionContext {
graph: Arc::new(RwLock::new(ComputationGraph::new())),
dynamic_nodes: HashMap::new(),
execution_order: RwLock::new(None),
compiled_ops: HashMap::new(),
next_node_id: 0,
stats: DynamicExecutionStats::default(),
}
}
pub fn add_operation(&mut self, op: DynamicOp, input_ids: Vec<usize>) -> RusTorchResult<usize> {
let node_id = self.next_node_id;
self.next_node_id += 1;
let input_nodes: Vec<Arc<DynamicNode<T>>> = input_ids
.iter()
.filter_map(|&id| self.dynamic_nodes.get(&id).cloned())
.collect();
if input_nodes.len() != input_ids.len() {
return Err(RusTorchError::tensor_op("Some input nodes not found"));
}
let dynamic_node = DynamicNode::new(op, input_nodes, node_id);
self.dynamic_nodes.insert(node_id, dynamic_node);
*self.execution_order.write().unwrap() = None;
Ok(node_id)
}
pub fn add_leaf(&mut self, tensor: Tensor<T>) -> RusTorchResult<usize> {
let node_id = self.next_node_id;
self.next_node_id += 1;
let dynamic_node = DynamicNode::new(DynamicOp::Custom("leaf".to_string()), vec![], node_id);
dynamic_node.set_cached_output(tensor);
self.dynamic_nodes.insert(node_id, dynamic_node);
Ok(node_id)
}
pub fn get_dynamic_node(&self, id: &usize) -> Option<&Arc<DynamicNode<T>>> {
self.dynamic_nodes.get(id)
}
pub fn execute(&mut self, output_node_id: usize) -> RusTorchResult<Tensor<T>> {
let start_time = Instant::now();
self.build_execution_order(output_node_id)?;
let execution_order = self
.execution_order
.read()
.unwrap()
.clone()
.ok_or_else(|| RusTorchError::tensor_op("Failed to build execution order"))?;
for &node_id in &execution_order {
if let Some(node) = self.dynamic_nodes.get(&node_id).cloned() {
if node.is_dirty() || node.get_cached_output().is_none() {
let output = self.execute_node(&node)?;
node.set_cached_output(output);
self.stats.total_ops += 1;
} else {
self.stats.cache_hit_rate =
(self.stats.cache_hit_rate * (self.stats.total_ops as f64) + 1.0)
/ (self.stats.total_ops as f64 + 1.0);
}
}
}
self.stats.total_execution_time += start_time.elapsed();
if let Some(output_node) = self.dynamic_nodes.get(&output_node_id) {
output_node
.get_cached_output()
.ok_or_else(|| RusTorchError::tensor_op("Output node has no result"))
} else {
Err(RusTorchError::tensor_op("Output node not found"))
}
}
pub fn execute_node(&self, node: &DynamicNode<T>) -> RusTorchResult<Tensor<T>> {
let start_time = Instant::now();
let mut input_tensors = Vec::new();
for input_node in &node.inputs {
if let Some(tensor) = input_node.get_cached_output() {
input_tensors.push(tensor);
} else {
return Err(RusTorchError::tensor_op(format!(
"Input node {} has no cached output",
input_node.id
)));
}
}
let output = match &node.op {
DynamicOp::Add => {
if input_tensors.len() != 2 {
return Err(RusTorchError::tensor_op("Add requires 2 inputs"));
}
&input_tensors[0] + &input_tensors[1]
}
DynamicOp::Mul => {
if input_tensors.len() != 2 {
return Err(RusTorchError::tensor_op("Mul requires 2 inputs"));
}
&input_tensors[0] * &input_tensors[1]
}
DynamicOp::MatMul => {
if input_tensors.len() != 2 {
return Err(RusTorchError::tensor_op("MatMul requires 2 inputs"));
}
input_tensors[0].matmul(&input_tensors[1])?
}
DynamicOp::ReLU => {
if input_tensors.len() != 1 {
return Err(RusTorchError::tensor_op("ReLU requires 1 input"));
}
let input_data = &input_tensors[0].data;
let relu_data: Vec<T> = input_data
.iter()
.map(|&x| if x > T::zero() { x } else { T::zero() })
.collect();
Tensor::from_vec(relu_data, input_tensors[0].shape().to_vec())
}
DynamicOp::Sigmoid => {
if input_tensors.len() != 1 {
return Err(RusTorchError::tensor_op("Sigmoid requires 1 input"));
}
let input_data = &input_tensors[0].data;
let sigmoid_data: Vec<T> = input_data
.iter()
.map(|&x| T::one() / (T::one() + (-x).exp()))
.collect();
Tensor::from_vec(sigmoid_data, input_tensors[0].shape().to_vec())
}
DynamicOp::Reshape { shape } => {
if input_tensors.len() != 1 {
return Err(RusTorchError::tensor_op("Reshape requires 1 input"));
}
input_tensors[0].reshape(shape)?
}
DynamicOp::Linear {
in_features: _,
out_features: _,
} => {
if input_tensors.len() < 2 || input_tensors.len() > 3 {
return Err(RusTorchError::tensor_op(
"Linear requires 2-3 inputs (input, weight, [bias])",
));
}
self.execute_linear(&input_tensors)?
}
DynamicOp::Conv2d {
kernel_size: _,
stride: _,
padding: _,
} => {
if input_tensors.len() != 2 {
return Err(RusTorchError::tensor_op(
"Conv2d requires 2 inputs (input, weight)",
));
}
self.execute_conv2d(&input_tensors)?
}
_ => {
return Err(RusTorchError::tensor_op(format!(
"Operation {:?} not implemented yet",
node.op
)));
}
};
let execution_time = start_time.elapsed();
*node.execution_time.write().unwrap() = Some(execution_time);
let memory_usage = output.data.len() * std::mem::size_of::<T>();
*node.memory_usage.write().unwrap() = Some(memory_usage);
Ok(output)
}
fn execute_linear(&self, inputs: &[Tensor<T>]) -> RusTorchResult<Tensor<T>> {
let input = &inputs[0];
let weight = &inputs[1];
let bias = inputs.get(2);
let mut output = input.matmul(&weight.transpose()?)?;
if let Some(bias_tensor) = bias {
output = &output + bias_tensor;
}
Ok(output)
}
fn execute_conv2d(&self, inputs: &[Tensor<T>]) -> RusTorchResult<Tensor<T>> {
let input = &inputs[0];
let weight = &inputs[1];
let input_shape = input.shape();
let weight_shape = weight.shape();
let batch_size = input_shape[0];
let in_channels = input_shape[1];
let out_channels = weight_shape[0];
let output_data =
vec![T::one(); batch_size * out_channels * input_shape[2] * input_shape[3]];
let output = Tensor::from_vec(
output_data,
vec![batch_size, out_channels, input_shape[2], input_shape[3]],
);
Ok(output)
}
fn build_execution_order(&mut self, output_node_id: usize) -> RusTorchResult<()> {
let mut visited = std::collections::HashSet::new();
let mut temp_visited = std::collections::HashSet::new();
let mut order = Vec::new();
self.topological_sort(output_node_id, &mut visited, &mut temp_visited, &mut order)?;
*self.execution_order.write().unwrap() = Some(order);
Ok(())
}
fn topological_sort(
&self,
node_id: usize,
visited: &mut std::collections::HashSet<usize>,
temp_visited: &mut std::collections::HashSet<usize>,
order: &mut Vec<usize>,
) -> RusTorchResult<()> {
if temp_visited.contains(&node_id) {
return Err(RusTorchError::tensor_op("Circular dependency detected"));
}
if visited.contains(&node_id) {
return Ok(());
}
temp_visited.insert(node_id);
if let Some(node) = self.dynamic_nodes.get(&node_id) {
for input_node in &node.inputs {
self.topological_sort(input_node.id, visited, temp_visited, order)?;
}
}
temp_visited.remove(&node_id);
visited.insert(node_id);
order.push(node_id);
Ok(())
}
pub fn get_stats(&self) -> &DynamicExecutionStats {
&self.stats
}
pub fn clear_cache(&mut self) {
for node in self.dynamic_nodes.values() {
node.mark_dirty();
}
*self.execution_order.write().unwrap() = None;
}
pub fn create_execution_plan(&self, output_node_id: usize) -> RusTorchResult<ExecutionPlan<T>> {
let mut plan = ExecutionPlan::new();
let mut visited = std::collections::HashSet::new();
self.build_execution_plan_recursive(output_node_id, &mut visited, &mut plan)?;
plan.optimize_memory_usage();
plan.optimize_execution_order();
Ok(plan)
}
fn build_execution_plan_recursive(
&self,
node_id: usize,
visited: &mut std::collections::HashSet<usize>,
plan: &mut ExecutionPlan<T>,
) -> RusTorchResult<()> {
if visited.contains(&node_id) {
return Ok(());
}
if let Some(node) = self.dynamic_nodes.get(&node_id) {
for input_node in &node.inputs {
self.build_execution_plan_recursive(input_node.id, visited, plan)?;
}
plan.add_operation(
node_id,
node.op.clone(),
node.inputs.iter().map(|n| n.id).collect(),
);
visited.insert(node_id);
}
Ok(())
}
}
#[derive(Clone)]
pub struct ExecutionPlan<T: Float + Send + Sync + 'static> {
pub operations: Vec<PlannedOperation>,
pub memory_plan: MemoryPlan,
pub parallel_groups: Vec<Vec<usize>>,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct PlannedOperation {
pub node_id: usize,
pub op: DynamicOp,
pub input_ids: Vec<usize>,
pub estimated_time: Option<std::time::Duration>,
pub memory_requirement: usize,
pub parallel_safe: bool,
}
#[derive(Debug, Default, Clone)]
pub struct MemoryPlan {
pub peak_memory: usize,
pub allocations: Vec<MemoryAllocation>,
pub reuse_map: HashMap<usize, usize>,
}
#[derive(Debug, Clone)]
pub struct MemoryAllocation {
pub operation_id: usize,
pub size: usize,
pub lifetime_end: usize,
pub reuse_from: Option<usize>,
}
impl<T: Float + Send + Sync + 'static> ExecutionPlan<T> {
pub fn new() -> Self {
ExecutionPlan {
operations: Vec::new(),
memory_plan: MemoryPlan::default(),
parallel_groups: Vec::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn add_operation(&mut self, node_id: usize, op: DynamicOp, input_ids: Vec<usize>) {
let planned_op = PlannedOperation {
node_id,
op,
input_ids,
estimated_time: None,
memory_requirement: 0,
parallel_safe: false,
};
self.operations.push(planned_op);
}
pub fn optimize_memory_usage(&mut self) {
let mut last_use = HashMap::new();
for (op_idx, op) in self.operations.iter().enumerate() {
for &input_id in &op.input_ids {
last_use.insert(input_id, op_idx);
}
}
for (op_idx, op) in self.operations.iter().enumerate() {
let allocation = MemoryAllocation {
operation_id: op.node_id,
size: op.memory_requirement,
lifetime_end: last_use.get(&op.node_id).copied().unwrap_or(op_idx),
reuse_from: None,
};
self.memory_plan.allocations.push(allocation);
}
}
pub fn optimize_execution_order(&mut self) {
let mut current_group = Vec::new();
for (idx, op) in self.operations.iter().enumerate() {
let has_dependency = current_group.iter().any(|&group_idx: &usize| {
op.input_ids.contains(&self.operations[group_idx].node_id)
});
if has_dependency {
if !current_group.is_empty() {
self.parallel_groups.push(current_group.clone());
current_group.clear();
}
}
current_group.push(idx);
}
if !current_group.is_empty() {
self.parallel_groups.push(current_group);
}
}
pub fn estimated_execution_time(&self) -> std::time::Duration {
let mut total_time = std::time::Duration::default();
for group in &self.parallel_groups {
let group_time = group
.iter()
.filter_map(|&idx| self.operations[idx].estimated_time)
.max()
.unwrap_or_default();
total_time += group_time;
}
total_time
}
pub fn peak_memory_usage(&self) -> usize {
self.memory_plan.peak_memory
}
}
pub struct JitCompiler<T: Float + Send + Sync + 'static> {
compiled_cache: HashMap<String, Arc<dyn Function<T>>>,
compilation_stats: JitStats,
}
#[derive(Debug, Default)]
pub struct JitStats {
pub compilations: usize,
pub cache_hits: usize,
pub compilation_time: std::time::Duration,
pub average_speedup: f64,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
JitCompiler<T>
{
pub fn new() -> Self {
JitCompiler {
compiled_cache: HashMap::new(),
compilation_stats: JitStats::default(),
}
}
pub fn compile_operations(
&mut self,
ops: &[DynamicOp],
) -> RusTorchResult<Arc<dyn Function<T>>> {
let ops_key = format!("{:?}", ops);
if let Some(cached) = self.compiled_cache.get(&ops_key) {
self.compilation_stats.cache_hits += 1;
return Ok(cached.clone());
}
let start_time = Instant::now();
let fused_op = self.create_fused_operation(ops)?;
self.compilation_stats.compilations += 1;
self.compilation_stats.compilation_time += start_time.elapsed();
let fused_fn = Arc::new(fused_op);
self.compiled_cache.insert(ops_key, fused_fn.clone());
Ok(fused_fn)
}
fn create_fused_operation(&self, ops: &[DynamicOp]) -> RusTorchResult<FusedOperation<T>> {
Ok(FusedOperation::new(ops.to_vec()))
}
pub fn get_stats(&self) -> &JitStats {
&self.compilation_stats
}
}
pub struct FusedOperation<T: Float + Send + Sync + 'static> {
operations: Vec<DynamicOp>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + Send + Sync + 'static> FusedOperation<T> {
pub fn new(operations: Vec<DynamicOp>) -> Self {
FusedOperation {
operations,
_phantom: std::marker::PhantomData,
}
}
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for FusedOperation<T>
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
if inputs.is_empty() {
Tensor::zeros(&[1])
} else {
inputs[0].clone()
}
}
fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
vec![Some(grad_output.clone()); inputs.len()]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dynamic_execution_context_creation() {
let mut ctx = DynamicExecutionContext::<f32>::new();
let input1 = Tensor::zeros(&[2, 3]);
let input2 = Tensor::ones(&[2, 3]);
let leaf1_id = ctx.add_leaf(input1).unwrap();
let leaf2_id = ctx.add_leaf(input2).unwrap();
let add_id = ctx
.add_operation(DynamicOp::Add, vec![leaf1_id, leaf2_id])
.unwrap();
let result = ctx.execute(add_id).unwrap();
assert_eq!(result.shape(), &[2, 3]);
}
#[test]
fn test_execution_plan() {
let mut plan = ExecutionPlan::<f32>::new();
plan.add_operation(0, DynamicOp::Add, vec![]);
plan.add_operation(1, DynamicOp::ReLU, vec![0]);
plan.optimize_execution_order();
assert!(!plan.parallel_groups.is_empty());
}
#[test]
fn test_jit_compiler() {
let mut compiler = JitCompiler::<f32>::new();
let ops = vec![DynamicOp::Add, DynamicOp::ReLU];
let compiled = compiler.compile_operations(&ops).unwrap();
let compiled2 = compiler.compile_operations(&ops).unwrap();
assert_eq!(compiler.get_stats().cache_hits, 1);
}
#[test]
fn test_relu_operation() {
let mut ctx = DynamicExecutionContext::<f32>::new();
let input_data = vec![-1.0, 0.0, 1.0, 2.0];
let input = Tensor::from_vec(input_data, vec![4]);
let leaf_id = ctx.add_leaf(input).unwrap();
let relu_id = ctx.add_operation(DynamicOp::ReLU, vec![leaf_id]).unwrap();
let result = ctx.execute(relu_id).unwrap();
let expected = vec![0.0, 0.0, 1.0, 2.0];
if let Some(slice) = result.as_slice() {
for (actual, expected) in slice.iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_sigmoid_operation() {
let mut ctx = DynamicExecutionContext::<f32>::new();
let input = Tensor::from_vec(vec![0.0], vec![1]);
let leaf_id = ctx.add_leaf(input).unwrap();
let sigmoid_id = ctx
.add_operation(DynamicOp::Sigmoid, vec![leaf_id])
.unwrap();
let result = ctx.execute(sigmoid_id).unwrap();
if let Some(slice) = result.as_slice() {
assert!((slice[0] - 0.5).abs() < 1e-6);
}
}
#[test]
fn test_linear_operation() {
let mut ctx = DynamicExecutionContext::<f32>::new();
let input = Tensor::ones(&[2, 3]);
let weight = Tensor::ones(&[4, 3]); let bias = Tensor::zeros(&[4]);
let input_id = ctx.add_leaf(input).unwrap();
let weight_id = ctx.add_leaf(weight).unwrap();
let bias_id = ctx.add_leaf(bias).unwrap();
let linear_id = ctx
.add_operation(
DynamicOp::Linear {
in_features: 3,
out_features: 4,
},
vec![input_id, weight_id, bias_id],
)
.unwrap();
let result = ctx.execute(linear_id).unwrap();
assert_eq!(result.shape(), &[2, 4]);
if let Some(slice) = result.as_slice() {
for &value in slice {
assert!((value - 3.0).abs() < 1e-6);
}
}
}
}