use crate::*;
impl TransformOps for CpuStorage {
fn apply_assign<F>(&mut self, op: F)
where
F: Fn(f32) -> f32,
{
let data = self.data().read().unwrap().iter()
.map(|a| op(*a))
.collect();
self.set_data(data);
}
fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
where
F: Fn(f32, f32) -> f32,
{
let total_elements = self.shape().iter().product();
let mut result = vec![0.0; total_elements];
let self_binding = self.data();
let self_data = self_binding.read().unwrap();
let other_binding = other.data();
let other_data = other_binding.read().unwrap();
let rank = self.shape().len();
let shape = self.shape();
let self_strides = self.stride();
let other_strides = other.stride();
let mut chunk_size = 1;
let mut contiguous_dims = 0;
for dim in (0..rank).rev() {
if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
chunk_size *= shape[dim];
contiguous_dims += 1;
} else {
break;
}
}
let outer_dims = rank - contiguous_dims;
let mut indices = vec![0; outer_dims];
let chunks = total_elements / chunk_size;
for chunk_idx in 0..chunks {
let mut self_base_idx = 0;
let mut other_base_idx = 0;
for (dim, &idx) in indices.iter().enumerate() {
self_base_idx += idx * self_strides[dim];
other_base_idx += idx * other_strides[dim];
}
let result_start = chunk_idx * chunk_size;
for i in 0..chunk_size {
let self_val = self_data[self_base_idx + i];
let other_val = other_data[other_base_idx + i];
result[result_start + i] = op(self_val, other_val);
}
for dim in (0..outer_dims).rev() {
indices[dim] += 1;
if indices[dim] < shape[dim] {
break;
}
indices[dim] = 0;
}
}
self.set_data(result);
}
fn reshape(&mut self, new_shape: Vec<usize>) {
self.set_shape(new_shape);
}
fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
where
F: Fn(f32, f32) -> f32,
{
let data = self.data().read().unwrap().iter()
.map(|a| op(*a, scalar))
.collect();
self.set_data(data);
}
fn permute(&mut self, dims: &[usize]) {
let self_shape = self.shape();
let shape = dims.iter().map(|&i| self_shape[i]).collect();
let self_stride = self.stride();
let stride = dims.iter().map(|&i| self_stride[i]).collect();
self.set_shape(shape);
self.set_stride(stride);
}
fn flatten(&mut self) {
let shape: Vec<usize> = vec![self.shape().iter().product()];
let stride = vec![1];
self.set_shape(shape);
self.set_stride(stride);
}
fn squeeze(&mut self) {
let shape: Vec<usize> = self.shape().to_owned().iter().filter(|&&x| x != 1).cloned().collect();
let stride = Self::compute_strides(&shape);
self.set_shape(shape);
self.set_stride(stride);
}
fn unsqueeze(&mut self, dim: usize) {
let mut shape: Vec<usize> = self.shape().to_owned();
shape.insert(dim, 1);
let stride = Self::compute_strides(&shape);
self.set_shape(shape);
self.set_stride(stride);
}
fn apply<F>(&self, op: F) -> Self
where
F: Fn(f32) -> f32,
{
let data = self.data().read().unwrap().iter()
.map(|a| op(*a))
.collect();
Self::new(data, self.shape().clone())
}
fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
where
F: Fn(f32, f32) -> f32,
{
let total_elements = self.shape().iter().product();
let mut result = vec![0.0; total_elements];
let self_binding = self.data();
let self_data = self_binding.read().unwrap();
let other_binding = other.data();
let other_data = other_binding.read().unwrap();
let rank = self.shape().len();
let shape = self.shape();
let self_strides = self.stride();
let other_strides = other.stride();
let mut chunk_size = 1;
let mut contiguous_dims = 0;
for dim in (0..rank).rev() {
if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
chunk_size *= shape[dim];
contiguous_dims += 1;
} else {
break;
}
}
let outer_dims = rank - contiguous_dims;
let mut indices = vec![0; outer_dims];
let chunks = total_elements / chunk_size;
for chunk_idx in 0..chunks {
let mut self_base_idx = 0;
let mut other_base_idx = 0;
for (dim, &idx) in indices.iter().enumerate() {
self_base_idx += idx * self_strides[dim];
other_base_idx += idx * other_strides[dim];
}
let result_start = chunk_idx * chunk_size;
for i in 0..chunk_size {
let self_val = self_data[self_base_idx + i];
let other_val = other_data[other_base_idx + i];
result[result_start + i] = op(self_val, other_val);
}
for dim in (0..outer_dims).rev() {
indices[dim] += 1;
if indices[dim] < shape[dim] {
break;
}
indices[dim] = 0;
}
}
Self::new(result, self.shape().clone())
}
fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
where
F: Fn(f32, f32) -> f32,
{
let data = self.data().read().unwrap().iter()
.map(|a| op(*a, scalar))
.collect();
Self::new(data, self.shape().clone())
}
fn sum_dim(&self, dims: &[bool]) -> Self {
if self.shape().len() == 1 && self.shape()[0] == 1 {
return self.clone();
}
let mut new_shape: Vec<usize> = self.shape().iter()
.zip(dims.iter().chain(std::iter::repeat(&false)))
.filter_map(|(&dim, &should_sum)| if !should_sum { Some(dim) } else { None })
.collect();
if new_shape.is_empty() {
let sum: f32 = self.data().read().unwrap().iter().sum();
return Self::new(vec![sum], vec![1]);
}
if new_shape.is_empty() {
new_shape.push(1);
}
let mut result = vec![0.0; new_shape.iter().product()];
let mut sum = 0.0;
let binding = self.data();
let data = binding.read().unwrap();
for i in 0..data.len() {
sum += data[i];
}
result[0] = sum;
Self::new(result, new_shape)
}
fn transpose(&self) -> Self {
if self.shape().len() != 2 { panic!("Must be 2-D Tensor (Matrix)"); }
let mut shape = self.shape().to_owned();
shape.reverse();
let mut stride = self.stride().to_owned();
stride.reverse();
Self::create(self.data(), shape, stride)
}
fn broadcast(&self, new_shape: &[usize]) -> Self {
let broadcast_shape = self.compute_broadcast_shape(new_shape);
let broadcast_strides = self.compute_broadcast_strides(&broadcast_shape);
Self::create(self.data(), broadcast_shape, broadcast_strides)
}
fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
let self_rank = self.shape().len();
let target_rank = target_shape.len();
let max_rank = std::cmp::max(self_rank, target_rank);
let self_padded = self.pad_shape(max_rank);
let mut result_shape = Vec::with_capacity(max_rank);
for i in 0..max_rank {
let self_dim = self_padded[i];
let target_dim = if i >= max_rank - target_rank {
target_shape[i - (max_rank - target_rank)]
} else {
1
};
if self_dim == target_dim {
result_shape.push(self_dim);
} else if self_dim == 1 {
result_shape.push(target_dim);
} else if target_dim == 1 {
result_shape.push(self_dim);
} else {
panic!(
"Incompatible broadcast dimensions: {} and {}",
self_dim, target_dim
)
}
}
result_shape
}
fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
let self_rank = self.shape().len();
let broadcast_rank = broadcast_shape.len();
let rank_diff = broadcast_rank - self_rank;
let mut new_strides = vec![0; broadcast_rank];
for i in 0..self_rank {
let broadcast_idx = i + rank_diff;
if broadcast_shape[broadcast_idx] == self.shape()[i] {
new_strides[broadcast_idx] = self.stride()[i];
} else if self.shape()[i] == 1 {
new_strides[broadcast_idx] = 0; } else {
panic!("Invalid broadcast shape")
}
}
new_strides
}
fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
let mut padded = vec![1; target_rank];
let rank_diff = target_rank - self.shape().len();
padded[rank_diff..].copy_from_slice(self.shape());
padded
}
fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) {
let broadcast_shape = a.compute_broadcast_shape(b.shape());
let broadcast_a = a.broadcast(&broadcast_shape);
let broadcast_b = b.broadcast(&broadcast_shape);
(broadcast_a, broadcast_b)
}
}