use crate::{error::PyResult, py_result, tensor::PyTensor};
use pyo3::prelude::*;
use pyo3::types::PyAny;
use pyo3::wrap_pyfunction;
use pyo3::PyRefMut;
use std::cell::RefCell;
use std::sync::{Arc, Mutex};
pub struct AutogradState {
enabled: bool,
anomaly_detection: bool,
}
impl AutogradState {
fn new() -> Self {
Self {
enabled: true,
anomaly_detection: false,
}
}
fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn set_anomaly_detection(&mut self, enabled: bool) {
self.anomaly_detection = enabled;
}
fn is_anomaly_detection_enabled(&self) -> bool {
self.anomaly_detection
}
}
thread_local! {
static AUTOGRAD_STATE: RefCell<AutogradState> = RefCell::new(AutogradState::new());
}
#[pyclass(name = "no_grad")]
pub struct PyNoGrad {
prev_state: bool,
}
#[pymethods]
impl PyNoGrad {
#[new]
fn new() -> Self {
let prev_state = AUTOGRAD_STATE.with(|state| state.borrow().is_enabled());
Self { prev_state }
}
fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
AUTOGRAD_STATE.with(|state| {
slf.prev_state = state.borrow().is_enabled();
state.borrow_mut().set_enabled(false);
});
slf
}
fn __exit__(
mut slf: PyRefMut<'_, Self>,
exc_type: Option<Py<PyAny>>,
exc_val: Option<Py<PyAny>>,
exc_tb: Option<Py<PyAny>>,
) -> PyResult<bool> {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_enabled(slf.prev_state);
});
Ok(false)
}
#[staticmethod]
fn is_enabled() -> bool {
!AUTOGRAD_STATE.with(|state| state.borrow().is_enabled())
}
}
#[pyclass(name = "enable_grad")]
pub struct PyEnableGrad {
prev_state: bool,
}
#[pymethods]
impl PyEnableGrad {
#[new]
fn new() -> Self {
let prev_state = AUTOGRAD_STATE.with(|state| state.borrow().is_enabled());
Self { prev_state }
}
fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
AUTOGRAD_STATE.with(|state| {
slf.prev_state = state.borrow().is_enabled();
state.borrow_mut().set_enabled(true);
});
slf
}
fn __exit__(
mut slf: PyRefMut<'_, Self>,
exc_type: Option<Py<PyAny>>,
exc_val: Option<Py<PyAny>>,
exc_tb: Option<Py<PyAny>>,
) -> PyResult<bool> {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_enabled(slf.prev_state);
});
Ok(false)
}
#[staticmethod]
fn is_enabled() -> bool {
AUTOGRAD_STATE.with(|state| state.borrow().is_enabled())
}
}
#[pyclass(name = "set_grad_enabled")]
pub struct PySetGradEnabled {
mode: bool,
prev_state: bool,
}
#[pymethods]
impl PySetGradEnabled {
#[new]
fn new(mode: bool) -> Self {
let prev_state = AUTOGRAD_STATE.with(|state| state.borrow().is_enabled());
Self { mode, prev_state }
}
fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
AUTOGRAD_STATE.with(|state| {
slf.prev_state = state.borrow().is_enabled();
state.borrow_mut().set_enabled(slf.mode);
});
slf
}
fn __exit__(
mut slf: PyRefMut<'_, Self>,
exc_type: Option<Py<PyAny>>,
exc_val: Option<Py<PyAny>>,
exc_tb: Option<Py<PyAny>>,
) -> PyResult<bool> {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_enabled(slf.prev_state);
});
Ok(false)
}
}
#[pyclass(name = "detect_anomaly")]
pub struct PyDetectAnomaly {
mode: bool,
prev_state: bool,
}
#[pymethods]
impl PyDetectAnomaly {
#[new]
fn new(mode: bool) -> Self {
let prev_state = AUTOGRAD_STATE.with(|state| state.borrow().is_anomaly_detection_enabled());
Self { mode, prev_state }
}
fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
AUTOGRAD_STATE.with(|state| {
slf.prev_state = state.borrow().is_anomaly_detection_enabled();
state.borrow_mut().set_anomaly_detection(slf.mode);
});
slf
}
fn __exit__(
mut slf: PyRefMut<'_, Self>,
exc_type: Option<Py<PyAny>>,
exc_val: Option<Py<PyAny>>,
exc_tb: Option<Py<PyAny>>,
) -> PyResult<bool> {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_anomaly_detection(slf.prev_state);
});
Ok(false)
}
#[staticmethod]
fn is_enabled() -> bool {
AUTOGRAD_STATE.with(|state| state.borrow().is_anomaly_detection_enabled())
}
}
#[pyclass(name = "Function")]
pub struct PyFunction;
#[pymethods]
impl PyFunction {
#[staticmethod]
fn apply(inputs: Vec<PyTensor>) -> PyResult<PyTensor> {
if inputs.is_empty() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Function.apply requires at least one input",
));
}
Ok(inputs[0].clone())
}
}
pub struct AutogradUtils;
impl AutogradUtils {
pub fn grad(
outputs: Vec<PyTensor>,
inputs: Vec<PyTensor>,
grad_outputs: Option<Vec<Option<PyTensor>>>,
retain_graph: Option<bool>,
create_graph: Option<bool>,
only_inputs: Option<bool>,
allow_unused: Option<bool>,
) -> PyResult<Vec<Option<PyTensor>>> {
if outputs.len() != 1 {
return Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
"Multiple outputs not yet supported",
));
}
let output = &outputs[0];
py_result!(output.tensor.backward())?;
let mut grads = Vec::new();
for input in inputs {
grads.push(input.tensor.grad().map(|g| PyTensor { tensor: g }));
}
Ok(grads)
}
}
use pyo3::types::{PyModule, PyModuleMethods};
pub fn register_autograd_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyNoGrad>()?;
m.add_class::<PyEnableGrad>()?;
m.add_class::<PySetGradEnabled>()?;
m.add_class::<PyDetectAnomaly>()?;
m.add_class::<PyFunction>()?;
#[pyfunction]
fn grad(
outputs: Vec<PyTensor>,
inputs: Vec<PyTensor>,
grad_outputs: Option<Vec<Option<PyTensor>>>,
retain_graph: Option<bool>,
create_graph: Option<bool>,
only_inputs: Option<bool>,
allow_unused: Option<bool>,
) -> PyResult<Vec<Option<PyTensor>>> {
AutogradUtils::grad(
outputs,
inputs,
grad_outputs,
retain_graph,
create_graph,
only_inputs,
allow_unused,
)
}
m.add_function(wrap_pyfunction!(grad, m)?)?;
#[pyfunction]
fn backward(
tensors: Vec<PyTensor>,
grad_tensors: Option<Vec<Option<PyTensor>>>,
retain_graph: Option<bool>,
create_graph: Option<bool>,
inputs: Option<Vec<PyTensor>>,
) -> PyResult<()> {
for tensor in tensors {
py_result!(tensor.tensor.backward())?;
}
Ok(())
}
m.add_function(wrap_pyfunction!(backward, m)?)?;
#[pyfunction]
fn is_grad_enabled() -> bool {
AUTOGRAD_STATE.with(|state| state.borrow().is_enabled())
}
#[pyfunction]
fn set_grad_enabled(mode: bool) {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_enabled(mode);
});
}
#[pyfunction]
fn detect_anomaly(mode: Option<bool>) -> PyResult<PyDetectAnomaly> {
let mode = mode.unwrap_or(true);
Ok(PyDetectAnomaly::new(mode))
}
#[pyfunction]
fn is_anomaly_detection_enabled() -> bool {
AUTOGRAD_STATE.with(|state| state.borrow().is_anomaly_detection_enabled())
}
#[pyfunction]
fn set_anomaly_detection(mode: bool) {
AUTOGRAD_STATE.with(|state| {
state.borrow_mut().set_anomaly_detection(mode);
});
}
m.add_function(wrap_pyfunction!(is_grad_enabled, m)?)?;
m.add_function(wrap_pyfunction!(set_grad_enabled, m)?)?;
m.add_function(wrap_pyfunction!(detect_anomaly, m)?)?;
m.add_function(wrap_pyfunction!(is_anomaly_detection_enabled, m)?)?;
m.add_function(wrap_pyfunction!(set_anomaly_detection, m)?)?;
Ok(())
}