use super::aligned::{SimdAllocator, SIMD_ALIGNMENT};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
#[derive(Debug, Clone)]
pub struct TensorMemoryInfo {
pub total_elements: usize,
pub element_size: usize,
pub total_bytes: usize,
pub is_contiguous: bool,
pub alignment: usize,
pub is_on_gpu: bool,
pub device: String,
}
pub trait MemoryOptimization<T: Float> {
fn memory_info(&self) -> TensorMemoryInfo;
fn can_optimize_memory(&self) -> bool;
fn optimize_memory(&self) -> Self;
fn try_optimize_memory(&self) -> RusTorchResult<Self>
where
Self: Sized;
}
impl<T: Float + Clone + 'static> MemoryOptimization<T> for Tensor<T> {
fn memory_info(&self) -> TensorMemoryInfo {
let element_size = std::mem::size_of::<T>();
let total_elements = self.data.len();
let total_bytes = total_elements * element_size;
let ptr = self.data.as_ptr();
let alignment = if SimdAllocator::is_aligned(ptr) {
SIMD_ALIGNMENT
} else {
if (ptr as usize) % 16 == 0 {
16
} else if (ptr as usize) % 8 == 0 {
8
} else if (ptr as usize) % 4 == 0 {
4
} else {
1
}
};
let is_on_gpu = self.is_on_gpu();
let device = self.device_type().to_string();
TensorMemoryInfo {
total_elements,
element_size,
total_bytes,
is_contiguous: self.data.is_standard_layout(),
alignment,
is_on_gpu,
device,
}
}
fn can_optimize_memory(&self) -> bool {
let info = self.memory_info();
info.total_bytes > 1024 && info.alignment < SIMD_ALIGNMENT
}
fn optimize_memory(&self) -> Self {
if !self.can_optimize_memory() {
return self.clone();
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let shape = self.shape();
let len = self.numel();
if let Ok(ptr) = SimdAllocator::alloc_f32(len) {
unsafe {
let src = self.data.as_ptr();
let dst = ptr.as_ptr();
std::ptr::copy_nonoverlapping(src as *const f32, dst, len);
let aligned_data = Vec::from_raw_parts(dst, len, len);
let aligned_data_t: Vec<T> = std::mem::transmute(aligned_data);
match Self::try_from_vec(aligned_data_t, shape.to_vec()) {
Ok(tensor) => return tensor,
Err(_) => {
SimdAllocator::dealloc_f32(ptr, len);
}
}
}
}
}
self.clone()
}
fn try_optimize_memory(&self) -> RusTorchResult<Self> {
let info = self.memory_info();
const MAX_OPTIMIZE_SIZE: usize = 1_000_000_000; if info.total_bytes > MAX_OPTIMIZE_SIZE {
return Err(RusTorchError::TensorOp {
message: format!(
"Tensor too large to optimize: {} bytes exceeds maximum of {} bytes",
info.total_bytes, MAX_OPTIMIZE_SIZE
),
source: None,
});
}
Ok(self.optimize_memory())
}
}