use crate::autograd::Variable;
use crate::python::common::{conversions, memory, to_py_err, validation};
use crate::python::tensor::PyTensor;
use numpy::PyArray1;
use pyo3::exceptions::*;
use pyo3::prelude::*;
#[pyclass]
pub struct PyVariable {
pub(crate) variable: Variable<f32>,
}
#[pymethods]
impl PyVariable {
#[new]
pub fn new(tensor: &PyTensor, requires_grad: Option<bool>) -> PyResult<Self> {
let requires_grad = requires_grad.unwrap_or(false);
let variable = Variable::new(tensor.tensor.clone(), requires_grad);
Ok(PyVariable { variable })
}
#[staticmethod]
pub fn from_data(
data: Vec<f32>,
shape: Vec<usize>,
requires_grad: Option<bool>,
) -> PyResult<Self> {
use crate::python::common::validation::validate_dimensions;
validate_dimensions(&shape)?;
let tensor = PyTensor::new(data, shape)?;
Self::new(&tensor, requires_grad)
}
#[staticmethod]
pub fn zeros(shape: Vec<usize>, requires_grad: Option<bool>) -> PyResult<Self> {
let tensor = PyTensor::zeros(shape)?;
Self::new(&tensor, requires_grad)
}
#[staticmethod]
pub fn ones(shape: Vec<usize>, requires_grad: Option<bool>) -> PyResult<Self> {
let tensor = PyTensor::ones(shape)?;
Self::new(&tensor, requires_grad)
}
#[staticmethod]
pub fn randn(shape: Vec<usize>, requires_grad: Option<bool>) -> PyResult<Self> {
let tensor = PyTensor::randn(shape)?;
Self::new(&tensor, requires_grad)
}
pub fn data(&self) -> PyResult<PyTensor> {
use crate::python::common::memory::safe_read;
safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| PyTensor {
tensor: tensor.clone(),
},
)
}
pub fn grad(&self) -> Option<PyTensor> {
use crate::python::common::memory::safe_read;
match safe_read(
&self.variable.grad(),
|grad_opt: &Option<crate::tensor::Tensor<f32>>| grad_opt.clone(),
) {
Ok(Some(grad)) => Some(PyTensor { tensor: grad }),
_ => None,
}
}
pub fn requires_grad(&self) -> bool {
self.variable.requires_grad()
}
pub fn requires_grad_(&mut self, _requires_grad: bool) -> PyResult<()> {
Err(PyRuntimeError::new_err(
"Cannot change requires_grad after Variable creation",
))
}
pub fn shape(&self) -> Vec<usize> {
use crate::python::common::memory::safe_read;
safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.shape().to_vec(),
)
.unwrap_or_default()
}
pub fn numel(&self) -> usize {
use crate::python::common::memory::safe_read;
safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.shape().iter().product(),
)
.unwrap_or(0)
}
pub fn numpy<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
use crate::python::common::{conversions::vec_to_pyarray, memory::safe_read};
let data = safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.data.as_slice().unwrap_or(&[]).to_vec(),
)
.unwrap_or_default();
vec_to_pyarray(data, py)
}
pub fn backward(
&mut self,
gradient: Option<&PyTensor>,
_retain_graph: Option<bool>,
) -> PyResult<()> {
use crate::python::common::to_py_err;
match gradient {
Some(grad) => {
self.variable.backward_with_grad(Some(grad.tensor.clone()));
Ok(())
}
None => {
self.variable.backward();
Ok(())
}
}
}
pub fn zero_grad(&mut self) -> PyResult<()> {
self.variable.zero_grad();
Ok(())
}
pub fn detach(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let tensor_data = safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.clone(),
)?;
let detached_var = Variable::new(tensor_data, false);
Ok(PyVariable {
variable: detached_var,
})
}
pub fn clone(&self) -> PyResult<PyVariable> {
Ok(PyVariable {
variable: self.variable.clone(),
})
}
pub fn reshape(&self, shape: Vec<usize>) -> PyResult<PyVariable> {
use crate::python::common::{
memory::safe_read, to_py_err, validation::validate_dimensions,
};
validate_dimensions(&shape)?;
let current_elements = self.numel();
let new_elements: usize = shape.iter().product();
if current_elements != new_elements {
return Err(PyValueError::new_err(format!(
"Cannot reshape Variable with {} elements to shape with {} elements",
current_elements, new_elements
)));
}
let result_tensor = safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.reshape(&shape),
)?
.map_err(to_py_err)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn transpose(&self) -> PyResult<PyVariable> {
use crate::python::common::{memory::safe_read, to_py_err};
let result_tensor = safe_read(
&self.variable.data(),
|tensor: &crate::tensor::Tensor<f32>| tensor.transpose(),
)?
.map_err(to_py_err)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn pow(&self, exponent: f32) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data.data.mapv(|x| x.powf(exponent));
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn exp(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data.data.mapv(|x| x.exp());
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn log(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data
.data
.mapv(|x| if x <= 0.0 { f32::NAN } else { x.ln() });
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn sin(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data.data.mapv(|x| x.sin());
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn cos(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data.data.mapv(|x| x.cos());
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn sqrt(&self) -> PyResult<PyVariable> {
use crate::python::common::memory::safe_read;
let result_tensor = safe_read(
&self.variable.data(),
|tensor_data: &crate::tensor::Tensor<f32>| {
let result_data = tensor_data
.data
.mapv(|x| if x < 0.0 { f32::NAN } else { x.sqrt() });
crate::tensor::Tensor::from_ndarray(result_data)
},
)?;
let result_var = Variable::new(result_tensor, self.variable.requires_grad());
Ok(PyVariable {
variable: result_var,
})
}
pub fn __add__(&self, _other: &PyVariable) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable arithmetic operations require full autograd implementation",
))
}
pub fn __sub__(&self, _other: &PyVariable) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable arithmetic operations require full autograd implementation",
))
}
pub fn __mul__(&self, _other: &PyVariable) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable arithmetic operations require full autograd implementation",
))
}
pub fn __matmul__(&self, _other: &PyVariable) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable matrix operations require full autograd implementation",
))
}
pub fn dot(&self, other: &PyVariable) -> PyResult<PyVariable> {
self.__matmul__(other)
}
pub fn sum(&self) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable reduction operations require full autograd implementation",
))
}
pub fn mean(&self) -> PyResult<PyVariable> {
Err(PyNotImplementedError::new_err(
"Variable reduction operations require full autograd implementation",
))
}
pub fn __repr__(&self) -> String {
let grad_str = if self.requires_grad() {
"requires_grad=True"
} else {
"requires_grad=False"
};
let shape = self.shape();
let grad_fn_str = if self.variable.grad_fn().is_some() {
"<BackwardFunction>"
} else {
"None"
};
format!(
"PyVariable(shape={:?}, {}, grad_fn={})",
shape, grad_str, grad_fn_str
)
}
pub fn __str__(&self) -> String {
self.__repr__()
}
pub fn __copy__(&self) -> PyResult<Self> {
self.clone()
}
pub fn __deepcopy__(&self, _memo: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {
self.clone()
}
}
#[pyfunction]
pub fn is_grad_enabled() -> bool {
true }
#[pyfunction]
pub fn no_grad() -> PyResult<()> {
Ok(())
}
#[pyfunction]
pub fn enable_grad() -> PyResult<()> {
Ok(())
}
#[pyfunction]
pub fn set_grad_enabled(mode: bool) -> PyResult<()> {
Ok(())
}