use crate::error::RusTorchError;
use pyo3::exceptions::*;
use pyo3::prelude::*;
use std::sync::{Arc, RwLock};
pub fn to_py_err(error: RusTorchError) -> PyErr {
match error {
RusTorchError::ShapeMismatch { expected, actual } => PyValueError::new_err(format!(
"Shape mismatch: expected {:?}, got {:?}",
expected, actual
)),
RusTorchError::Device { device, message } => {
PyRuntimeError::new_err(format!("Device error on {}: {}", device, message))
}
RusTorchError::TensorOp { message, .. } => PyRuntimeError::new_err(message),
_ => PyRuntimeError::new_err(error.to_string()),
}
}
pub trait PyWrapper<T> {
fn from_rust(value: T) -> Self;
fn to_rust(&self) -> &T;
fn into_rust(self) -> T;
}
pub trait ThreadSafePyWrapper<T> {
fn from_arc_rwlock(value: Arc<RwLock<T>>) -> Self;
fn as_arc_rwlock(&self) -> &Arc<RwLock<T>>;
fn clone_arc_rwlock(&self) -> Arc<RwLock<T>>;
}
pub mod validation {
use super::*;
pub fn validate_dimensions(dims: &[usize]) -> PyResult<()> {
if dims.is_empty() {
return Err(PyValueError::new_err("Tensor dimensions cannot be empty"));
}
if dims.iter().any(|&d| d == 0) {
return Err(PyValueError::new_err(
"Tensor dimensions cannot contain zero",
));
}
let total_elements: usize = dims.iter().product();
if total_elements > 1_000_000_000 {
return Err(PyValueError::new_err("Tensor too large (>1B elements)"));
}
Ok(())
}
pub fn validate_learning_rate(lr: f64) -> PyResult<()> {
if lr <= 0.0 || lr > 1.0 {
return Err(PyValueError::new_err("Learning rate must be in (0, 1]"));
}
Ok(())
}
pub fn validate_beta(beta: f64, name: &str) -> PyResult<()> {
if !(0.0..1.0).contains(&beta) {
return Err(PyValueError::new_err(format!("{} must be in [0, 1)", name)));
}
Ok(())
}
pub fn validate_epsilon(eps: f64) -> PyResult<()> {
if eps <= 0.0 {
return Err(PyValueError::new_err("Epsilon must be positive"));
}
Ok(())
}
}
pub mod conversions {
use super::*;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, ToPyArray};
pub fn vec_to_pyarray<'py>(vec: Vec<f32>, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
vec.into_pyarray(py)
}
pub fn pyarray_to_vec(array: PyReadonlyArray1<f32>) -> Vec<f32> {
array.as_array().to_vec()
}
pub fn pylist_to_vec_usize(list: &Bound<'_, pyo3::types::PyList>) -> PyResult<Vec<usize>> {
let mut result = Vec::with_capacity(list.len());
for (i, item) in list.iter().enumerate() {
let value: usize = item
.extract()
.map_err(|_| PyTypeError::new_err(format!("Item {} is not a valid integer", i)))?;
result.push(value);
}
Ok(result)
}
pub fn pylist_to_vec_f32(list: &Bound<'_, pyo3::types::PyList>) -> PyResult<Vec<f32>> {
let mut result = Vec::with_capacity(list.len());
for (i, item) in list.iter().enumerate() {
let value: f32 = item
.extract()
.map_err(|_| PyTypeError::new_err(format!("Item {} is not a valid float", i)))?;
result.push(value);
}
Ok(result)
}
pub fn pylist_to_shape(list: &Bound<'_, pyo3::types::PyList>) -> PyResult<Vec<usize>> {
let shape = pylist_to_vec_usize(list)?;
crate::python::common::validation::validate_dimensions(&shape)?;
Ok(shape)
}
}
pub mod memory {
use super::*;
pub fn safe_read<T, F, R>(arc_lock: &Arc<RwLock<T>>, f: F) -> PyResult<R>
where
F: FnOnce(&T) -> R,
{
match arc_lock.try_read() {
Ok(guard) => Ok(f(&*guard)),
Err(_) => Err(PyRuntimeError::new_err("Failed to acquire read lock")),
}
}
pub fn safe_write<T, F, R>(arc_lock: &Arc<RwLock<T>>, f: F) -> PyResult<R>
where
F: FnOnce(&mut T) -> R,
{
match arc_lock.try_write() {
Ok(mut guard) => Ok(f(&mut *guard)),
Err(_) => Err(PyRuntimeError::new_err("Failed to acquire write lock")),
}
}
}
#[macro_export]
macro_rules! impl_py_common_methods {
($type:ty, $rust_type:ty) => {
#[pymethods]
impl $type {
fn __repr__(&self) -> String {
format!("{}(...)", stringify!($type))
}
fn __copy__(&self) -> Self {
self.clone()
}
fn __deepcopy__(&self, _memo: &Bound<'_, pyo3::types::PyDict>) -> Self {
self.clone()
}
}
};
}
#[macro_export]
macro_rules! impl_thread_safe_wrapper {
($type:ty, $rust_type:ty) => {
impl ThreadSafePyWrapper<$rust_type> for $type {
fn from_arc_rwlock(value: Arc<RwLock<$rust_type>>) -> Self {
Self { inner: value }
}
fn as_arc_rwlock(&self) -> &Arc<RwLock<$rust_type>> {
&self.inner
}
fn clone_arc_rwlock(&self) -> Arc<RwLock<$rust_type>> {
Arc::clone(&self.inner)
}
}
};
}