Skip to main content

ferrum_models/
tensor_wrapper.rs

1//! Candle Tensor wrapper implementing TensorLike
2
3use candle_core::Tensor;
4use ferrum_interfaces::TensorLike;
5use ferrum_types::{DataType, Device, FerrumError, Result};
6use std::any::Any;
7
8/// Wrapper for Candle Tensor to implement TensorLike
9#[derive(Debug, Clone)]
10pub struct CandleTensorWrapper {
11    tensor: Tensor,
12}
13
14impl CandleTensorWrapper {
15    pub fn new(tensor: Tensor) -> Self {
16        Self { tensor }
17    }
18
19    pub fn inner(&self) -> &Tensor {
20        &self.tensor
21    }
22
23    pub fn into_inner(self) -> Tensor {
24        self.tensor
25    }
26
27    /// Safe extraction from Arc<dyn TensorLike>
28    pub fn from_tensorref(tensor_ref: &ferrum_interfaces::TensorRef) -> Option<Tensor> {
29        // Try to extract by getting raw data and reconstructing
30        // This is safe because we only read immutable data
31        let _ = tensor_ref;
32
33        // For now, return None if not our wrapper
34        // A better approach would be to add a method to TensorLike to extract data
35        None
36    }
37}
38
39impl TensorLike for CandleTensorWrapper {
40    fn as_any(&self) -> &dyn Any {
41        self
42    }
43
44    fn shape(&self) -> &[usize] {
45        self.tensor.dims()
46    }
47
48    fn dtype(&self) -> DataType {
49        match self.tensor.dtype() {
50            candle_core::DType::F32 => DataType::FP32,
51            candle_core::DType::F16 => DataType::FP16,
52            candle_core::DType::BF16 => DataType::BF16,
53            _ => DataType::FP32,
54        }
55    }
56
57    fn device(&self) -> Device {
58        match self.tensor.device() {
59            candle_core::Device::Cpu => Device::CPU,
60            candle_core::Device::Cuda(_) => Device::CUDA(0),
61            candle_core::Device::Metal(_) => {
62                #[cfg(any(target_os = "macos", target_os = "ios"))]
63                return Device::Metal;
64                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
65                Device::CPU
66            }
67        }
68    }
69
70    fn is_contiguous(&self) -> bool {
71        self.tensor.is_contiguous()
72    }
73
74    fn view(&self, start: &[usize], end: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
75        if start.len() != end.len() || start.len() != self.tensor.dims().len() {
76            return Err(FerrumError::model(format!(
77                "Invalid view dimensions: start={:?}, end={:?}, shape={:?}",
78                start,
79                end,
80                self.tensor.dims()
81            )));
82        }
83
84        let mut view = self.tensor.clone();
85        for (dim, (&start_idx, &end_idx)) in start.iter().zip(end.iter()).enumerate() {
86            if end_idx < start_idx {
87                return Err(FerrumError::model(format!(
88                    "Invalid view range on dim {}: {}..{}",
89                    dim, start_idx, end_idx
90                )));
91            }
92
93            let current_dim = view
94                .dims()
95                .get(dim)
96                .copied()
97                .ok_or_else(|| FerrumError::model("View dimension out of bounds"))?;
98            if end_idx > current_dim {
99                return Err(FerrumError::model(format!(
100                    "View end out of bounds on dim {}: {} > {}",
101                    dim, end_idx, current_dim
102                )));
103            }
104
105            let length = end_idx - start_idx;
106            if start_idx != 0 || length != current_dim {
107                view = view.narrow(dim, start_idx, length).map_err(|e| {
108                    FerrumError::model(format!("View narrow failed on dim {}: {}", dim, e))
109                })?;
110            }
111        }
112
113        Ok(std::sync::Arc::new(CandleTensorWrapper::new(view)))
114    }
115
116    fn reshape(&self, shape: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
117        let reshaped = self
118            .tensor
119            .reshape(shape)
120            .map_err(|e| FerrumError::model(format!("Reshape failed: {}", e)))?;
121        Ok(std::sync::Arc::new(CandleTensorWrapper::new(reshaped)))
122    }
123
124    fn to_cpu(&self) -> Result<ferrum_interfaces::TensorRef> {
125        if matches!(self.tensor.device(), candle_core::Device::Cpu) {
126            return Ok(std::sync::Arc::new(self.clone()));
127        }
128
129        let cpu_tensor = self
130            .tensor
131            .to_device(&candle_core::Device::Cpu)
132            .map_err(|e| FerrumError::model(format!("to_cpu failed: {}", e)))?;
133        Ok(std::sync::Arc::new(CandleTensorWrapper::new(cpu_tensor)))
134    }
135
136    fn to_device(&self, device: &Device) -> Result<ferrum_interfaces::TensorRef> {
137        let candle_device = match device {
138            Device::CPU => candle_core::Device::Cpu,
139            Device::CUDA(id) => candle_core::Device::new_cuda(*id)
140                .map_err(|e| FerrumError::device(format!("CUDA device error: {}", e)))?,
141            #[cfg(any(target_os = "macos", target_os = "ios"))]
142            Device::Metal => candle_core::Device::new_metal(0)
143                .map_err(|e| FerrumError::device(format!("Metal device error: {}", e)))?,
144            Device::ROCm(_) => {
145                return Err(FerrumError::device("ROCm not supported yet"));
146            }
147        };
148
149        let device_tensor = self
150            .tensor
151            .to_device(&candle_device)
152            .map_err(|e| FerrumError::model(format!("to_device failed: {}", e)))?;
153        Ok(std::sync::Arc::new(CandleTensorWrapper::new(device_tensor)))
154    }
155
156    fn to_dtype(&self, dtype: DataType) -> Result<ferrum_interfaces::TensorRef> {
157        let candle_dtype = match &dtype {
158            DataType::FP32 => candle_core::DType::F32,
159            DataType::FP16 => candle_core::DType::F16,
160            DataType::BF16 => candle_core::DType::BF16,
161            _ => {
162                return Err(FerrumError::model(format!(
163                    "Unsupported dtype: {:?}",
164                    dtype
165                )))
166            }
167        };
168
169        let converted = self
170            .tensor
171            .to_dtype(candle_dtype)
172            .map_err(|e| FerrumError::model(format!("to_dtype failed: {}", e)))?;
173        Ok(std::sync::Arc::new(CandleTensorWrapper::new(converted)))
174    }
175
176    /// Extract tensor data as Vec<f32> - Candle implementation
177    fn to_vec_f32(&self) -> Result<Vec<f32>> {
178        // Ensure F32 dtype (CUDA/Metal may produce F16/BF16 logits)
179        let tensor = if self.tensor.dtype() != candle_core::DType::F32 {
180            self.tensor
181                .to_dtype(candle_core::DType::F32)
182                .map_err(|e| FerrumError::model(format!("Cast to f32 failed: {}", e)))?
183        } else {
184            self.tensor.clone()
185        };
186        // Handle different tensor dimensions
187        match tensor.dims().len() {
188            1 => tensor
189                .to_vec1::<f32>()
190                .map_err(|e| FerrumError::model(format!("to_vec1 failed: {}", e))),
191            2 => {
192                // Take first batch: [batch, vocab] -> [vocab]
193                let batch = tensor
194                    .to_vec2::<f32>()
195                    .map_err(|e| FerrumError::model(format!("to_vec2 failed: {}", e)))?;
196                Ok(batch.into_iter().next().unwrap_or_default())
197            }
198            3 => {
199                // Take last token of first batch: [batch, seq, vocab] -> [vocab]
200                let all = tensor
201                    .to_vec3::<f32>()
202                    .map_err(|e| FerrumError::model(format!("to_vec3 failed: {}", e)))?;
203                Ok(all
204                    .into_iter()
205                    .next()
206                    .and_then(|seq| seq.into_iter().last())
207                    .unwrap_or_default())
208            }
209            4 => {
210                // Handle [batch, seq, extra, vocab] - squeeze and take last
211                // First squeeze to 3D by selecting first element of extra dim
212                let squeezed = tensor
213                    .squeeze(2)
214                    .map_err(|e| FerrumError::model(format!("Squeeze dim 2 failed: {}", e)))?;
215
216                // Now extract as 3D: [batch, seq, vocab]
217                let all = squeezed
218                    .to_vec3::<f32>()
219                    .map_err(|e| FerrumError::model(format!("to_vec3 (from 4D) failed: {}", e)))?;
220                Ok(all
221                    .into_iter()
222                    .next()
223                    .and_then(|seq| seq.into_iter().last())
224                    .unwrap_or_default())
225            }
226            _ => Err(FerrumError::model(format!(
227                "Unsupported dims: {:?}",
228                self.tensor.dims()
229            ))),
230        }
231    }
232
233    fn to_vec_u32(&self) -> Result<Vec<u32>> {
234        // Handle different tensor dimensions for token IDs
235        match self.tensor.dims().len() {
236            1 => self
237                .tensor
238                .to_vec1::<u32>()
239                .map_err(|e| FerrumError::model(format!("to_vec1<u32> failed: {}", e))),
240            2 => {
241                // Take first batch: [batch, seq] -> [seq]
242                let batch = self
243                    .tensor
244                    .to_vec2::<u32>()
245                    .map_err(|e| FerrumError::model(format!("to_vec2<u32> failed: {}", e)))?;
246                Ok(batch.into_iter().next().unwrap_or_default())
247            }
248            _ => Err(FerrumError::model(format!(
249                "Unsupported dims for token extraction: {:?}",
250                self.tensor.dims()
251            ))),
252        }
253    }
254
255    fn argmax_last_dim_u32(&self) -> Result<u32> {
256        // Same strategy as runtime CandleTensor: argmax on-device, read back a scalar.
257        use candle_core::{IndexOp, D};
258
259        let dims = self.tensor.dims();
260        let logits_1d = match dims.len() {
261            1 => self.tensor.clone(),
262            2 => self
263                .tensor
264                .i(0)
265                .map_err(|e| FerrumError::model(format!("Index batch failed: {}", e)))?,
266            3 => {
267                let seq_len = dims[1];
268                self.tensor
269                    .i((0, seq_len.saturating_sub(1)))
270                    .map_err(|e| FerrumError::model(format!("Index last token failed: {}", e)))?
271            }
272            4 => {
273                // [batch, seq, extra, vocab] -> take batch 0, last seq, extra 0 -> [vocab]
274                let seq_len = dims[1];
275                self.tensor
276                    .i((0, seq_len.saturating_sub(1), 0))
277                    .map_err(|e| {
278                        FerrumError::model(format!("Index last token (4D) failed: {}", e))
279                    })?
280            }
281            _ => {
282                return Err(FerrumError::model(format!(
283                    "argmax_last_dim_u32 unsupported dims: {:?}",
284                    dims
285                )))
286            }
287        };
288
289        let idx = logits_1d
290            .argmax(D::Minus1)
291            .map_err(|e| FerrumError::model(format!("Argmax failed: {}", e)))?
292            .to_device(&candle_core::Device::Cpu)
293            .map_err(|e| FerrumError::model(format!("Argmax to CPU failed: {}", e)))?
294            .to_vec0::<u32>()
295            .map_err(|e| FerrumError::model(format!("Argmax readback failed: {}", e)))?;
296
297        Ok(idx)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn view_extracts_last_sequence_slice() {
307        let tensor = Tensor::from_vec(
308            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
309            (1, 2, 3),
310            &candle_core::Device::Cpu,
311        )
312        .expect("create tensor");
313        let wrapper = CandleTensorWrapper::new(tensor);
314
315        let view = wrapper.view(&[0, 1, 0], &[1, 2, 3]).expect("slice view");
316        assert_eq!(view.shape(), &[1, 1, 3]);
317        assert_eq!(view.to_vec_f32().expect("to_vec_f32"), vec![4.0, 5.0, 6.0]);
318    }
319}