Skip to main content

ferrum_testkit/
tensor.rs

1//! Mock tensor and tensor factory for testing without any ML backend.
2
3use ferrum_interfaces::{TensorFactory, TensorRef};
4use ferrum_types::{DataType, Device, FerrumError, Result};
5use std::sync::Arc;
6
7/// A mock tensor that stores shape and optional f32 data.
8/// No GPU, no Candle — pure Rust.
9#[derive(Clone)]
10pub struct MockTensor {
11    shape: Vec<usize>,
12    dtype: DataType,
13    device: Device,
14    data_f32: Vec<f32>,
15}
16
17impl MockTensor {
18    /// Create a zero-filled tensor with given shape.
19    pub fn zeros(shape: &[usize], dtype: DataType) -> Self {
20        let numel: usize = shape.iter().product();
21        Self {
22            shape: shape.to_vec(),
23            dtype,
24            device: Device::CPU,
25            data_f32: vec![0.0; numel],
26        }
27    }
28
29    /// Create a tensor from f32 data.
30    pub fn from_f32(data: Vec<f32>, shape: &[usize]) -> Self {
31        Self {
32            shape: shape.to_vec(),
33            dtype: DataType::FP32,
34            device: Device::CPU,
35            data_f32: data,
36        }
37    }
38
39    /// Create a tensor from u32 token IDs (stored as f32 internally).
40    pub fn from_u32(data: &[u32], shape: &[usize]) -> Self {
41        Self {
42            shape: shape.to_vec(),
43            dtype: DataType::FP32,
44            device: Device::CPU,
45            data_f32: data.iter().map(|&v| v as f32).collect(),
46        }
47    }
48
49    /// Wrap as TensorRef.
50    pub fn into_ref(self) -> TensorRef {
51        Arc::new(self)
52    }
53}
54
55impl std::fmt::Debug for MockTensor {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("MockTensor")
58            .field("shape", &self.shape)
59            .field("dtype", &self.dtype)
60            .finish()
61    }
62}
63
64impl ferrum_interfaces::TensorLike for MockTensor {
65    fn as_any(&self) -> &dyn std::any::Any {
66        self
67    }
68
69    fn shape(&self) -> &[usize] {
70        &self.shape
71    }
72
73    fn dtype(&self) -> DataType {
74        self.dtype
75    }
76
77    fn device(&self) -> Device {
78        self.device.clone()
79    }
80
81    fn is_contiguous(&self) -> bool {
82        true
83    }
84
85    fn view(&self, start: &[usize], end: &[usize]) -> Result<TensorRef> {
86        let new_shape: Vec<usize> = start.iter().zip(end.iter()).map(|(s, e)| e - s).collect();
87
88        // Compute strides for the original shape (row-major)
89        let ndim = self.shape.len();
90        let mut strides = vec![1usize; ndim];
91        for i in (0..ndim.saturating_sub(1)).rev() {
92            strides[i] = strides[i + 1] * self.shape[i + 1];
93        }
94
95        // Copy the viewed region
96        let new_numel: usize = new_shape.iter().product();
97        let mut data = Vec::with_capacity(new_numel);
98        let mut coords = start.to_vec();
99        loop {
100            // Compute flat index
101            let flat: usize = coords.iter().zip(strides.iter()).map(|(c, s)| c * s).sum();
102            data.push(self.data_f32[flat]);
103
104            // Increment coords (innermost first)
105            let mut dim = ndim - 1;
106            loop {
107                coords[dim] += 1;
108                if coords[dim] < end[dim] {
109                    break;
110                }
111                coords[dim] = start[dim];
112                if dim == 0 {
113                    // All done
114                    return Ok(MockTensor {
115                        shape: new_shape,
116                        dtype: self.dtype,
117                        device: self.device.clone(),
118                        data_f32: data,
119                    }
120                    .into_ref());
121                }
122                dim -= 1;
123            }
124        }
125    }
126
127    fn reshape(&self, shape: &[usize]) -> Result<TensorRef> {
128        let new_numel: usize = shape.iter().product();
129        if new_numel != self.data_f32.len() {
130            return Err(FerrumError::backend(format!(
131                "Cannot reshape {} elements to {:?}",
132                self.data_f32.len(),
133                shape
134            )));
135        }
136        Ok(MockTensor {
137            shape: shape.to_vec(),
138            dtype: self.dtype,
139            device: self.device.clone(),
140            data_f32: self.data_f32.clone(),
141        }
142        .into_ref())
143    }
144
145    fn to_cpu(&self) -> Result<TensorRef> {
146        Ok(self.clone().into_ref())
147    }
148
149    fn to_device(&self, _device: &Device) -> Result<TensorRef> {
150        Ok(self.clone().into_ref())
151    }
152
153    fn to_dtype(&self, dtype: DataType) -> Result<TensorRef> {
154        Ok(MockTensor {
155            shape: self.shape.clone(),
156            dtype,
157            device: self.device.clone(),
158            data_f32: self.data_f32.clone(),
159        }
160        .into_ref())
161    }
162
163    fn to_vec_f32(&self) -> Result<Vec<f32>> {
164        Ok(self.data_f32.clone())
165    }
166
167    fn to_vec_u32(&self) -> Result<Vec<u32>> {
168        Ok(self.data_f32.iter().map(|&v| v as u32).collect())
169    }
170
171    fn argmax_last_dim_u32(&self) -> Result<u32> {
172        self.data_f32
173            .iter()
174            .enumerate()
175            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
176            .map(|(i, _)| i as u32)
177            .ok_or_else(|| FerrumError::backend("Empty tensor"))
178    }
179}
180
181/// Mock tensor factory implementing TensorFactory without any ML backend.
182pub struct MockTensorFactory;
183
184impl TensorFactory for MockTensorFactory {
185    fn empty(&self, shape: &[usize], dtype: DataType, _device: Device) -> Result<TensorRef> {
186        Ok(MockTensor::zeros(shape, dtype).into_ref())
187    }
188
189    fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef> {
190        Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
191    }
192
193    fn from_slice(
194        &self,
195        data: &[f32],
196        shape: &[usize],
197        _dtype: DataType,
198        _device: Device,
199    ) -> Result<TensorRef> {
200        Ok(MockTensor::from_f32(data.to_vec(), shape).into_ref())
201    }
202
203    fn to_device(&self, tensor: &TensorRef, _device: Device) -> Result<TensorRef> {
204        Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
205    }
206
207    fn narrow(
208        &self,
209        tensor: &TensorRef,
210        dim: usize,
211        start: usize,
212        length: usize,
213    ) -> Result<TensorRef> {
214        let mut new_shape = tensor.shape().to_vec();
215        if dim < new_shape.len() {
216            new_shape[dim] = length;
217        }
218        let _ = start; // mock ignores actual data slicing
219        Ok(MockTensor::zeros(&new_shape, tensor.dtype()).into_ref())
220    }
221
222    fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef> {
223        tensor.reshape(shape)
224    }
225
226    fn zeros(&self, shape: &[usize], dtype: DataType, _device: &Device) -> Result<TensorRef> {
227        Ok(MockTensor::zeros(shape, dtype).into_ref())
228    }
229
230    fn ones(&self, shape: &[usize], _dtype: DataType, _device: &Device) -> Result<TensorRef> {
231        let numel: usize = shape.iter().product();
232        Ok(MockTensor::from_f32(vec![1.0; numel], shape).into_ref())
233    }
234
235    fn uniform(
236        &self,
237        shape: &[usize],
238        _low: f32,
239        _high: f32,
240        dtype: DataType,
241        _device: &Device,
242    ) -> Result<TensorRef> {
243        Ok(MockTensor::zeros(shape, dtype).into_ref())
244    }
245
246    fn normal(
247        &self,
248        shape: &[usize],
249        _mean: f32,
250        _std: f32,
251        dtype: DataType,
252        _device: &Device,
253    ) -> Result<TensorRef> {
254        Ok(MockTensor::zeros(shape, dtype).into_ref())
255    }
256
257    fn from_tensor(&self, tensor: &TensorRef, _device: &Device) -> Result<TensorRef> {
258        Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
259    }
260}