1use ferrum_types::{DataType, Device, Result};
8use std::any::Any;
9use std::sync::Arc;
10
11pub trait TensorLike: Send + Sync + std::fmt::Debug {
13 fn as_any(&self) -> &dyn Any;
15
16 fn shape(&self) -> &[usize];
18
19 fn dtype(&self) -> DataType;
21
22 fn device(&self) -> Device;
24
25 fn numel(&self) -> usize {
27 self.shape().iter().product()
28 }
29
30 fn ndim(&self) -> usize {
32 self.shape().len()
33 }
34
35 fn is_scalar(&self) -> bool {
37 self.shape().is_empty()
38 }
39
40 fn is_contiguous(&self) -> bool;
42
43 fn size_bytes(&self) -> usize {
45 self.numel() * self.dtype().size_bytes()
46 }
47
48 fn view(&self, start: &[usize], end: &[usize]) -> Result<TensorRef>;
50
51 fn reshape(&self, shape: &[usize]) -> Result<TensorRef>;
53
54 fn to_cpu(&self) -> Result<TensorRef>;
56
57 fn to_device(&self, device: &Device) -> Result<TensorRef>;
59
60 fn to_dtype(&self, dtype: DataType) -> Result<TensorRef>;
62
63 fn to_vec_f32(&self) -> Result<Vec<f32>> {
66 Err(crate::FerrumError::model(
68 "to_vec_f32 not implemented for this tensor backend",
69 ))
70 }
71
72 fn to_vec_u32(&self) -> Result<Vec<u32>> {
75 Err(crate::FerrumError::model(
77 "to_vec_u32 not implemented for this tensor backend",
78 ))
79 }
80
81 fn argmax_last_dim_u32(&self) -> Result<u32> {
85 Err(crate::FerrumError::model(
86 "argmax_last_dim_u32 not implemented for this tensor backend",
87 ))
88 }
89}
90
91pub type TensorRef = Arc<dyn TensorLike>;
93
94pub trait TensorFactory: Send + Sync {
96 fn empty(&self, shape: &[usize], dtype: DataType, device: Device) -> Result<TensorRef>;
98 fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef>;
100 fn from_slice(
102 &self,
103 data: &[f32],
104 shape: &[usize],
105 dtype: DataType,
106 device: Device,
107 ) -> Result<TensorRef>;
108 fn to_device(&self, tensor: &TensorRef, device: Device) -> Result<TensorRef>;
110 fn narrow(
112 &self,
113 tensor: &TensorRef,
114 dim: usize,
115 start: usize,
116 length: usize,
117 ) -> Result<TensorRef>;
118 fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef>;
120
121 fn zeros(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
123
124 fn ones(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
126
127 fn uniform(
129 &self,
130 shape: &[usize],
131 low: f32,
132 high: f32,
133 dtype: DataType,
134 device: &Device,
135 ) -> Result<TensorRef>;
136
137 fn normal(
139 &self,
140 shape: &[usize],
141 mean: f32,
142 std: f32,
143 dtype: DataType,
144 device: &Device,
145 ) -> Result<TensorRef>;
146
147 fn from_tensor(&self, tensor: &TensorRef, device: &Device) -> Result<TensorRef>;
149}
150
151pub trait TensorOps: Send + Sync {
153 fn matmul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
155
156 fn add(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
158
159 fn sub(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
161
162 fn mul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
164
165 fn div(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
167
168 fn softmax(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
170
171 fn layer_norm(
173 &self,
174 input: &TensorRef,
175 weight: &TensorRef,
176 bias: Option<&TensorRef>,
177 eps: f32,
178 ) -> Result<TensorRef>;
179
180 fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
182
183 fn relu(&self, tensor: &TensorRef) -> Result<TensorRef>;
185
186 fn gelu(&self, tensor: &TensorRef) -> Result<TensorRef>;
188
189 fn silu(&self, tensor: &TensorRef) -> Result<TensorRef>;
191
192 fn concat(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
194
195 fn split(&self, tensor: &TensorRef, sizes: &[usize], dim: usize) -> Result<Vec<TensorRef>>;
197
198 fn transpose(&self, tensor: &TensorRef, dim0: usize, dim1: usize) -> Result<TensorRef>;
200
201 fn permute(&self, tensor: &TensorRef, dims: &[usize]) -> Result<TensorRef>;
203}
204
205#[async_trait::async_trait]
207pub trait AsyncTensorOps: TensorOps {
208 async fn matmul_async(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
210
211 async fn softmax_async(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
213
214 async fn synchronize(&self) -> Result<()>;
216}
217
218pub trait TensorBatchOps: Send + Sync {
220 fn batch_matmul(
222 &self,
223 a_batch: &[&TensorRef],
224 b_batch: &[&TensorRef],
225 ) -> Result<Vec<TensorRef>>;
226
227 fn stack(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
229
230 fn unstack(&self, tensor: &TensorRef, dim: usize) -> Result<Vec<TensorRef>>;
232
233 fn pad_batch(&self, tensors: &[&TensorRef], target_shape: &[usize]) -> Result<Vec<TensorRef>>;
235}
236
237pub trait TensorMemoryManager: Send + Sync {
239 fn preallocate(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
241
242 fn clear(&self, tensor: &TensorRef) -> Result<()>;
244
245 fn memory_stats(&self) -> TensorMemoryStats;
247
248 fn gc(&self) -> Result<()>;
250}
251
252#[derive(Debug, Clone)]
254pub struct TensorMemoryStats {
255 pub total_allocated: usize,
257 pub used_memory: usize,
259 pub active_tensors: usize,
261 pub peak_memory: usize,
263}
264
265pub trait TensorDataAccess {
267 fn data_f32(&self) -> Option<&[f32]>;
270
271 fn data_bytes(&self) -> Option<&[u8]>;
273
274 fn to_vec_f32(&self) -> Result<Vec<f32>>;
276
277 fn to_vec_u8(&self) -> Result<Vec<u8>>;
279}
280
281pub mod utils {
283 use super::*;
284
285 pub fn matmul_output_shape(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
287 if a_shape.len() < 2 || b_shape.len() < 2 {
288 return Err(ferrum_types::FerrumError::backend(
289 "Matrix multiplication requires at least 2D tensors",
290 ));
291 }
292
293 let a_rows = a_shape[a_shape.len() - 2];
294 let a_cols = a_shape[a_shape.len() - 1];
295 let b_rows = b_shape[b_shape.len() - 2];
296 let b_cols = b_shape[b_shape.len() - 1];
297
298 if a_cols != b_rows {
299 return Err(ferrum_types::FerrumError::backend(format!(
300 "Matrix dimensions mismatch: {} vs {}",
301 a_cols, b_rows
302 )));
303 }
304
305 let mut output_shape = a_shape[..a_shape.len() - 2].to_vec();
306 output_shape.push(a_rows);
307 output_shape.push(b_cols);
308
309 Ok(output_shape)
310 }
311
312 pub fn are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
314 let max_ndim = shape1.len().max(shape2.len());
315
316 for i in 0..max_ndim {
317 let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
318 let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
319
320 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
321 return false;
322 }
323 }
324
325 true
326 }
327
328 pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
330 if !are_broadcastable(shape1, shape2) {
331 return None;
332 }
333
334 let max_ndim = shape1.len().max(shape2.len());
335 let mut output_shape = Vec::with_capacity(max_ndim);
336
337 for i in 0..max_ndim {
338 let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
339 let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
340
341 output_shape.push(dim1.max(dim2));
342 }
343
344 output_shape.reverse();
345 Some(output_shape)
346 }
347}