Skip to main content

ferrum_interfaces/
tensor.rs

1//! Tensor abstraction with zero-copy and device-aware semantics
2//!
3//! This module provides the core tensor interface that abstracts over different
4//! ML frameworks (Candle, ONNX Runtime, etc.) while maintaining zero-copy
5//! semantics and device information.
6
7use ferrum_types::{DataType, Device, Result};
8use std::any::Any;
9use std::sync::Arc;
10
11/// Core tensor trait for zero-copy, device-aware operations
12pub trait TensorLike: Send + Sync + std::fmt::Debug {
13    /// Downcast support for backend-specific fast paths
14    fn as_any(&self) -> &dyn Any;
15
16    /// Get tensor shape
17    fn shape(&self) -> &[usize];
18
19    /// Get tensor data type
20    fn dtype(&self) -> DataType;
21
22    /// Get device where tensor resides
23    fn device(&self) -> Device;
24
25    /// Get total number of elements
26    fn numel(&self) -> usize {
27        self.shape().iter().product()
28    }
29
30    /// Get number of dimensions
31    fn ndim(&self) -> usize {
32        self.shape().len()
33    }
34
35    /// Check if tensor is scalar (0-dimensional)
36    fn is_scalar(&self) -> bool {
37        self.shape().is_empty()
38    }
39
40    /// Check if tensor is contiguous in memory
41    fn is_contiguous(&self) -> bool;
42
43    /// Get size in bytes for this tensor
44    fn size_bytes(&self) -> usize {
45        self.numel() * self.dtype().size_bytes()
46    }
47
48    /// Create a view/slice of this tensor
49    fn view(&self, start: &[usize], end: &[usize]) -> Result<TensorRef>;
50
51    /// Reshape tensor to new shape (must have same number of elements)
52    fn reshape(&self, shape: &[usize]) -> Result<TensorRef>;
53
54    /// Convert tensor to CPU device
55    fn to_cpu(&self) -> Result<TensorRef>;
56
57    /// Convert tensor to specific device  
58    fn to_device(&self, device: &Device) -> Result<TensorRef>;
59
60    /// Convert tensor to specific data type
61    fn to_dtype(&self, dtype: DataType) -> Result<TensorRef>;
62
63    /// Extract tensor data as Vec<f32> (for logits sampling)
64    /// This is a convenience method for backends that need to extract data
65    fn to_vec_f32(&self) -> Result<Vec<f32>> {
66        // Default implementation returns error - backends should override
67        Err(crate::FerrumError::model(
68            "to_vec_f32 not implemented for this tensor backend",
69        ))
70    }
71
72    /// Extract tensor data as Vec<u32> (for token IDs)
73    /// This is a convenience method for backends that need to extract token data
74    fn to_vec_u32(&self) -> Result<Vec<u32>> {
75        // Default implementation returns error - backends should override
76        Err(crate::FerrumError::model(
77            "to_vec_u32 not implemented for this tensor backend",
78        ))
79    }
80
81    /// Fast path: argmax over the last dimension, returning the selected token id.
82    ///
83    /// Backends may override this to avoid transferring full logits to CPU.
84    fn argmax_last_dim_u32(&self) -> Result<u32> {
85        Err(crate::FerrumError::model(
86            "argmax_last_dim_u32 not implemented for this tensor backend",
87        ))
88    }
89}
90
91/// Reference-counted tensor handle for zero-copy sharing
92pub type TensorRef = Arc<dyn TensorLike>;
93
94/// Tensor factory for creating tensors on specific backends
95pub trait TensorFactory: Send + Sync {
96    /// 创建指定形状/数据类型的空张量(`[MVP]`)
97    fn empty(&self, shape: &[usize], dtype: DataType, device: Device) -> Result<TensorRef>;
98    /// 基于已有张量创建零填充张量(`[MVP]`)
99    fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef>;
100    /// 通过 slice 数据创建张量(`[MVP]`)
101    fn from_slice(
102        &self,
103        data: &[f32],
104        shape: &[usize],
105        dtype: DataType,
106        device: Device,
107    ) -> Result<TensorRef>;
108    /// 迁移张量到目标设备(`[MVP]`)
109    fn to_device(&self, tensor: &TensorRef, device: Device) -> Result<TensorRef>;
110    /// 执行窄视图操作(`[MVP]`)
111    fn narrow(
112        &self,
113        tensor: &TensorRef,
114        dim: usize,
115        start: usize,
116        length: usize,
117    ) -> Result<TensorRef>;
118    /// reshape 张量(`[MVP]`)
119    fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef>;
120
121    /// Create tensor filled with zeros(`[Phase 2+]` 可选实现)
122    fn zeros(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
123
124    /// Create tensor filled with ones(`[Phase 2+]`)
125    fn ones(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
126
127    /// Create tensor from uniform random distribution(`[Phase 2+]`)
128    fn uniform(
129        &self,
130        shape: &[usize],
131        low: f32,
132        high: f32,
133        dtype: DataType,
134        device: &Device,
135    ) -> Result<TensorRef>;
136
137    /// Create tensor from normal distribution(`[Phase 2+]`)
138    fn normal(
139        &self,
140        shape: &[usize],
141        mean: f32,
142        std: f32,
143        dtype: DataType,
144        device: &Device,
145    ) -> Result<TensorRef>;
146
147    /// Create tensor from existing tensor reference (may involve copying)
148    fn from_tensor(&self, tensor: &TensorRef, device: &Device) -> Result<TensorRef>;
149}
150
151/// Basic tensor operations
152pub trait TensorOps: Send + Sync {
153    /// Matrix multiplication
154    fn matmul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
155
156    /// Element-wise addition
157    fn add(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
158
159    /// Element-wise subtraction  
160    fn sub(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
161
162    /// Element-wise multiplication
163    fn mul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
164
165    /// Element-wise division
166    fn div(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
167
168    /// Apply softmax along specified dimension
169    fn softmax(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
170
171    /// Apply layer normalization
172    fn layer_norm(
173        &self,
174        input: &TensorRef,
175        weight: &TensorRef,
176        bias: Option<&TensorRef>,
177        eps: f32,
178    ) -> Result<TensorRef>;
179
180    /// Apply RMS normalization  
181    fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
182
183    /// Apply ReLU activation
184    fn relu(&self, tensor: &TensorRef) -> Result<TensorRef>;
185
186    /// Apply GELU activation
187    fn gelu(&self, tensor: &TensorRef) -> Result<TensorRef>;
188
189    /// Apply SiLU (Swish) activation
190    fn silu(&self, tensor: &TensorRef) -> Result<TensorRef>;
191
192    /// Concatenate tensors along specified dimension
193    fn concat(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
194
195    /// Split tensor along specified dimension
196    fn split(&self, tensor: &TensorRef, sizes: &[usize], dim: usize) -> Result<Vec<TensorRef>>;
197
198    /// Transpose tensor dimensions
199    fn transpose(&self, tensor: &TensorRef, dim0: usize, dim1: usize) -> Result<TensorRef>;
200
201    /// Permute tensor dimensions
202    fn permute(&self, tensor: &TensorRef, dims: &[usize]) -> Result<TensorRef>;
203}
204
205/// GPU-specific tensor operations
206#[async_trait::async_trait]
207pub trait AsyncTensorOps: TensorOps {
208    /// Asynchronous matrix multiplication
209    async fn matmul_async(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef>;
210
211    /// Asynchronous softmax
212    async fn softmax_async(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef>;
213
214    /// Synchronize all pending operations
215    async fn synchronize(&self) -> Result<()>;
216}
217
218/// Tensor batch operations for efficient processing
219pub trait TensorBatchOps: Send + Sync {
220    /// Batch matrix multiplication for multiple pairs
221    fn batch_matmul(
222        &self,
223        a_batch: &[&TensorRef],
224        b_batch: &[&TensorRef],
225    ) -> Result<Vec<TensorRef>>;
226
227    /// Stack tensors along new dimension  
228    fn stack(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef>;
229
230    /// Unstack tensor along specified dimension
231    fn unstack(&self, tensor: &TensorRef, dim: usize) -> Result<Vec<TensorRef>>;
232
233    /// Pad tensors in batch to same shape
234    fn pad_batch(&self, tensors: &[&TensorRef], target_shape: &[usize]) -> Result<Vec<TensorRef>>;
235}
236
237/// Device-specific tensor memory management
238pub trait TensorMemoryManager: Send + Sync {
239    /// Pre-allocate tensor of given shape for reuse
240    fn preallocate(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef>;
241
242    /// Clear tensor data (set to zeros) without deallocation
243    fn clear(&self, tensor: &TensorRef) -> Result<()>;
244
245    /// Get memory usage statistics
246    fn memory_stats(&self) -> TensorMemoryStats;
247
248    /// Force garbage collection of unused tensors
249    fn gc(&self) -> Result<()>;
250}
251
252/// Tensor memory usage statistics
253#[derive(Debug, Clone)]
254pub struct TensorMemoryStats {
255    /// Total allocated memory in bytes
256    pub total_allocated: usize,
257    /// Currently used memory in bytes  
258    pub used_memory: usize,
259    /// Number of active tensor references
260    pub active_tensors: usize,
261    /// Peak memory usage
262    pub peak_memory: usize,
263}
264
265/// Tensor data access for interop
266pub trait TensorDataAccess {
267    /// Get read-only access to raw data (CPU only)
268    /// Returns None if tensor is not on CPU or data is not contiguous
269    fn data_f32(&self) -> Option<&[f32]>;
270
271    /// Get read-only access to raw data as bytes
272    fn data_bytes(&self) -> Option<&[u8]>;
273
274    /// Copy tensor data to a Vec<f32> (may involve device-to-host transfer)
275    fn to_vec_f32(&self) -> Result<Vec<f32>>;
276
277    /// Copy tensor data to a Vec<u8>
278    fn to_vec_u8(&self) -> Result<Vec<u8>>;
279}
280
281/// Utility functions for tensor operations
282pub mod utils {
283    use super::*;
284
285    /// Calculate output shape for matrix multiplication
286    pub fn matmul_output_shape(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
287        if a_shape.len() < 2 || b_shape.len() < 2 {
288            return Err(ferrum_types::FerrumError::backend(
289                "Matrix multiplication requires at least 2D tensors",
290            ));
291        }
292
293        let a_rows = a_shape[a_shape.len() - 2];
294        let a_cols = a_shape[a_shape.len() - 1];
295        let b_rows = b_shape[b_shape.len() - 2];
296        let b_cols = b_shape[b_shape.len() - 1];
297
298        if a_cols != b_rows {
299            return Err(ferrum_types::FerrumError::backend(format!(
300                "Matrix dimensions mismatch: {} vs {}",
301                a_cols, b_rows
302            )));
303        }
304
305        let mut output_shape = a_shape[..a_shape.len() - 2].to_vec();
306        output_shape.push(a_rows);
307        output_shape.push(b_cols);
308
309        Ok(output_shape)
310    }
311
312    /// Check if shapes are broadcastable
313    pub fn are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
314        let max_ndim = shape1.len().max(shape2.len());
315
316        for i in 0..max_ndim {
317            let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
318            let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
319
320            if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
321                return false;
322            }
323        }
324
325        true
326    }
327
328    /// Calculate output shape after broadcasting
329    pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
330        if !are_broadcastable(shape1, shape2) {
331            return None;
332        }
333
334        let max_ndim = shape1.len().max(shape2.len());
335        let mut output_shape = Vec::with_capacity(max_ndim);
336
337        for i in 0..max_ndim {
338            let dim1 = shape1.get(shape1.len() - 1 - i).copied().unwrap_or(1);
339            let dim2 = shape2.get(shape2.len() - 1 - i).copied().unwrap_or(1);
340
341            output_shape.push(dim1.max(dim2));
342        }
343
344        output_shape.reverse();
345        Some(output_shape)
346    }
347}