use crate::{Tensor, TensorStorage};
use std::collections::HashMap;
use std::sync::{Arc, RwLock, Weak};
use torsh_core::{
device::DeviceType,
dtype::TensorElement,
error::{Result, TorshError},
shape::Shape,
};
#[derive(Debug, Clone)]
pub struct TensorView<T: TensorElement> {
storage: Arc<RwLock<ViewStorage<T>>>,
shape: Shape,
strides: Vec<usize>,
offset: usize,
device: DeviceType,
}
#[derive(Debug)]
struct ViewStorage<T: TensorElement> {
#[allow(dead_code)]
parent: Weak<RwLock<Vec<T>>>,
data_ref: Option<Arc<RwLock<Vec<T>>>>,
view_cache: HashMap<ViewKey, Arc<TensorView<T>>>,
view_count: usize,
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
struct ViewKey {
shape: Vec<usize>,
strides: Vec<usize>,
offset: usize,
}
impl<T: TensorElement + Copy> Tensor<T> {
pub fn calculate_strides(&self) -> Vec<usize> {
let shape_binding = self.shape();
let dims = shape_binding.dims();
let mut strides = vec![1; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
strides
}
pub fn create_view(&self, new_shape: &[usize]) -> Result<TensorView<T>> {
let new_numel = new_shape.iter().product::<usize>();
if new_numel != self.numel() {
return Err(TorshError::InvalidOperation(format!(
"View shape {:?} has {} elements, but tensor has {} elements",
new_shape,
new_numel,
self.numel()
)));
}
let mut strides = vec![1; new_shape.len()];
for i in (0..new_shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * new_shape[i + 1];
}
self.create_view_with_strides(new_shape, &strides, 0)
}
pub fn view_with_strides(
&self,
new_shape: &[usize],
strides: &[usize],
) -> Result<TensorView<T>> {
if new_shape.len() != strides.len() {
return Err(TorshError::InvalidOperation(
"Shape and strides must have same length".to_string(),
));
}
self.create_view_with_strides(new_shape, strides, 0)
}
pub fn slice(&self, dim: usize, start: usize, end: usize) -> Result<TensorView<T>> {
let shape_binding = self.shape();
let dims = shape_binding.dims();
if dim >= dims.len() {
return Err(TorshError::InvalidOperation(format!(
"Dimension {} out of bounds for tensor with {} dimensions",
dim,
dims.len()
)));
}
if start >= end || end > dims[dim] {
return Err(TorshError::InvalidOperation(format!(
"Invalid slice range [{}:{}] for dimension of size {}",
start, end, dims[dim]
)));
}
let mut new_shape = dims.to_vec();
new_shape[dim] = end - start;
let strides = self.calculate_strides();
let offset = start * strides[dim];
self.create_view_with_strides(&new_shape, &strides, offset)
}
fn create_view_with_strides(
&self,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Result<TensorView<T>> {
let data_ref = match &self.storage {
TensorStorage::InMemory(data) => data.clone(),
TensorStorage::MemoryMapped(_) => {
let data = self.to_vec()?;
Arc::new(RwLock::new(data))
}
#[cfg(feature = "simd")]
TensorStorage::Aligned(data) => {
let aligned_data = data.read().expect("lock should not be poisoned");
let vec_data = aligned_data.as_slice().to_vec();
Arc::new(RwLock::new(vec_data))
}
#[cfg(feature = "simd")]
TensorStorage::SimdOptimized(storage) => {
let vec_data = storage.as_slice().to_vec();
Arc::new(RwLock::new(vec_data))
}
};
let view_storage = ViewStorage {
parent: Arc::downgrade(&data_ref),
data_ref: Some(data_ref),
view_cache: HashMap::new(),
view_count: 1,
};
Ok(TensorView {
storage: Arc::new(RwLock::new(view_storage)),
shape: Shape::new(shape.to_vec()),
strides: strides.to_vec(),
offset,
device: self.device,
})
}
pub fn alias(&self) -> TensorAlias<T> {
TensorAlias {
tensor: self.clone(),
is_mutable: false,
}
}
pub fn alias_mut(&mut self) -> TensorAlias<T> {
TensorAlias {
tensor: self.clone(),
is_mutable: true,
}
}
}
impl<T: TensorElement + Copy> TensorView<T> {
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn to_tensor(&self) -> Result<Tensor<T>> {
let data = self.to_vec()?;
Tensor::from_data(data, self.shape.dims().to_vec(), self.device)
}
pub fn to_vec(&self) -> Result<Vec<T>> {
let storage = self.storage.read().expect("lock should not be poisoned");
if let Some(data_ref) = &storage.data_ref {
let data = data_ref.read().expect("lock should not be poisoned");
let mut result = Vec::with_capacity(self.shape.numel());
self.extract_view_data(&data, &mut result, &mut vec![0; self.shape.ndim()], 0)?;
Ok(result)
} else {
Err(TorshError::InvalidOperation(
"View data no longer available".to_string(),
))
}
}
fn extract_view_data(
&self,
data: &[T],
result: &mut Vec<T>,
indices: &mut [usize],
dim: usize,
) -> Result<()> {
if dim == self.shape.ndim() {
let flat_index = self.offset
+ indices
.iter()
.zip(self.strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum::<usize>();
if flat_index < data.len() {
result.push(data[flat_index]);
} else {
return Err(TorshError::InvalidOperation(
"View index out of bounds".to_string(),
));
}
} else {
for i in 0..self.shape.dims()[dim] {
indices[dim] = i;
self.extract_view_data(data, result, indices, dim + 1)?;
}
}
Ok(())
}
pub fn is_contiguous(&self) -> bool {
let dims = self.shape.dims();
let mut expected_strides = vec![1; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
expected_strides[i] = expected_strides[i + 1] * dims[i + 1];
}
self.strides == expected_strides
}
pub fn is_view(&self) -> bool {
true
}
pub fn get(&self, indices: &[usize]) -> Result<T> {
if indices.len() != self.shape.ndim() {
return Err(TorshError::InvalidOperation(format!(
"Expected {} indices, got {}",
self.shape.ndim(),
indices.len()
)));
}
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape.dims()[i] {
return Err(TorshError::InvalidOperation(format!(
"Index {} out of bounds for dimension {} (size {})",
idx,
i,
self.shape.dims()[i]
)));
}
}
let storage = self.storage.read().expect("lock should not be poisoned");
if let Some(data_ref) = &storage.data_ref {
let data = data_ref.read().expect("lock should not be poisoned");
let flat_index = self.offset
+ indices
.iter()
.zip(self.strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum::<usize>();
if flat_index < data.len() {
Ok(data[flat_index])
} else {
Err(TorshError::InvalidOperation(
"View index out of bounds".to_string(),
))
}
} else {
Err(TorshError::InvalidOperation(
"View data no longer available".to_string(),
))
}
}
pub fn view_memory_usage(&self) -> ViewMemoryUsage {
let storage = self.storage.read().expect("lock should not be poisoned");
ViewMemoryUsage {
view_elements: self.shape.numel(),
total_elements: storage
.data_ref
.as_ref()
.map(|data| data.read().expect("lock should not be poisoned").len())
.unwrap_or(0),
active_views: storage.view_count,
is_contiguous: self.is_contiguous(),
memory_efficiency: self.calculate_memory_efficiency(),
}
}
fn calculate_memory_efficiency(&self) -> f64 {
let view_size = self.shape.numel();
let storage = self.storage.read().expect("lock should not be poisoned");
let total_size = storage
.data_ref
.as_ref()
.map(|data| data.read().expect("lock should not be poisoned").len())
.unwrap_or(1);
view_size as f64 / total_size as f64
}
}
#[derive(Debug, Clone)]
pub struct TensorAlias<T: TensorElement> {
tensor: Tensor<T>,
is_mutable: bool,
}
impl<T: TensorElement + Copy> TensorAlias<T> {
pub fn tensor(&self) -> &Tensor<T> {
&self.tensor
}
pub fn is_mutable(&self) -> bool {
self.is_mutable
}
pub fn to_owned(&self) -> Result<Tensor<T>> {
Ok(self.tensor.clone())
}
pub fn ref_count(&self) -> usize {
match &self.tensor.storage {
TensorStorage::InMemory(data) => Arc::strong_count(data),
TensorStorage::MemoryMapped(storage) => Arc::strong_count(storage),
#[cfg(feature = "simd")]
TensorStorage::Aligned(data) => Arc::strong_count(data),
#[cfg(feature = "simd")]
TensorStorage::SimdOptimized(storage) => Arc::strong_count(storage),
}
}
pub fn is_unique(&self) -> bool {
self.ref_count() == 1
}
}
#[derive(Debug, Clone)]
pub struct ViewMemoryUsage {
pub view_elements: usize,
pub total_elements: usize,
pub active_views: usize,
pub is_contiguous: bool,
pub memory_efficiency: f64,
}
impl<T: TensorElement + Copy> Drop for ViewStorage<T> {
fn drop(&mut self) {
self.view_cache.clear();
self.view_count = 0;
}
}
#[cfg(test)]
mod tests {
use crate::creation::*;
#[test]
fn test_tensor_view() {
let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
let view = tensor
.create_view(&[6, 4])
.expect("create_view should succeed");
assert_eq!(view.shape().dims(), &[6, 4]);
assert_eq!(view.shape().numel(), 24);
}
#[test]
fn test_tensor_slice() {
let tensor = arange(0.0f32, 12.0, 1.0).expect("arange should succeed");
let _reshaped = tensor
.create_view(&[3, 4])
.expect("create_view should succeed");
}
#[test]
fn test_tensor_squeeze_unsqueeze() {
let tensor = ones::<f32>(&[1, 3, 1, 4]).expect("ones creation should succeed");
let squeezed = tensor.squeeze(0).expect("squeeze should succeed");
assert_eq!(squeezed.shape().dims(), &[3, 1, 4]);
let squeezed_all = tensor.squeeze_all().expect("squeeze_all should succeed");
assert_eq!(squeezed_all.shape().dims(), &[3, 4]);
let unsqueezed = tensor.unsqueeze(2).expect("unsqueeze should succeed");
assert_eq!(unsqueezed.shape().dims(), &[1, 3, 1, 1, 4]);
}
#[test]
fn test_tensor_permute() {
let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
let permuted = tensor.permute(&[2, 0, 1]).expect("permute should succeed");
assert_eq!(permuted.shape().dims(), &[4, 2, 3]);
}
#[test]
fn test_tensor_alias() {
let tensor = ones::<f32>(&[10, 10]).expect("ones creation should succeed");
let alias = tensor.alias();
assert!(!alias.is_mutable());
assert!(alias.ref_count() >= 2); }
#[test]
fn test_view_memory_usage() {
let tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
let view = tensor
.create_view(&[1000, 10])
.expect("create_view should succeed");
let usage = view.view_memory_usage();
assert_eq!(usage.view_elements, 10000);
assert_eq!(usage.memory_efficiency, 1.0); }
}