use anyhow::{Result, anyhow};
use std::sync::Arc;
use crate::block_manager::v2::memory::TorchTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorFormat {
NHD,
HND,
Unknown,
}
pub fn validate_tensor_strides(tensors: &[Arc<dyn TorchTensor>]) -> Result<TensorFormat> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let mut format = TensorFormat::Unknown;
for tensor in tensors {
let stride = tensor.stride();
let shape = tensor.shape();
if stride.len() < 2 {
return Err(anyhow!(
"Tensor must have at least 2 dimensions, got stride: {:?}",
stride
));
}
let mut prev_stride = usize::MAX;
for (i, ¤t_stride) in stride.iter().enumerate() {
if current_stride > prev_stride {
return Err(anyhow!(
"Tensor strides must be monotonically decreasing (until inner dimension). \
Got stride: {:?} at position {}",
stride,
i
));
}
prev_stride = current_stride;
}
if shape.len() >= 3 {
if stride[0] < stride[1] {
format = TensorFormat::HND;
} else if stride[0] > stride[1] {
format = TensorFormat::NHD;
}
}
}
Ok(format)
}
pub fn validate_tensor_shapes(tensors: &[Arc<dyn TorchTensor>]) -> Result<Vec<usize>> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(anyhow!(
"All tensors must have the same shape. Expected {:?}, got {:?}",
first_shape,
tensor.shape()
));
}
}
Ok(first_shape)
}
#[allow(dead_code)]
pub fn determine_compressed_shape(shape: &[usize]) -> usize {
shape.iter().product()
}
#[cfg(test)]
mod tests {
}