use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
pub trait ZeroCopyOps<T: Float> {
fn inplace_add(&mut self, other: &Tensor<T>) -> RusTorchResult<()>;
fn inplace_sub(&mut self, other: &Tensor<T>) -> RusTorchResult<()>;
fn inplace_mul(&mut self, other: &Tensor<T>) -> RusTorchResult<()>;
fn inplace_mul_scalar(&mut self, scalar: T);
fn inplace_add_scalar(&mut self, scalar: T);
fn inplace_apply<F>(&mut self, f: F) -> RusTorchResult<()>
where
F: Fn(T) -> T + Send + Sync;
fn slice_view(&self, ranges: &[std::ops::Range<usize>]) -> RusTorchResult<Tensor<T>>;
fn shares_memory_with(&self, other: &Tensor<T>) -> bool;
fn detach(&self) -> Tensor<T>;
}
pub trait TensorIterOps<T: Float> {
fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T>
where
T: 'a;
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T>
where
T: 'a;
}
impl<T: Float + Clone + 'static + ndarray::ScalarOperand> ZeroCopyOps<T> for Tensor<T> {
fn inplace_add(&mut self, other: &Tensor<T>) -> RusTorchResult<()> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
*a = *a + *b;
}
Ok(())
}
fn inplace_sub(&mut self, other: &Tensor<T>) -> RusTorchResult<()> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
*a = *a - *b;
}
Ok(())
}
fn inplace_mul(&mut self, other: &Tensor<T>) -> RusTorchResult<()> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
*a = *a * *b;
}
Ok(())
}
fn inplace_mul_scalar(&mut self, scalar: T) {
for a in self.data.iter_mut() {
*a = *a * scalar;
}
}
fn inplace_add_scalar(&mut self, scalar: T) {
for a in self.data.iter_mut() {
*a = *a + scalar;
}
}
fn inplace_apply<F>(&mut self, f: F) -> RusTorchResult<()>
where
F: Fn(T) -> T + Send + Sync,
{
self.data.mapv_inplace(f);
Ok(())
}
fn slice_view(&self, ranges: &[std::ops::Range<usize>]) -> RusTorchResult<Tensor<T>> {
if ranges.len() != self.ndim() {
return Err(RusTorchError::TensorOp {
message: format!(
"Number of slice ranges {} does not match tensor dimensions {}",
ranges.len(),
self.ndim()
),
source: None,
});
}
for (i, range) in ranges.iter().enumerate() {
if range.end > self.shape()[i] {
return Err(RusTorchError::TensorOp {
message: format!(
"Slice range {}..{} exceeds dimension {} size {}",
range.start,
range.end,
i,
self.shape()[i]
),
source: None,
});
}
}
if ranges.len() == 2 && self.ndim() == 2 {
let rows = &ranges[0];
let cols = &ranges[1];
let original_shape = self.shape();
let mut sliced_data = Vec::new();
for r in rows.clone() {
for c in cols.clone() {
let idx = r * original_shape[1] + c;
if let Some(&value) = self.data.as_slice().unwrap().get(idx) {
sliced_data.push(value);
}
}
}
let new_shape = vec![rows.len(), cols.len()];
Self::try_from_vec(sliced_data, new_shape)
} else {
Ok(self.clone())
}
}
fn shares_memory_with(&self, other: &Tensor<T>) -> bool {
let self_ptr = self.data.as_ptr();
let other_ptr = other.data.as_ptr();
let self_len = self.data.len();
let other_len = other.data.len();
let self_start = self_ptr as usize;
let self_end = self_start + self_len * std::mem::size_of::<T>();
let other_start = other_ptr as usize;
let other_end = other_start + other_len * std::mem::size_of::<T>();
(self_start < other_end) && (other_start < self_end)
}
fn detach(&self) -> Tensor<T> {
Tensor::new(self.data.clone())
}
}
impl<T: Float + Clone + 'static> TensorIterOps<T> for Tensor<T> {
fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T>
where
T: 'a,
{
self.data.iter()
}
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T>
where
T: 'a,
{
self.data.iter_mut()
}
}