pub mod context;
mod backward;
mod backward_ops;
mod gradient_store;
mod tape;
pub use backward::{backward, backward_with_grad};
pub use backward_ops::*;
pub use gradient_store::*;
pub use tape::*;
use std::fmt;
use crate::tensor::GradId;
#[derive(Debug, Clone)]
pub enum AutogradError {
VersionMismatch {
grad_id: GradId,
expected: usize,
found: usize,
},
ShapeMismatch {
grad_id: GradId,
expected: Vec<usize>,
found: Vec<usize>,
},
NoGraph,
MissingGrad { grad_id: GradId },
StateError { key: String, message: String },
}
impl fmt::Display for AutogradError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AutogradError::VersionMismatch {
grad_id,
expected,
found,
} => write!(
f,
"autograd: tensor (GradId({})) was mutated in-place after being \
recorded on the tape (expected version {}, found {})",
grad_id.0, expected, found,
),
AutogradError::ShapeMismatch {
grad_id,
expected,
found,
} => write!(
f,
"autograd: gradient shape mismatch for GradId({}): \
expected {:?}, got {:?}",
grad_id.0, expected, found,
),
AutogradError::NoGraph => write!(
f,
"autograd: backward() called on a tensor with no computational graph",
),
AutogradError::MissingGrad { grad_id } => write!(
f,
"autograd: gradient for GradId({}) not found in GradientStore",
grad_id.0,
),
AutogradError::StateError { key, message } => write!(
f,
"state_dict: key \"{}\": {}",
key, message,
),
}
}
}
impl std::error::Error for AutogradError {}