use ferrum_types::{DataType, Device, Result};
use std::any::Any;
use std::sync::Arc;
pub trait TensorLike: Send + Sync + std::fmt::Debug {
fn as_any(&self) -> &dyn Any;
fn shape(&self) -> &[usize];
fn dtype(&self) -> DataType;
fn device(&self) -> Device;
fn numel(&self) -> usize {
self.shape().iter().product()
}
fn ndim(&self) -> usize {
self.shape().len()
}
fn is_scalar(&self) -> bool {
self.shape().is_empty()
}
fn is_contiguous(&self) -> bool;
fn size_bytes(&self) -> usize {
self.numel() * self.dtype().size_bytes()
}
fn view(&self, start: &[usize], end: &[usize]) -> Result<TensorRef>;
fn reshape(&self, shape: &[usize]) -> Result<TensorRef>;
fn to_cpu(&self) -> Result<TensorRef>;
fn to_device(&self, device: &Device) -> Result<TensorRef>;
fn to_dtype(&self, dtype: DataType) -> Result<TensorRef>;
fn to_vec_f32(&self) -> Result<Vec<f32>> {
Err(crate::FerrumError::model(
"to_vec_f32 not implemented for this tensor backend",
))
}
fn to_vec_u32(&self) -> Result<Vec<u32>> {
Err(crate::FerrumError::model(
"to_vec_u32 not implemented for this tensor backend",
))
}
fn argmax_last_dim_u32(&self) -> Result<u32> {
Err(crate::FerrumError::model(
"argmax_last_dim_u32 not implemented for this tensor backend",
))
}
}
pub type TensorRef = Arc<dyn TensorLike>;
pub trait TensorFactory: Send + Sync {
fn empty(&self, shape: &[usize], dtype: DataType, device: Device) -> Result<TensorRef>;
fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef>;
fn from_slice(
&self,
data: &[f32],
shape: &[usize],
dtype: DataType,
device: Device,
) -> Result<TensorRef>;
fn to_device(&self, tensor: &TensorRef, device: Device) -> Result<TensorRef>;
fn narrow(
&self,
tensor: &TensorRef,
dim: usize,
start: usize,
length: usize,
) -> Result<TensorRef>;
fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef>;
fn zeros(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
fn ones(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
fn uniform(
&self,
shape: &[usize],
low: f32,
high: f32,
dtype: DataType,
device: &Device,
) -> Result<TensorRef>;
fn normal(
&self,
shape: &[usize],
mean: f32,
std: f32,
dtype: DataType,
device: &Device,
) -> Result<TensorRef>;
fn from_tensor(&self, tensor: &TensorRef, device: &Device) -> Result<TensorRef>;
}
pub trait TensorOps: Send + Sync {
fn matmul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
fn add(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
fn sub(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
fn mul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
fn div(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
fn softmax(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
fn layer_norm(
&self,
input: &TensorRef,
weight: &TensorRef,
bias: Option<&TensorRef>,
eps: f32,
) -> Result<TensorRef>;
fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
fn relu(&self, tensor: &TensorRef) -> Result<TensorRef>;
fn gelu(&self, tensor: &TensorRef) -> Result<TensorRef>;
fn silu(&self, tensor: &TensorRef) -> Result<TensorRef>;
fn concat(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
fn split(&self, tensor: &TensorRef, sizes: &[usize], dim: usize) -> Result<Vec<TensorRef>>;
fn transpose(&self, tensor: &TensorRef, dim0: usize, dim1: usize) -> Result<TensorRef>;
fn permute(&self, tensor: &TensorRef, dims: &[usize]) -> Result<TensorRef>;
}
#[async_trait::async_trait]
pub trait AsyncTensorOps: TensorOps {
async fn matmul_async(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
async fn softmax_async(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
async fn synchronize(&self) -> Result<()>;
}
pub trait TensorBatchOps: Send + Sync {
fn batch_matmul(
&self,
a_batch: &[&TensorRef],
b_batch: &[&TensorRef],
) -> Result<Vec<TensorRef>>;
fn stack(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
fn unstack(&self, tensor: &TensorRef, dim: usize) -> Result<Vec<TensorRef>>;
fn pad_batch(&self, tensors: &[&TensorRef], target_shape: &[usize]) -> Result<Vec<TensorRef>>;
}
pub trait TensorMemoryManager: Send + Sync {
fn preallocate(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
fn clear(&self, tensor: &TensorRef) -> Result<()>;
fn memory_stats(&self) -> TensorMemoryStats;
fn gc(&self) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct TensorMemoryStats {
pub total_allocated: usize,
pub used_memory: usize,
pub active_tensors: usize,
pub peak_memory: usize,
}
pub trait TensorDataAccess {
fn data_f32(&self) -> Option<&[f32]>;
fn data_bytes(&self) -> Option<&[u8]>;
fn to_vec_f32(&self) -> Result<Vec<f32>>;
fn to_vec_u8(&self) -> Result<Vec<u8>>;
}
pub mod utils {
use super::*;
pub fn matmul_output_shape(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
if a_shape.len() < 2 || b_shape.len() < 2 {
return Err(ferrum_types::FerrumError::backend(
"Matrix multiplication requires at least 2D tensors",
));
}
let a_rows = a_shape[a_shape.len() - 2];
let a_cols = a_shape[a_shape.len() - 1];
let b_rows = b_shape[b_shape.len() - 2];
let b_cols = b_shape[b_shape.len() - 1];
if a_cols != b_rows {
return Err(ferrum_types::FerrumError::backend(format!(
"Matrix dimensions mismatch: {} vs {}",
a_cols, b_rows
)));
}
let mut output_shape = a_shape[..a_shape.len() - 2].to_vec();
output_shape.push(a_rows);
output_shape.push(b_cols);
Ok(output_shape)
}
pub fn are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
let max_ndim = shape1.len().max(shape2.len());
for i in 0..max_ndim {
let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
if !are_broadcastable(shape1, shape2) {
return None;
}
let max_ndim = shape1.len().max(shape2.len());
let mut output_shape = Vec::with_capacity(max_ndim);
for i in 0..max_ndim {
let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
output_shape.push(dim1.max(dim2));
}
output_shape.reverse();
Some(output_shape)
}
}