use crate::error::{RusTorchError, RusTorchResult};
use crate::hybrid_f32_experimental;
use ndarray::{Array, IxDyn};
use std::ops::{Index, IndexMut};
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub struct Index2D(pub usize, pub usize);
#[derive(Debug, Clone, Copy)]
pub struct Index3D(pub usize, pub usize, pub usize);
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum DeviceState {
CPU,
Metal { device_id: usize },
CoreML { device_id: usize },
Synchronized, }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MetalBuffer {
_device_id: usize,
_size: usize,
}
impl MetalBuffer {
pub fn new(device_id: usize, size: usize) -> Self {
Self {
_device_id: device_id,
_size: size,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CoreMLBuffer {
_device_id: usize,
_shape: Vec<usize>,
}
impl CoreMLBuffer {
pub fn new(device_id: usize, shape: Vec<usize>) -> Self {
Self {
_device_id: device_id,
_shape: shape,
}
}
}
#[derive(Debug)]
pub struct F32Tensor {
pub data: Array<f32, IxDyn>,
pub metal_buffer: Option<Arc<MetalBuffer>>,
pub coreml_buffer: Option<Arc<CoreMLBuffer>>,
pub device_state: DeviceState,
pub requires_grad: bool,
shape: Vec<usize>,
}
impl Clone for F32Tensor {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
metal_buffer: self.metal_buffer.clone(),
coreml_buffer: self.coreml_buffer.clone(),
device_state: self.device_state.clone(),
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
impl Add<F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn add(self, rhs: F32Tensor) -> Self::Output {
(&self).add(&rhs)
}
}
impl Add<&F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn add(self, rhs: &F32Tensor) -> Self::Output {
(&self).add(rhs)
}
}
impl Add for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn add(self, rhs: &F32Tensor) -> Self::Output {
self.add(rhs)
}
}
impl Add<f32> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn add(self, rhs: f32) -> Self::Output {
let scalar_tensor = F32Tensor::from_scalar(rhs)?;
self.add(&scalar_tensor)
}
}
impl Add<f32> for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn add(self, rhs: f32) -> Self::Output {
let scalar_tensor = F32Tensor::from_scalar(rhs)?;
self.add(&scalar_tensor)
}
}
impl Sub<F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn sub(self, rhs: F32Tensor) -> Self::Output {
(&self).sub(&rhs)
}
}
impl Sub<&F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn sub(self, rhs: &F32Tensor) -> Self::Output {
(&self).sub(rhs)
}
}
impl Sub for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn sub(self, rhs: &F32Tensor) -> Self::Output {
self.sub(rhs)
}
}
impl Sub<f32> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn sub(self, rhs: f32) -> Self::Output {
let scalar_tensor = F32Tensor::from_scalar(rhs)?;
self.sub(&scalar_tensor)
}
}
impl Sub<f32> for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn sub(self, rhs: f32) -> Self::Output {
let scalar_tensor = F32Tensor::from_scalar(rhs)?;
self.sub(&scalar_tensor)
}
}
impl Mul<F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn mul(self, rhs: F32Tensor) -> Self::Output {
(&self).mul(&rhs)
}
}
impl Mul<&F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn mul(self, rhs: &F32Tensor) -> Self::Output {
(&self).mul(rhs)
}
}
impl Mul for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn mul(self, rhs: &F32Tensor) -> Self::Output {
self.mul(rhs)
}
}
impl Mul<f32> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn mul(self, rhs: f32) -> Self::Output {
self.mul_scalar(rhs)
}
}
impl Mul<f32> for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn mul(self, rhs: f32) -> Self::Output {
self.mul_scalar(rhs)
}
}
impl Div<F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn div(self, rhs: F32Tensor) -> Self::Output {
(&self).divide(&rhs)
}
}
impl Div<&F32Tensor> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn div(self, rhs: &F32Tensor) -> Self::Output {
(&self).divide(rhs)
}
}
impl Div for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn div(self, rhs: &F32Tensor) -> Self::Output {
self.divide(rhs)
}
}
impl Div<f32> for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn div(self, rhs: f32) -> Self::Output {
if rhs == 0.0 {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "div".to_string(),
message: "Division by zero".to_string(),
});
}
self.mul_scalar(1.0 / rhs)
}
}
impl Div<f32> for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn div(self, rhs: f32) -> Self::Output {
if rhs == 0.0 {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "div".to_string(),
message: "Division by zero".to_string(),
});
}
self.mul_scalar(1.0 / rhs)
}
}
impl Neg for F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn neg(self) -> Self::Output {
self.mul_scalar(-1.0)
}
}
impl Neg for &F32Tensor {
type Output = RusTorchResult<F32Tensor>;
fn neg(self) -> Self::Output {
self.mul_scalar(-1.0)
}
}
impl F32Tensor {
pub fn as_slice(&self) -> &[f32] {
self.data.as_slice().unwrap_or(&[])
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn is_empty(&self) -> bool {
self.numel() == 0
}
pub fn is_scalar(&self) -> bool {
self.numel() == 1
}
pub fn is_grad_enabled(&self) -> bool {
self.requires_grad
}
pub fn requires_grad(&mut self, requires: bool) {
self.requires_grad = requires;
}
pub fn zeros(shape: &[usize]) -> RusTorchResult<Self> {
hybrid_f32_experimental!();
let data = Array::zeros(IxDyn(shape));
Ok(Self {
data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: shape.to_vec(),
})
}
pub fn randn(shape: &[usize]) -> RusTorchResult<Self> {
hybrid_f32_experimental!();
use rand::Rng;
use rand_distr::StandardNormal;
let mut rng = rand::thread_rng();
let size: usize = shape.iter().product();
let data: Vec<f32> = (0..size).map(|_| rng.sample(StandardNormal)).collect();
let array = Array::from_shape_vec(IxDyn(shape), data).map_err(|e| {
RusTorchError::InvalidParameters {
operation: "randn".to_string(),
message: format!("Shape error: {}", e),
}
})?;
Ok(Self {
data: array,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: shape.to_vec(),
})
}
pub fn from_scalar(value: f32) -> RusTorchResult<Self> {
hybrid_f32_experimental!();
let data = Array::from_elem(IxDyn(&[1]), value);
Ok(Self {
data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: vec![1],
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn to_cpu(&self) -> RusTorchResult<Self> {
Ok(self.clone())
}
pub fn to_metal(&mut self, device_id: usize) -> RusTorchResult<()> {
hybrid_f32_experimental!();
self.device_state = DeviceState::Metal { device_id };
self.metal_buffer = Some(Arc::new(MetalBuffer::new(device_id, self.data.len())));
Ok(())
}
pub fn to_coreml(&mut self, device_id: usize) -> RusTorchResult<()> {
hybrid_f32_experimental!();
self.device_state = DeviceState::CoreML { device_id };
self.coreml_buffer = Some(Arc::new(CoreMLBuffer::new(device_id, self.shape.clone())));
Ok(())
}
pub fn device_state(&self) -> &DeviceState {
&self.device_state
}
pub fn unwrap(&self) -> RusTorchResult<f32> {
if self.data.len() == 1 {
Ok(self.data.iter().next().copied().unwrap_or(0.0))
} else {
Err(RusTorchError::InvalidParameters {
operation: "unwrap".to_string(),
message: format!("Tensor has {} elements, expected 1", self.data.len()),
})
}
}
pub fn add(&self, other: &Self) -> RusTorchResult<Self> {
if other.shape == [1] {
let scalar_value = other.data.iter().next().unwrap();
let result_data = self.data.mapv(|x| x + scalar_value);
return Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
});
}
if self.shape != other.shape {
return Err(RusTorchError::shape_mismatch(&self.shape, &other.shape));
}
let result_data = &self.data + &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn mul(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = &self.data * &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn matmul(&self, other: &Self) -> RusTorchResult<Self> {
if self.shape.len() == 2 && other.shape.len() == 2 {
let (m, k) = (self.shape[0], self.shape[1]);
let (k2, n) = (other.shape[0], other.shape[1]);
if k != k2 {
return Err(RusTorchError::InvalidParameters {
operation: "matmul".to_string(),
message: format!("Incompatible dimensions: {}x{} and {}x{}", m, k, k2, n),
});
}
let result_shape = vec![m, n];
let mut result_data = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += self.data[[i, l]] * other.data[[l, j]];
}
result_data[i * n + j] = sum;
}
}
let array = Array::from_shape_vec(IxDyn(&result_shape), result_data).map_err(|e| {
RusTorchError::InvalidParameters {
operation: "matmul".to_string(),
message: format!("Shape error: {}", e),
}
})?;
Ok(Self {
data: array,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: result_shape,
})
} else {
Err(RusTorchError::InvalidParameters {
operation: "matmul".to_string(),
message: "Only 2D tensors supported".to_string(),
})
}
}
pub fn transpose(&self) -> RusTorchResult<Self> {
if self.shape.len() == 2 {
let transposed = self.data.view().reversed_axes().to_owned();
let new_shape = vec![self.shape[1], self.shape[0]];
Ok(Self {
data: transposed,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: new_shape,
})
} else {
Err(RusTorchError::InvalidParameters {
operation: "transpose".to_string(),
message: "Only 2D tensors supported".to_string(),
})
}
}
pub fn sub(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = &self.data - &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn numel(&self) -> usize {
self.data.len()
}
pub fn gt(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| {
if x > other.data.iter().next().copied().unwrap_or(0.0) {
1.0
} else {
0.0
}
});
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: self.shape.clone(),
})
}
pub fn le(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| {
if x <= other.data.iter().next().copied().unwrap_or(0.0) {
1.0
} else {
0.0
}
});
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: self.shape.clone(),
})
}
pub fn relu(&self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.max(0.0));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn sigmoid(&self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn tanh(&self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.tanh());
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn exp(&self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.exp());
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn log(&self) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.ln());
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn power(&self, exponent: f32) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.powf(exponent));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn maximum(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = self
.data
.mapv(|x| x.max(other.data.iter().next().copied().unwrap_or(0.0)));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn minimum(&self, other: &Self) -> RusTorchResult<Self> {
let result_data = self
.data
.mapv(|x| x.min(other.data.iter().next().copied().unwrap_or(0.0)));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn clamp(&self, min: f32, max: f32) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x.max(min).min(max));
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn argmax(&self) -> RusTorchResult<Self> {
let max_idx = self
.data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx as f32)
.unwrap_or(0.0);
Self::from_scalar(max_idx)
}
pub fn reshape(&self, new_shape: &[usize]) -> RusTorchResult<Self> {
let new_size: usize = new_shape.iter().product();
if new_size != self.data.len() {
return Err(RusTorchError::InvalidParameters {
operation: "reshape".to_string(),
message: format!(
"Cannot reshape tensor of size {} to size {}",
self.data.len(),
new_size
),
});
}
let reshaped_data = self
.data
.clone()
.into_shape_with_order(IxDyn(new_shape))
.map_err(|e| RusTorchError::InvalidParameters {
operation: "reshape".to_string(),
message: format!("Reshape error: {}", e),
})?;
Ok(Self {
data: reshaped_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: new_shape.to_vec(),
})
}
pub fn slice(&self, ranges: &[(usize, usize)]) -> RusTorchResult<Self> {
let mut result_data = Vec::new();
let shape = self.shape();
if ranges.len() == 1 && self.ndim() == 1 {
let (start, end) = ranges[0];
let slice_data = &self.data.as_slice().unwrap()[start..end];
result_data.extend_from_slice(slice_data);
F32Tensor::new(result_data, &[end - start])
} else {
Ok(self.clone())
}
}
pub fn to_type(&self, _dtype: &str) -> RusTorchResult<Self> {
Ok(self.clone())
}
pub fn divide(&self, other: &Self) -> RusTorchResult<Self> {
if other.shape == [1] {
let scalar_value = other.data.iter().next().unwrap();
if *scalar_value == 0.0 {
return Err(RusTorchError::tensor_op("Division by zero"));
}
let result_data = self.data.mapv(|x| x / scalar_value);
return Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
});
}
if self.shape != other.shape {
return Err(RusTorchError::shape_mismatch(&self.shape, &other.shape));
}
for &value in other.data.iter() {
if value == 0.0 {
return Err(RusTorchError::tensor_op("Division by zero"));
}
}
let result_data = &self.data / &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn subtract(&self, other: &Self) -> RusTorchResult<Self> {
if other.shape == [1] {
let scalar_value = other.data.iter().next().unwrap();
let result_data = self.data.mapv(|x| x - scalar_value);
return Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
});
}
if self.shape != other.shape {
return Err(RusTorchError::shape_mismatch(&self.shape, &other.shape));
}
let result_data = &self.data - &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn multiply(&self, other: &Self) -> RusTorchResult<Self> {
if other.shape == [1] {
let scalar_value = other.data.iter().next().unwrap();
let result_data = self.data.mapv(|x| x * scalar_value);
return Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
});
}
if self.shape != other.shape {
return Err(RusTorchError::shape_mismatch(&self.shape, &other.shape));
}
let result_data = &self.data * &other.data;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
})
}
pub fn add_scalar(&self, scalar: f32) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x + scalar);
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn multiply_scalar(&self, scalar: f32) -> RusTorchResult<Self> {
let result_data = self.data.mapv(|x| x * scalar);
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn mean(&self) -> RusTorchResult<f32> {
if self.data.is_empty() {
return Err(RusTorchError::tensor_op(
"Cannot calculate mean of empty tensor",
));
}
Ok(self.data.mean().unwrap())
}
pub fn min(&self) -> RusTorchResult<f32> {
if self.data.is_empty() {
return Err(RusTorchError::tensor_op(
"Cannot calculate min of empty tensor",
));
}
let min_val = self.data.iter().cloned().fold(f32::INFINITY, f32::min);
Ok(min_val)
}
pub fn max(&self) -> RusTorchResult<f32> {
if self.data.is_empty() {
return Err(RusTorchError::tensor_op(
"Cannot calculate max of empty tensor",
));
}
let max_val = self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
Ok(max_val)
}
pub fn mean_tensor(&self) -> RusTorchResult<Self> {
let mean_val = self.data.mean().unwrap_or(0.0);
Self::from_scalar(mean_val)
}
pub fn sum_dim(&self, _dim: usize) -> RusTorchResult<Self> {
let sum_val = self.data.sum();
Self::from_scalar(sum_val)
}
pub fn from_vec(data: Vec<f32>, shape: &[usize]) -> RusTorchResult<Self> {
let expected_size: usize = shape.iter().product();
if data.len() != expected_size {
return Err(RusTorchError::InvalidParameters {
operation: "from_vec".to_string(),
message: format!(
"Data length {} doesn't match shape size {}",
data.len(),
expected_size
),
});
}
let array = Array::from_shape_vec(IxDyn(shape), data).map_err(|e| {
RusTorchError::InvalidParameters {
operation: "from_vec".to_string(),
message: format!("Shape error: {}", e),
}
})?;
Ok(Self {
data: array,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: shape.to_vec(),
})
}
pub fn ones(shape: &[usize]) -> RusTorchResult<Self> {
let data = Array::ones(IxDyn(shape));
Ok(Self {
data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape: shape.to_vec(),
})
}
pub fn new(data: Vec<f32>, shape: &[usize]) -> RusTorchResult<Self> {
Self::from_vec(data, shape)
}
pub fn as_slice_option(&self) -> Option<&[f32]> {
self.data.as_slice()
}
pub fn scalar_value(&self) -> RusTorchResult<f32> {
self.unwrap()
}
pub fn try_add(&self, other: &F32Tensor) -> RusTorchResult<F32Tensor> {
self.add(other)
}
pub fn try_sub(&self, other: &F32Tensor) -> RusTorchResult<F32Tensor> {
self.sub(other)
}
pub fn try_mul(&self, other: &F32Tensor) -> RusTorchResult<F32Tensor> {
self.mul(other)
}
pub fn try_div(&self, other: &F32Tensor) -> RusTorchResult<F32Tensor> {
self.divide(other)
}
pub fn try_matmul(&self, other: &F32Tensor) -> RusTorchResult<F32Tensor> {
self.matmul(other)
}
pub fn try_mul_scalar(&self, scalar: f32) -> RusTorchResult<F32Tensor> {
if scalar.is_nan() || scalar.is_infinite() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_mul_scalar".to_string(),
message: format!("Invalid scalar value: {}", scalar),
});
}
self.mul_scalar(scalar)
}
pub fn try_reshape(&self, new_shape: &[usize]) -> RusTorchResult<F32Tensor> {
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_reshape".to_string(),
message: format!(
"Cannot reshape tensor with {} elements to shape {:?} ({} elements)",
self.numel(),
new_shape,
new_numel
),
});
}
self.reshape(new_shape)
}
pub fn try_transpose(&self) -> RusTorchResult<F32Tensor> {
if self.ndim() != 2 {
return Err(crate::error::RusTorchError::InvalidOperation(format!(
"transpose requires 2D tensor, got {}D",
self.ndim()
)));
}
self.transpose()
}
pub fn try_slice(&self, ranges: &[(usize, usize)]) -> RusTorchResult<F32Tensor> {
if ranges.len() != self.ndim() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_slice".to_string(),
message: format!(
"Expected {} slice ranges for {}D tensor, got {}",
self.ndim(),
self.ndim(),
ranges.len()
),
});
}
let shape = self.shape();
for (i, &(start, end)) in ranges.iter().enumerate() {
if start >= end || end > shape[i] {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_slice".to_string(),
message: format!(
"Invalid slice range for dimension {}: {}..{} (max: {})",
i, start, end, shape[i]
),
});
}
}
self.slice(ranges)
}
pub fn try_to_cpu(&self) -> RusTorchResult<F32Tensor> {
self.to_cpu()
}
pub fn try_to_metal(&mut self, device_id: usize) -> RusTorchResult<()> {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
self.to_metal(device_id)
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "Metal (macOS + metal feature required)".to_string(),
})
}
}
pub fn try_to_coreml(&mut self, device_id: usize) -> RusTorchResult<()> {
#[cfg(all(target_os = "macos", feature = "coreml"))]
{
self.to_coreml(device_id)
}
#[cfg(not(all(target_os = "macos", feature = "coreml")))]
{
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "CoreML (macOS + coreml feature required)".to_string(),
})
}
}
pub fn try_to_type<T>(&self) -> RusTorchResult<Vec<T>>
where
T: From<f32> + Copy,
{
if self.numel() == 0 {
return Ok(Vec::new());
}
let data = self.data.as_slice().ok_or_else(|| {
crate::error::RusTorchError::InvalidOperation("Cannot access tensor data".to_string())
})?;
Ok(data.iter().map(|&x| T::from(x)).collect())
}
pub fn try_get(&self, indices: &[usize]) -> RusTorchResult<f32> {
if indices.len() != self.ndim() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_get".to_string(),
message: format!(
"Expected {} indices for {}D tensor, got {}",
self.ndim(),
self.ndim(),
indices.len()
),
});
}
let shape = self.shape();
for (i, &idx) in indices.iter().enumerate() {
if idx >= shape[i] {
return Err(crate::error::RusTorchError::index_out_of_bounds(
&[idx],
&[shape[i]],
));
}
}
let mut flat_index = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
flat_index += indices[i] * stride;
stride *= shape[i];
}
let data = self.data.as_slice().ok_or_else(|| {
crate::error::RusTorchError::InvalidOperation("Cannot access tensor data".to_string())
})?;
Ok(data[flat_index])
}
pub fn try_set(&mut self, indices: &[usize], value: f32) -> RusTorchResult<()> {
if value.is_nan() || value.is_infinite() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_set".to_string(),
message: format!("Invalid value: {}", value),
});
}
if indices.len() != self.ndim() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "try_set".to_string(),
message: format!(
"Expected {} indices for {}D tensor, got {}",
self.ndim(),
self.ndim(),
indices.len()
),
});
}
let shape = self.shape();
for (i, &idx) in indices.iter().enumerate() {
if idx >= shape[i] {
return Err(crate::error::RusTorchError::index_out_of_bounds(
&[idx],
&[shape[i]],
));
}
}
let mut flat_index = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
flat_index += indices[i] * stride;
stride *= shape[i];
}
let data = self.data.as_slice_mut().ok_or_else(|| {
crate::error::RusTorchError::InvalidOperation(
"Cannot access tensor data for modification".to_string(),
)
})?;
data[flat_index] = value;
Ok(())
}
pub fn sum(&self) -> RusTorchResult<f32> {
Ok(self.data.sum())
}
pub fn mul_scalar(&self, scalar: f32) -> RusTorchResult<Self> {
let result_data = &self.data * scalar;
Ok(Self {
data: result_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
})
}
pub fn unsqueeze(&self, dim: usize) -> RusTorchResult<Self> {
let mut new_shape = self.shape.clone();
if dim > new_shape.len() {
return Err(RusTorchError::InvalidParameters {
operation: "unsqueeze".to_string(),
message: format!(
"Dimension {} out of bounds for tensor with {} dimensions",
dim,
new_shape.len()
),
});
}
new_shape.insert(dim, 1);
let reshaped_data = self
.data
.clone()
.into_shape_with_order(IxDyn(&new_shape))
.map_err(|e| RusTorchError::InvalidParameters {
operation: "unsqueeze".to_string(),
message: format!("Reshape error: {}", e),
})?;
Ok(Self {
data: reshaped_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: new_shape,
})
}
pub fn expand(&self, new_shape: &[usize]) -> RusTorchResult<Self> {
if new_shape.len() != self.shape.len() {
return Err(RusTorchError::InvalidParameters {
operation: "expand".to_string(),
message: format!(
"Cannot expand from {} dimensions to {} dimensions",
self.shape.len(),
new_shape.len()
),
});
}
for (i, (¤t, &target)) in self.shape.iter().zip(new_shape.iter()).enumerate() {
if current != 1 && current != target {
return Err(RusTorchError::InvalidParameters {
operation: "expand".to_string(),
message: format!(
"Cannot expand dimension {} from {} to {}",
i, current, target
),
});
}
}
let total_size: usize = new_shape.iter().product();
let mut expanded_data = Vec::with_capacity(total_size);
let source_data = self.data.as_slice().unwrap();
let source_size = source_data.len();
let repeat_count = total_size / source_size;
for _ in 0..repeat_count {
expanded_data.extend_from_slice(source_data);
}
let array = Array::from_shape_vec(IxDyn(new_shape), expanded_data).map_err(|e| {
RusTorchError::InvalidParameters {
operation: "expand".to_string(),
message: format!("Shape error: {}", e),
}
})?;
Ok(Self {
data: array,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: new_shape.to_vec(),
})
}
pub fn transpose_dims(&self, dim1: usize, dim2: usize) -> RusTorchResult<Self> {
if dim1 >= self.shape.len() || dim2 >= self.shape.len() {
return Err(RusTorchError::InvalidParameters {
operation: "transpose_dims".to_string(),
message: format!(
"Dimension indices {} and {} out of bounds for tensor with {} dimensions",
dim1,
dim2,
self.shape.len()
),
});
}
if dim1 == dim2 {
return Ok(self.clone());
}
let mut new_shape = self.shape.clone();
new_shape.swap(dim1, dim2);
let mut transposed_data = self.data.clone();
transposed_data.swap_axes(dim1, dim2);
Ok(Self {
data: transposed_data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: new_shape,
})
}
pub fn softmax(&self, dim: Option<usize>) -> RusTorchResult<Self> {
let softmax_dim = dim.unwrap_or(self.shape.len().saturating_sub(1));
if softmax_dim >= self.shape.len() {
return Err(RusTorchError::InvalidParameters {
operation: "softmax".to_string(),
message: format!(
"Dimension {} out of bounds for tensor with {} dimensions",
softmax_dim,
self.shape.len()
),
});
}
let max_val = self.data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let max_tensor = F32Tensor::from_scalar(max_val)?;
let shifted = self.sub(&max_tensor)?;
let exp_data = shifted.exp()?;
let sum_val = exp_data.data.sum();
let sum_tensor = F32Tensor::from_scalar(sum_val)?;
exp_data.divide(&sum_tensor)
}
pub fn qr_decomposition(&self) -> RusTorchResult<(Self, Self)> {
if self.ndim() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "qr_decomposition".to_string(),
message: "QR decomposition requires 2D tensor".to_string(),
});
}
let (m, n) = (self.shape[0], self.shape[1]);
let min_dim = m.min(n);
let mut q_data = vec![0.0f32; m * m];
for i in 0..m {
q_data[i * m + i] = 1.0;
}
let mut r_data = self.data.as_slice().unwrap().to_vec();
for k in 0..min_dim {
let mut v = vec![0.0f32; m - k];
for i in k..m {
v[i - k] = r_data[i * n + k];
}
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
continue;
}
v[0] += if v[0] >= 0.0 { norm } else { -norm };
let v_norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if v_norm == 0.0 {
continue;
}
for i in 0..v.len() {
v[i] /= v_norm;
}
for j in k..n {
let mut dot_product = 0.0;
for i in k..m {
dot_product += v[i - k] * r_data[i * n + j];
}
for i in k..m {
r_data[i * n + j] -= 2.0 * v[i - k] * dot_product;
}
}
for j in 0..m {
let mut dot_product = 0.0;
for i in k..m {
dot_product += v[i - k] * q_data[i * m + j];
}
for i in k..m {
q_data[i * m + j] -= 2.0 * v[i - k] * dot_product;
}
}
}
let q = F32Tensor::from_vec(q_data, &[m, m])?;
let r = F32Tensor::from_vec(r_data, &[m, n])?;
Ok((q, r))
}
pub fn cholesky_decomposition(&self) -> RusTorchResult<Self> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "cholesky_decomposition".to_string(),
message: "Cholesky decomposition requires square matrix".to_string(),
});
}
let n = self.shape[0];
let mut l_data = vec![0.0f32; n * n];
let a_data = self.data.as_slice().unwrap();
for i in 0..n {
for j in 0..=i {
if i == j {
let mut sum = 0.0;
for k in 0..j {
sum += l_data[j * n + k] * l_data[j * n + k];
}
let val = a_data[j * n + j] - sum;
if val <= 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "cholesky_decomposition".to_string(),
message: "Matrix is not positive definite".to_string(),
});
}
l_data[j * n + j] = val.sqrt();
} else {
let mut sum = 0.0;
for k in 0..j {
sum += l_data[i * n + k] * l_data[j * n + k];
}
l_data[i * n + j] = (a_data[i * n + j] - sum) / l_data[j * n + j];
}
}
}
F32Tensor::from_vec(l_data, &[n, n])
}
pub fn svd(&self) -> RusTorchResult<(Self, Self, Self)> {
if self.ndim() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "svd".to_string(),
message: "SVD requires 2D tensor".to_string(),
});
}
let (m, n) = (self.shape[0], self.shape[1]);
let at = self.transpose()?;
let ata = at.matmul(self)?;
let mut v = F32Tensor::randn(&[n, 1])?;
for _ in 0..100 {
let av = ata.matmul(&v)?;
let norm = av.data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
break;
}
v = av.mul_scalar(1.0 / norm)?;
}
let av = self.matmul(&v)?;
let sigma = av.data.iter().map(|x| x * x).sum::<f32>().sqrt();
let u = if sigma > 1e-10 {
av.mul_scalar(1.0 / sigma)?
} else {
F32Tensor::zeros(&[m, 1])?
};
let s = F32Tensor::from_scalar(sigma)?;
Ok((u, s, v))
}
pub fn eigen_decomposition(&self) -> RusTorchResult<(Self, Self)> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "eigen_decomposition".to_string(),
message: "Eigenvalue decomposition requires square matrix".to_string(),
});
}
let n = self.shape[0];
let mut v = F32Tensor::randn(&[n, 1])?;
let mut eigenvalue = 0.0;
for _ in 0..100 {
let av = self.matmul(&v)?;
let vt_av = v.transpose()?.matmul(&av)?;
let vt_v = v.transpose()?.matmul(&v)?;
eigenvalue = vt_av.unwrap()? / vt_v.unwrap()?;
let norm = av.data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-10 {
break;
}
v = av.mul_scalar(1.0 / norm)?;
}
let eigenvalues = F32Tensor::from_scalar(eigenvalue)?;
let eigenvectors = v;
Ok((eigenvalues, eigenvectors))
}
pub fn lu_decomposition(&self) -> RusTorchResult<(Self, Self, Self)> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "lu_decomposition".to_string(),
message: "LU decomposition requires square matrix".to_string(),
});
}
let n = self.shape[0];
let mut a_data = self.data.as_slice().unwrap().to_vec();
let mut l_data = vec![0.0f32; n * n];
let mut p_data = vec![0.0f32; n * n];
for i in 0..n {
p_data[i * n + i] = 1.0;
}
for i in 0..n {
l_data[i * n + i] = 1.0;
}
for k in 0..n {
let mut max_row = k;
let mut max_val = a_data[k * n + k].abs();
for i in (k + 1)..n {
if a_data[i * n + k].abs() > max_val {
max_val = a_data[i * n + k].abs();
max_row = i;
}
}
if max_row != k {
for j in 0..n {
a_data.swap(k * n + j, max_row * n + j);
p_data.swap(k * n + j, max_row * n + j);
}
}
for i in (k + 1)..n {
if a_data[k * n + k].abs() < 1e-10 {
return Err(RusTorchError::InvalidParameters {
operation: "lu_decomposition".to_string(),
message: "Matrix is singular".to_string(),
});
}
let factor = a_data[i * n + k] / a_data[k * n + k];
l_data[i * n + k] = factor;
for j in k..n {
a_data[i * n + j] -= factor * a_data[k * n + j];
}
}
}
let p = F32Tensor::from_vec(p_data, &[n, n])?;
let l = F32Tensor::from_vec(l_data, &[n, n])?;
let u = F32Tensor::from_vec(a_data, &[n, n])?;
Ok((p, l, u))
}
pub fn determinant(&self) -> RusTorchResult<f32> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "determinant".to_string(),
message: "Determinant requires square matrix".to_string(),
});
}
let (p, _l, u) = self.lu_decomposition()?;
let n = self.shape[0];
let u_data = u.data.as_slice().unwrap();
let mut det = 1.0;
for i in 0..n {
det *= u_data[i * n + i];
}
let p_data = p.data.as_slice().unwrap();
let mut sign = 1.0;
for i in 0..n {
for j in 0..n {
if i != j && p_data[i * n + j] == 1.0 {
sign *= -1.0;
break;
}
}
}
Ok(det * sign)
}
pub fn inverse(&self) -> RusTorchResult<Self> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "inverse".to_string(),
message: "Matrix inverse requires square matrix".to_string(),
});
}
let n = self.shape[0];
let mut augmented = vec![0.0f32; n * (2 * n)];
let a_data = self.data.as_slice().unwrap();
for i in 0..n {
for j in 0..n {
augmented[i * (2 * n) + j] = a_data[i * n + j];
augmented[i * (2 * n) + (n + j)] = if i == j { 1.0 } else { 0.0 };
}
}
for i in 0..n {
let mut max_row = i;
let mut max_val = augmented[i * (2 * n) + i].abs();
for k in (i + 1)..n {
if augmented[k * (2 * n) + i].abs() > max_val {
max_val = augmented[k * (2 * n) + i].abs();
max_row = k;
}
}
if max_row != i {
for j in 0..(2 * n) {
augmented.swap(i * (2 * n) + j, max_row * (2 * n) + j);
}
}
let pivot = augmented[i * (2 * n) + i];
if pivot.abs() < 1e-10 {
return Err(RusTorchError::InvalidParameters {
operation: "inverse".to_string(),
message: "Matrix is singular".to_string(),
});
}
for j in 0..(2 * n) {
augmented[i * (2 * n) + j] /= pivot;
}
for k in 0..n {
if k != i {
let factor = augmented[k * (2 * n) + i];
for j in 0..(2 * n) {
augmented[k * (2 * n) + j] -= factor * augmented[i * (2 * n) + j];
}
}
}
}
let mut inverse_data = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
inverse_data[i * n + j] = augmented[i * (2 * n) + (n + j)];
}
}
F32Tensor::from_vec(inverse_data, &[n, n])
}
pub fn rank(&self) -> RusTorchResult<usize> {
if self.ndim() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "rank".to_string(),
message: "Rank calculation requires 2D tensor".to_string(),
});
}
let (_u, s, _v) = self.svd()?;
let tolerance = 1e-6;
let rank = s.data.iter().filter(|&&x| x.abs() > tolerance).count();
Ok(rank)
}
pub fn condition_number(&self) -> RusTorchResult<f32> {
if self.ndim() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "condition_number".to_string(),
message: "Condition number requires 2D tensor".to_string(),
});
}
let (_u, s, _v) = self.svd()?;
let s_data = s.data.as_slice().unwrap();
if s_data.is_empty() {
return Ok(f32::INFINITY);
}
let max_singular = s_data.iter().fold(0.0f32, |a, &b| a.max(b));
let min_singular = s_data
.iter()
.filter(|&&x| x > 1e-10)
.fold(f32::INFINITY, |a, &b| a.min(b));
if min_singular == f32::INFINITY || min_singular == 0.0 {
Ok(f32::INFINITY)
} else {
Ok(max_singular / min_singular)
}
}
pub fn frobenius_norm(&self) -> RusTorchResult<f32> {
let sum_of_squares = self.data.iter().map(|&x| x * x).sum::<f32>();
Ok(sum_of_squares.sqrt())
}
pub fn trace(&self) -> RusTorchResult<f32> {
if self.ndim() != 2 || self.shape[0] != self.shape[1] {
return Err(RusTorchError::InvalidParameters {
operation: "trace".to_string(),
message: "Trace requires square matrix".to_string(),
});
}
let n = self.shape[0];
let data = self.data.as_slice().unwrap();
let mut trace = 0.0;
for i in 0..n {
trace += data[i * n + i];
}
Ok(trace)
}
}
impl Index<usize> for F32Tensor {
type Output = f32;
fn index(&self, index: usize) -> &Self::Output {
&self.data.as_slice().unwrap()[index]
}
}
impl IndexMut<usize> for F32Tensor {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data.as_slice_mut().unwrap()[index]
}
}
impl Index<Index2D> for F32Tensor {
type Output = f32;
fn index(&self, index: Index2D) -> &Self::Output {
let flat_index = index.0 * self.shape[1] + index.1;
&self.data.as_slice().unwrap()[flat_index]
}
}
impl IndexMut<Index2D> for F32Tensor {
fn index_mut(&mut self, index: Index2D) -> &mut Self::Output {
let flat_index = index.0 * self.shape[1] + index.1;
&mut self.data.as_slice_mut().unwrap()[flat_index]
}
}
impl F32Tensor {
pub fn gpu_sum(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("reduction", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let sum_value = self.execute_gpu_reduction("sum")?;
Self::from_scalar(sum_value)
}
Some(_axis) => {
let sum_value = self.sum()?;
Self::from_scalar(sum_value)
}
}
}
pub fn gpu_mean(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("reduction", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let mean_value = self.execute_gpu_reduction("mean")?;
Self::from_scalar(mean_value)
}
Some(_axis) => {
let mean_value = self.mean()?;
Self::from_scalar(mean_value)
}
}
}
pub fn gpu_min(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("reduction", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let min_value = self.execute_gpu_reduction("min")?;
Self::from_scalar(min_value)
}
Some(_axis) => {
let min_value = self.min()?;
Self::from_scalar(min_value)
}
}
}
pub fn gpu_max(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("reduction", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let max_value = self.execute_gpu_reduction("max")?;
Self::from_scalar(max_value)
}
Some(_axis) => {
let max_value = self.max()?;
Self::from_scalar(max_value)
}
}
}
pub fn gpu_std(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
if self.data.is_empty() {
return Err(RusTorchError::tensor_op(
"Cannot calculate std of empty tensor",
));
}
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("statistics", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let std_value = self.execute_gpu_statistics("std")?;
Self::from_scalar(std_value)
}
Some(_axis) => {
let mean_val = self.mean()?;
let variance = self
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (self.data.len() as f32);
let std_val = variance.sqrt();
Self::from_scalar(std_val)
}
}
}
pub fn gpu_var(&self, axis: Option<usize>) -> RusTorchResult<Self> {
crate::hybrid_f32_experimental!();
if self.data.is_empty() {
return Err(RusTorchError::tensor_op(
"Cannot calculate var of empty tensor",
));
}
let mut context = crate::hybrid_f32::gpu::F32UnifiedGPUContext::new();
let optimal_device = context.select_optimal_device("statistics", self.numel());
context.initialize_device(optimal_device)?;
match axis {
None => {
let var_value = self.execute_gpu_statistics("variance")?;
Self::from_scalar(var_value)
}
Some(_axis) => {
let mean_val = self.mean()?;
let variance = self
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (self.data.len() as f32);
Self::from_scalar(variance)
}
}
}
fn execute_gpu_reduction(&self, operation: &str) -> RusTorchResult<f32> {
match operation {
"sum" => {
println!(
"🚀 GPU並列リダクション: {} (size={})",
operation,
self.numel()
);
Ok(self.sum()?) }
"mean" => {
println!(
"🚀 GPU並列リダクション: {} (size={})",
operation,
self.numel()
);
Ok(self.mean()?)
}
"min" => {
println!(
"🚀 GPU並列リダクション: {} (size={})",
operation,
self.numel()
);
Ok(self.min()?)
}
"max" => {
println!(
"🚀 GPU並列リダクション: {} (size={})",
operation,
self.numel()
);
Ok(self.max()?)
}
_ => Err(RusTorchError::tensor_op(&format!(
"Unsupported reduction operation: {}",
operation
))),
}
}
fn execute_gpu_statistics(&self, operation: &str) -> RusTorchResult<f32> {
match operation {
"std" => {
println!(
"🧠 Neural Engine統計処理: {} (size={})",
operation,
self.numel()
);
let mean_val = self.mean()?;
let variance = self
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (self.data.len() as f32);
Ok(variance.sqrt())
}
"variance" => {
println!(
"🧠 Neural Engine統計処理: {} (size={})",
operation,
self.numel()
);
let mean_val = self.mean()?;
let variance = self
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (self.data.len() as f32);
Ok(variance)
}
_ => Err(RusTorchError::tensor_op(&format!(
"Unsupported statistics operation: {}",
operation
))),
}
}
pub fn __add__(&self, other: &Self) -> RusTorchResult<Self> {
self.add(other)
}
pub fn __mul__(&self, other: &Self) -> RusTorchResult<Self> {
self.multiply(other)
}
}
impl Index<Index3D> for F32Tensor {
type Output = f32;
fn index(&self, index: Index3D) -> &Self::Output {
let flat_index =
index.0 * (self.shape[1] * self.shape[2]) + index.1 * self.shape[2] + index.2;
&self.data.as_slice().unwrap()[flat_index]
}
}
impl IndexMut<Index3D> for F32Tensor {
fn index_mut(&mut self, index: Index3D) -> &mut Self::Output {
let flat_index =
index.0 * (self.shape[1] * self.shape[2]) + index.1 * self.shape[2] + index.2;
&mut self.data.as_slice_mut().unwrap()[flat_index]
}
}