pub(crate) mod grad_fn;
mod graph;
mod ops;
mod tensor;
pub use grad_fn::GradFn;
pub use graph::ComputationGraph;
pub use tensor::{Tensor, TensorId};
use std::cell::RefCell;
thread_local! {
static GRAPH: RefCell<ComputationGraph> = RefCell::new(ComputationGraph::new());
static GRAD_ENABLED: RefCell<bool> = const { RefCell::new(true) };
}
pub fn no_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
GRAD_ENABLED.with(|enabled| {
let prev = *enabled.borrow();
*enabled.borrow_mut() = false;
let result = f();
*enabled.borrow_mut() = prev;
result
})
}
#[must_use]
pub fn is_grad_enabled() -> bool {
GRAD_ENABLED.with(|enabled| *enabled.borrow())
}
pub(crate) fn with_graph<F, R>(f: F) -> R
where
F: FnOnce(&mut ComputationGraph) -> R,
{
GRAPH.with(|graph| f(&mut graph.borrow_mut()))
}
pub fn clear_graph() {
GRAPH.with(|graph| graph.borrow_mut().clear());
}
#[must_use]
pub fn get_grad(id: TensorId) -> Option<Tensor> {
with_graph(|graph| graph.get_grad(id))
}
pub fn clear_grad(id: TensorId) {
with_graph(|graph| graph.clear_grad(id));
}
#[cfg(test)]
#[path = "tests_tensor_contract.rs"]
mod tests_tensor_contract;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_grad_context() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_nested_no_grad() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
}