1use ferrum_interfaces::{TensorFactory, TensorRef};
4use ferrum_types::{DataType, Device, FerrumError, Result};
5use std::sync::Arc;
6
7#[derive(Clone)]
10pub struct MockTensor {
11 shape: Vec<usize>,
12 dtype: DataType,
13 device: Device,
14 data_f32: Vec<f32>,
15}
16
17impl MockTensor {
18 pub fn zeros(shape: &[usize], dtype: DataType) -> Self {
20 let numel: usize = shape.iter().product();
21 Self {
22 shape: shape.to_vec(),
23 dtype,
24 device: Device::CPU,
25 data_f32: vec![0.0; numel],
26 }
27 }
28
29 pub fn from_f32(data: Vec<f32>, shape: &[usize]) -> Self {
31 Self {
32 shape: shape.to_vec(),
33 dtype: DataType::FP32,
34 device: Device::CPU,
35 data_f32: data,
36 }
37 }
38
39 pub fn from_u32(data: &[u32], shape: &[usize]) -> Self {
41 Self {
42 shape: shape.to_vec(),
43 dtype: DataType::FP32,
44 device: Device::CPU,
45 data_f32: data.iter().map(|&v| v as f32).collect(),
46 }
47 }
48
49 pub fn into_ref(self) -> TensorRef {
51 Arc::new(self)
52 }
53}
54
55impl std::fmt::Debug for MockTensor {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("MockTensor")
58 .field("shape", &self.shape)
59 .field("dtype", &self.dtype)
60 .finish()
61 }
62}
63
64impl ferrum_interfaces::TensorLike for MockTensor {
65 fn as_any(&self) -> &dyn std::any::Any {
66 self
67 }
68
69 fn shape(&self) -> &[usize] {
70 &self.shape
71 }
72
73 fn dtype(&self) -> DataType {
74 self.dtype
75 }
76
77 fn device(&self) -> Device {
78 self.device.clone()
79 }
80
81 fn is_contiguous(&self) -> bool {
82 true
83 }
84
85 fn view(&self, start: &[usize], end: &[usize]) -> Result<TensorRef> {
86 let new_shape: Vec<usize> = start.iter().zip(end.iter()).map(|(s, e)| e - s).collect();
87
88 let ndim = self.shape.len();
90 let mut strides = vec![1usize; ndim];
91 for i in (0..ndim.saturating_sub(1)).rev() {
92 strides[i] = strides[i + 1] * self.shape[i + 1];
93 }
94
95 let new_numel: usize = new_shape.iter().product();
97 let mut data = Vec::with_capacity(new_numel);
98 let mut coords = start.to_vec();
99 loop {
100 let flat: usize = coords.iter().zip(strides.iter()).map(|(c, s)| c * s).sum();
102 data.push(self.data_f32[flat]);
103
104 let mut dim = ndim - 1;
106 loop {
107 coords[dim] += 1;
108 if coords[dim] < end[dim] {
109 break;
110 }
111 coords[dim] = start[dim];
112 if dim == 0 {
113 return Ok(MockTensor {
115 shape: new_shape,
116 dtype: self.dtype,
117 device: self.device.clone(),
118 data_f32: data,
119 }
120 .into_ref());
121 }
122 dim -= 1;
123 }
124 }
125 }
126
127 fn reshape(&self, shape: &[usize]) -> Result<TensorRef> {
128 let new_numel: usize = shape.iter().product();
129 if new_numel != self.data_f32.len() {
130 return Err(FerrumError::backend(format!(
131 "Cannot reshape {} elements to {:?}",
132 self.data_f32.len(),
133 shape
134 )));
135 }
136 Ok(MockTensor {
137 shape: shape.to_vec(),
138 dtype: self.dtype,
139 device: self.device.clone(),
140 data_f32: self.data_f32.clone(),
141 }
142 .into_ref())
143 }
144
145 fn to_cpu(&self) -> Result<TensorRef> {
146 Ok(self.clone().into_ref())
147 }
148
149 fn to_device(&self, _device: &Device) -> Result<TensorRef> {
150 Ok(self.clone().into_ref())
151 }
152
153 fn to_dtype(&self, dtype: DataType) -> Result<TensorRef> {
154 Ok(MockTensor {
155 shape: self.shape.clone(),
156 dtype,
157 device: self.device.clone(),
158 data_f32: self.data_f32.clone(),
159 }
160 .into_ref())
161 }
162
163 fn to_vec_f32(&self) -> Result<Vec<f32>> {
164 Ok(self.data_f32.clone())
165 }
166
167 fn to_vec_u32(&self) -> Result<Vec<u32>> {
168 Ok(self.data_f32.iter().map(|&v| v as u32).collect())
169 }
170
171 fn argmax_last_dim_u32(&self) -> Result<u32> {
172 self.data_f32
173 .iter()
174 .enumerate()
175 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
176 .map(|(i, _)| i as u32)
177 .ok_or_else(|| FerrumError::backend("Empty tensor"))
178 }
179}
180
181pub struct MockTensorFactory;
183
184impl TensorFactory for MockTensorFactory {
185 fn empty(&self, shape: &[usize], dtype: DataType, _device: Device) -> Result<TensorRef> {
186 Ok(MockTensor::zeros(shape, dtype).into_ref())
187 }
188
189 fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef> {
190 Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
191 }
192
193 fn from_slice(
194 &self,
195 data: &[f32],
196 shape: &[usize],
197 _dtype: DataType,
198 _device: Device,
199 ) -> Result<TensorRef> {
200 Ok(MockTensor::from_f32(data.to_vec(), shape).into_ref())
201 }
202
203 fn to_device(&self, tensor: &TensorRef, _device: Device) -> Result<TensorRef> {
204 Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
205 }
206
207 fn narrow(
208 &self,
209 tensor: &TensorRef,
210 dim: usize,
211 start: usize,
212 length: usize,
213 ) -> Result<TensorRef> {
214 let mut new_shape = tensor.shape().to_vec();
215 if dim < new_shape.len() {
216 new_shape[dim] = length;
217 }
218 let _ = start; Ok(MockTensor::zeros(&new_shape, tensor.dtype()).into_ref())
220 }
221
222 fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef> {
223 tensor.reshape(shape)
224 }
225
226 fn zeros(&self, shape: &[usize], dtype: DataType, _device: &Device) -> Result<TensorRef> {
227 Ok(MockTensor::zeros(shape, dtype).into_ref())
228 }
229
230 fn ones(&self, shape: &[usize], _dtype: DataType, _device: &Device) -> Result<TensorRef> {
231 let numel: usize = shape.iter().product();
232 Ok(MockTensor::from_f32(vec![1.0; numel], shape).into_ref())
233 }
234
235 fn uniform(
236 &self,
237 shape: &[usize],
238 _low: f32,
239 _high: f32,
240 dtype: DataType,
241 _device: &Device,
242 ) -> Result<TensorRef> {
243 Ok(MockTensor::zeros(shape, dtype).into_ref())
244 }
245
246 fn normal(
247 &self,
248 shape: &[usize],
249 _mean: f32,
250 _std: f32,
251 dtype: DataType,
252 _device: &Device,
253 ) -> Result<TensorRef> {
254 Ok(MockTensor::zeros(shape, dtype).into_ref())
255 }
256
257 fn from_tensor(&self, tensor: &TensorRef, _device: &Device) -> Result<TensorRef> {
258 Ok(MockTensor::zeros(tensor.shape(), tensor.dtype()).into_ref())
259 }
260}