1mod backward;
27pub mod checkpoint;
28mod context;
29#[cfg(feature = "cuda")]
30pub mod cuda_backward;
31#[cfg(feature = "cuda")]
32pub mod cuda_forward;
33#[cfg(feature = "cuda")]
34pub mod cuda_optim;
35pub mod cuda_tensor;
36pub mod cuda_training;
37pub mod graph_opt;
38pub(crate) mod ops;
39pub mod precision;
40mod tensor;
41#[cfg(feature = "gpu")]
42pub mod wgpu_backward;
43#[cfg(feature = "gpu")]
44pub mod wgpu_block;
45#[cfg(feature = "gpu")]
46pub mod wgpu_cross_entropy;
47#[cfg(feature = "gpu")]
48pub mod wgpu_training;
49
50#[cfg(test)]
51mod tests;
52
53pub use backward::BackwardOp;
54pub use checkpoint::{
55 checkpoint, checkpoint_if, estimate_memory_savings, estimate_policy_tradeoff,
56 optimal_checkpoints, BinomialCheckpointing, CheckpointConfig, CheckpointManager,
57 CheckpointPolicy, CheckpointedSegment, CustomPolicy, MemoryBudget, OperationInfo,
58 PolicyCheckpointManager, SaveAll, SaveMatmuls, SaveNothing, SaveUnbatchedMatmuls,
59};
60pub use context::Context;
61pub use cuda_training::{cuda_training_available, CudaTrainer};
62pub use graph_opt::{
63 traced_binary_op, CommonSubexprElimination, ComputeGraph, ConstantFolding, DeadCodeElimination,
64 GraphOptimizer, NodeId, OpType, OptimizationPass, OptimizationReport, ShapeError, ShapeTracker,
65 TracedTensor, TracedValue,
66};
67pub use ops::*;
68pub use precision::{
69 bf16_to_f32, bf16_truncate, f32_to_bf16, f32_to_fp16, fp16_to_f32, gemm_bf16_reference,
70 GradScaler, MixedPrecisionConfig, Precision,
71};
72pub use tensor::Tensor;
73
74pub fn backward(tensor: &mut Tensor, grad_output: Option<ndarray::Array1<f32>>) {
76 if let Some(grad) = grad_output {
77 tensor.set_grad(grad);
78 } else {
79 let ones = ndarray::Array1::ones(tensor.data().len());
81 tensor.set_grad(ones);
82 }
83
84 if let Some(op) = tensor.backward_op() {
85 op.backward();
86 }
87}