#![allow(unused_variables)]
use super::graph::ComputationGraph;
use super::tape::GradientTape;
use super::variable::{GraphRef, Variable};
use crate::errors::{tensor_op_error, Result};
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GradientMode {
Forward,
Reverse,
Mixed,
}
#[derive(Debug, Clone)]
pub struct AutodiffConfig {
pub mode: GradientMode,
pub enabled: bool,
pub detect_anomalies: bool,
pub retain_graph: bool,
pub max_cache_size: usize,
pub optimize_graph: bool,
pub gradient_checkpointing: bool,
}
impl Default for AutodiffConfig {
fn default() -> Self {
Self {
mode: GradientMode::Reverse,
enabled: true,
detect_anomalies: false,
retain_graph: false,
max_cache_size: 10000,
optimize_graph: true,
gradient_checkpointing: false,
}
}
}
#[derive(Debug)]
pub struct AutodiffEngine {
config: AutodiffConfig,
graph: GraphRef,
tape: Arc<Mutex<GradientTape>>,
#[allow(dead_code)]
operation_cache: Arc<Mutex<HashMap<String, CompiledOperation>>>,
stats: Arc<Mutex<AutodiffStats>>,
}
#[derive(Debug, Clone)]
pub struct CompiledOperation {
pub id: String,
pub forward_fn: fn(&[&Tensor]) -> Result<Tensor>,
pub backward_fn: fn(&Tensor, &[&Tensor]) -> Result<Vec<Tensor>>,
pub metadata: OperationMetadata,
}
#[derive(Debug, Clone)]
pub struct OperationMetadata {
pub op_type: String,
pub input_shapes: Vec<Vec<usize>>,
pub output_shape: Vec<usize>,
pub num_parameters: usize,
pub estimated_flops: usize,
}
#[derive(Debug, Default, Clone)]
pub struct AutodiffStats {
pub forward_passes: u64,
pub backward_passes: u64,
pub total_operations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub forward_time_us: u64,
pub backward_time_us: u64,
pub peak_memory_usage: usize,
pub current_memory_usage: usize,
}
impl Default for AutodiffEngine {
fn default() -> Self {
Self::new(AutodiffConfig::default())
}
}
impl AutodiffEngine {
pub fn new(config: AutodiffConfig) -> Self {
let graph = Arc::new(Mutex::new(ComputationGraph::new()));
let tape = Arc::new(Mutex::new(GradientTape::new()));
let operation_cache = Arc::new(Mutex::new(HashMap::new()));
let stats = Arc::new(Mutex::new(AutodiffStats::default()));
Self {
config,
graph,
tape,
operation_cache,
stats,
}
}
pub fn enable_grad(&mut self) {
self.config.enabled = true;
}
pub fn disable_grad(&mut self) {
self.config.enabled = false;
}
pub fn is_grad_enabled(&self) -> bool {
self.config.enabled
}
pub fn set_mode(&mut self, mode: GradientMode) {
self.config.mode = mode;
}
pub fn mode(&self) -> GradientMode {
self.config.mode
}
pub fn enable_anomaly_detection(&mut self) {
self.config.detect_anomalies = true;
}
pub fn disable_anomaly_detection(&mut self) {
self.config.detect_anomalies = false;
}
pub fn variable(&self, tensor: Tensor, requires_grad: bool) -> Variable {
Variable::from_graph(
self.graph.clone(),
{
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.add_node(tensor, requires_grad, None)
},
requires_grad,
)
}
pub fn variable_with_name(
&self,
tensor: Tensor,
requires_grad: bool,
name: String,
) -> Variable {
Variable::from_graph(
self.graph.clone(),
{
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.add_node(tensor, requires_grad, Some(name))
},
requires_grad,
)
}
pub fn backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
let start_time = std::time::Instant::now();
match self.config.mode {
GradientMode::Forward => self.forward_mode_backward(output, grad_output),
GradientMode::Reverse => self.reverse_mode_backward(output, grad_output),
GradientMode::Mixed => self.mixed_mode_backward(output, grad_output),
}?;
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.backward_passes += 1;
stats.backward_time_us += start_time.elapsed().as_micros() as u64;
Ok(())
}
fn forward_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.backward(output.node_id(), grad_output)
}
fn reverse_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.backward(output.node_id(), grad_output)
}
fn mixed_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
let graph = self.graph.lock().expect("lock should not be poisoned");
let num_nodes = graph.num_nodes();
if num_nodes < 100 {
drop(graph);
self.forward_mode_backward(output, grad_output)
} else {
drop(graph);
self.reverse_mode_backward(output, grad_output)
}
}
pub fn zero_grad(&self) {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.zero_grad();
}
pub fn get_grad(&self, variable: &Variable) -> Result<Option<Tensor>> {
let graph = self.graph.lock().expect("lock should not be poisoned");
Ok(graph.get_gradient(variable.node_id()).cloned())
}
pub fn clear_graph(&self) {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
*graph = ComputationGraph::new();
let mut tape = self.tape.lock().expect("lock should not be poisoned");
tape.clear();
}
pub fn stats(&self) -> AutodiffStats {
let stats = self.stats.lock().expect("lock should not be poisoned");
stats.clone()
}
pub fn reset_stats(&self) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
*stats = AutodiffStats::default();
}
pub fn graph(&self) -> GraphRef {
self.graph.clone()
}
pub fn optimize_graph(&self) -> Result<()> {
if !self.config.optimize_graph {
return Ok(());
}
let mut graph = self.graph.lock().expect("lock should not be poisoned");
self.eliminate_dead_nodes(&mut graph)?;
self.fuse_operations(&mut graph)?;
self.optimize_memory_layout(&mut graph)?;
Ok(())
}
fn eliminate_dead_nodes(&self, graph: &mut ComputationGraph) -> Result<()> {
Ok(())
}
fn fuse_operations(&self, graph: &mut ComputationGraph) -> Result<()> {
Ok(())
}
fn optimize_memory_layout(&self, graph: &mut ComputationGraph) -> Result<()> {
Ok(())
}
pub fn no_grad<F, R>(&mut self, f: F) -> R
where
F: FnOnce() -> R,
{
let was_enabled = self.config.enabled;
self.config.enabled = false;
let result = f();
self.config.enabled = was_enabled;
result
}
pub fn with_grad<F, R>(&mut self, f: F) -> R
where
F: FnOnce() -> R,
{
let was_enabled = self.config.enabled;
self.config.enabled = true;
let result = f();
self.config.enabled = was_enabled;
result
}
pub fn check_anomalies(&self, variable: &Variable) -> Result<()> {
if !self.config.detect_anomalies {
return Ok(());
}
if let Some(grad) = self.get_grad(variable)? {
let grad_values = grad.to_vec_f32()?;
for &value in &grad_values {
if value.is_nan() {
return Err(tensor_op_error(
"AutodiffEngine::check_anomalies",
"NaN detected in gradient",
));
}
if value.is_infinite() {
return Err(tensor_op_error(
"AutodiffEngine::check_anomalies",
"Infinite value detected in gradient",
));
}
}
}
Ok(())
}
pub fn enable_checkpointing(&mut self) {
self.config.gradient_checkpointing = true;
}
pub fn disable_checkpointing(&mut self) {
self.config.gradient_checkpointing = false;
}
pub fn is_checkpointing_enabled(&self) -> bool {
self.config.gradient_checkpointing
}
pub fn export_graph(&self) -> Result<String> {
let graph = self.graph.lock().expect("lock should not be poisoned");
let graph_export = graph.export_graph();
let mut dot = String::from("digraph G {\n");
dot.push_str(" rankdir=TB;\n");
for node in &graph_export.nodes {
let node_label = if let Some(ref name) = node.name {
name.clone()
} else {
format!("node_{}", node.id)
};
let op_label = if let Some(ref op) = node.operation {
format!("{:?}", op)
} else {
"Variable".to_string()
};
dot.push_str(&format!(
" {} [label=\"{}\\n{}\\n{:?}\"];\n",
node.id, node_label, op_label, node.shape
));
for parent_id in &node.parents {
dot.push_str(&format!(" {} -> {};\n", parent_id, node.id));
}
}
dot.push_str("}\n");
Ok(dot)
}
pub fn memory_info(&self) -> Result<MemoryInfo> {
let graph = self.graph.lock().expect("lock should not be poisoned");
let mut total_memory = 0;
let mut num_tensors = 0;
for node in graph.export_graph().nodes {
total_memory += node.value.memory_usage();
num_tensors += 1;
if let Some(ref grad) = node.gradient {
total_memory += grad.memory_usage();
num_tensors += 1;
}
}
Ok(MemoryInfo {
total_memory_bytes: total_memory,
num_tensors,
num_nodes: graph.num_nodes(),
})
}
}
#[derive(Debug, Clone)]
pub struct MemoryInfo {
pub total_memory_bytes: usize,
pub num_tensors: usize,
pub num_nodes: usize,
}
static GLOBAL_ENGINE: OnceLock<Arc<Mutex<AutodiffEngine>>> = OnceLock::new();
pub fn init_engine(config: AutodiffConfig) {
let _ = GLOBAL_ENGINE.set(Arc::new(Mutex::new(AutodiffEngine::new(config))));
}
pub fn get_engine() -> Arc<Mutex<AutodiffEngine>> {
GLOBAL_ENGINE
.get_or_init(|| Arc::new(Mutex::new(AutodiffEngine::new(AutodiffConfig::default()))))
.clone()
}
pub struct GradContext {
previous_state: bool,
}
impl GradContext {
pub fn enable() -> Self {
let engine = get_engine();
let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
engine.lock().expect("Lock poisoned").enable_grad();
Self { previous_state }
}
pub fn disable() -> Self {
let engine = get_engine();
let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
engine.lock().expect("Lock poisoned").disable_grad();
Self { previous_state }
}
}
impl Drop for GradContext {
fn drop(&mut self) {
let engine = get_engine();
if self.previous_state {
engine.lock().expect("Lock poisoned").enable_grad();
} else {
engine.lock().expect("Lock poisoned").disable_grad();
}
}
}
#[macro_export]
macro_rules! no_grad {
($($stmt:stmt)*) => {
{
let _ctx = $crate::autodiff::engine::GradContext::disable();
$($stmt)*
}
};
}
#[macro_export]
macro_rules! with_grad {
($($stmt:stmt)*) => {
{
let _ctx = $crate::autodiff::engine::GradContext::enable();
$($stmt)*
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_engine_creation() {
let config = AutodiffConfig::default();
let engine = AutodiffEngine::new(config);
assert!(engine.is_grad_enabled());
assert_eq!(engine.mode(), GradientMode::Reverse);
}
#[test]
fn test_variable_creation() {
let engine = AutodiffEngine::default();
let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
let var = engine.variable(tensor, true);
assert!(var.requires_grad());
assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
}
#[test]
fn test_gradient_computation() {
let engine = AutodiffEngine::default();
let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
let c = a.mul(&b).expect("Multiplication failed");
engine.backward(&c, None).expect("operation failed in test");
let grad_a = engine
.get_grad(&a)
.expect("operation failed in test")
.expect("operation failed in test");
let grad_b = engine
.get_grad(&b)
.expect("operation failed in test")
.expect("operation failed in test");
assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
}
#[test]
fn test_grad_context() {
let engine = AutodiffEngine::default();
assert!(engine.is_grad_enabled());
{
let _ctx = GradContext::disable();
assert!(!get_engine().lock().expect("Lock poisoned").is_grad_enabled());
}
assert!(get_engine().lock().expect("Lock poisoned").is_grad_enabled());
}
#[test]
fn test_engine_stats() {
let engine = AutodiffEngine::default();
let stats = engine.stats();
assert_eq!(stats.forward_passes, 0);
assert_eq!(stats.backward_passes, 0);
}
#[test]
fn test_memory_info() {
let engine = AutodiffEngine::default();
let tensor = Tensor::ones(&[100, 100]).expect("Failed to create ones tensor");
let _var = engine.variable(tensor, true);
let memory_info = engine.memory_info().expect("operation failed in test");
assert!(memory_info.total_memory_bytes > 0);
assert!(memory_info.num_tensors > 0);
assert!(memory_info.num_nodes > 0);
}
#[test]
fn test_anomaly_detection() {
let config = AutodiffConfig {
detect_anomalies: true,
..Default::default()
};
let engine = AutodiffEngine::new(config);
let var = engine.variable(Tensor::scalar(1.0).expect("tensor operation failed"), true);
let result = engine.check_anomalies(&var);
assert!(result.is_ok());
}
#[test]
fn test_graph_export() {
let engine = AutodiffEngine::default();
let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
let _c = a.mul(&b).expect("Multiplication failed");
let dot_graph = engine.export_graph().expect("operation failed in test");
assert!(dot_graph.contains("digraph G"));
assert!(dot_graph.contains("->"));
}
}