Skip to main content

entrenar/autograd/
mod.rs

1//! Tape-based autograd engine
2//!
3//! Provides automatic differentiation using a computational graph with gradient tape.
4//!
5//! ## CUDA Acceleration (SPEC-FT-001 v3.0.0)
6//!
7//! When the `cuda` feature is enabled, use `CudaTensor` for GPU-accelerated training:
8//!
9//! ```ignore
10//! use entrenar::autograd::{CudaDevice, CudaTensor};
11//!
12//! let device = CudaDevice::default_device()?;
13//! let tensor = CudaTensor::from_vec(&device, vec![1.0, 2.0, 3.0], true)?;
14//! ```
15//!
16//! ## Gradient Checkpointing
17//!
18//! For memory-efficient training of large models, use the `checkpoint` module:
19//!
20//! ```ignore
21//! use entrenar::autograd::checkpoint::{checkpoint, CheckpointConfig};
22//!
23//! let output = checkpoint(|x| layer.forward(x), &input);
24//! ```
25
26mod backward;
27pub mod checkpoint;
28mod context;
29#[cfg(feature = "cuda")]
30pub mod cuda_backward;
31#[cfg(feature = "cuda")]
32pub mod cuda_forward;
33#[cfg(feature = "cuda")]
34pub mod cuda_optim;
35pub mod cuda_tensor;
36pub mod cuda_training;
37pub mod graph_opt;
38pub(crate) mod ops;
39pub mod precision;
40mod tensor;
41#[cfg(feature = "gpu")]
42pub mod wgpu_backward;
43#[cfg(feature = "gpu")]
44pub mod wgpu_block;
45#[cfg(feature = "gpu")]
46pub mod wgpu_cross_entropy;
47#[cfg(feature = "gpu")]
48pub mod wgpu_training;
49
50#[cfg(test)]
51mod tests;
52
53pub use backward::BackwardOp;
54pub use checkpoint::{
55    checkpoint, checkpoint_if, estimate_memory_savings, estimate_policy_tradeoff,
56    optimal_checkpoints, BinomialCheckpointing, CheckpointConfig, CheckpointManager,
57    CheckpointPolicy, CheckpointedSegment, CustomPolicy, MemoryBudget, OperationInfo,
58    PolicyCheckpointManager, SaveAll, SaveMatmuls, SaveNothing, SaveUnbatchedMatmuls,
59};
60pub use context::Context;
61pub use cuda_training::{cuda_training_available, CudaTrainer};
62pub use graph_opt::{
63    traced_binary_op, CommonSubexprElimination, ComputeGraph, ConstantFolding, DeadCodeElimination,
64    GraphOptimizer, NodeId, OpType, OptimizationPass, OptimizationReport, ShapeError, ShapeTracker,
65    TracedTensor, TracedValue,
66};
67pub use ops::*;
68pub use precision::{
69    bf16_to_f32, bf16_truncate, f32_to_bf16, f32_to_fp16, fp16_to_f32, gemm_bf16_reference,
70    GradScaler, MixedPrecisionConfig, Precision,
71};
72pub use tensor::Tensor;
73
74/// Perform backward pass on a tensor
75pub fn backward(tensor: &mut Tensor, grad_output: Option<ndarray::Array1<f32>>) {
76    if let Some(grad) = grad_output {
77        tensor.set_grad(grad);
78    } else {
79        // Initialize with ones for scalar loss
80        let ones = ndarray::Array1::ones(tensor.data().len());
81        tensor.set_grad(ones);
82    }
83
84    if let Some(op) = tensor.backward_op() {
85        op.backward();
86    }
87}