1use candle_core::Tensor;
4use ferrum_interfaces::TensorLike;
5use ferrum_types::{DataType, Device, FerrumError, Result};
6use std::any::Any;
7
8#[derive(Debug, Clone)]
10pub struct CandleTensorWrapper {
11 tensor: Tensor,
12}
13
14impl CandleTensorWrapper {
15 pub fn new(tensor: Tensor) -> Self {
16 Self { tensor }
17 }
18
19 pub fn inner(&self) -> &Tensor {
20 &self.tensor
21 }
22
23 pub fn into_inner(self) -> Tensor {
24 self.tensor
25 }
26
27 pub fn from_tensorref(tensor_ref: &ferrum_interfaces::TensorRef) -> Option<Tensor> {
29 let _ = tensor_ref;
32
33 None
36 }
37}
38
39impl TensorLike for CandleTensorWrapper {
40 fn as_any(&self) -> &dyn Any {
41 self
42 }
43
44 fn shape(&self) -> &[usize] {
45 self.tensor.dims()
46 }
47
48 fn dtype(&self) -> DataType {
49 match self.tensor.dtype() {
50 candle_core::DType::F32 => DataType::FP32,
51 candle_core::DType::F16 => DataType::FP16,
52 candle_core::DType::BF16 => DataType::BF16,
53 _ => DataType::FP32,
54 }
55 }
56
57 fn device(&self) -> Device {
58 match self.tensor.device() {
59 candle_core::Device::Cpu => Device::CPU,
60 candle_core::Device::Cuda(_) => Device::CUDA(0),
61 candle_core::Device::Metal(_) => {
62 #[cfg(any(target_os = "macos", target_os = "ios"))]
63 return Device::Metal;
64 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
65 Device::CPU
66 }
67 }
68 }
69
70 fn is_contiguous(&self) -> bool {
71 self.tensor.is_contiguous()
72 }
73
74 fn view(&self, start: &[usize], end: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
75 if start.len() != end.len() || start.len() != self.tensor.dims().len() {
76 return Err(FerrumError::model(format!(
77 "Invalid view dimensions: start={:?}, end={:?}, shape={:?}",
78 start,
79 end,
80 self.tensor.dims()
81 )));
82 }
83
84 let mut view = self.tensor.clone();
85 for (dim, (&start_idx, &end_idx)) in start.iter().zip(end.iter()).enumerate() {
86 if end_idx < start_idx {
87 return Err(FerrumError::model(format!(
88 "Invalid view range on dim {}: {}..{}",
89 dim, start_idx, end_idx
90 )));
91 }
92
93 let current_dim = view
94 .dims()
95 .get(dim)
96 .copied()
97 .ok_or_else(|| FerrumError::model("View dimension out of bounds"))?;
98 if end_idx > current_dim {
99 return Err(FerrumError::model(format!(
100 "View end out of bounds on dim {}: {} > {}",
101 dim, end_idx, current_dim
102 )));
103 }
104
105 let length = end_idx - start_idx;
106 if start_idx != 0 || length != current_dim {
107 view = view.narrow(dim, start_idx, length).map_err(|e| {
108 FerrumError::model(format!("View narrow failed on dim {}: {}", dim, e))
109 })?;
110 }
111 }
112
113 Ok(std::sync::Arc::new(CandleTensorWrapper::new(view)))
114 }
115
116 fn reshape(&self, shape: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
117 let reshaped = self
118 .tensor
119 .reshape(shape)
120 .map_err(|e| FerrumError::model(format!("Reshape failed: {}", e)))?;
121 Ok(std::sync::Arc::new(CandleTensorWrapper::new(reshaped)))
122 }
123
124 fn to_cpu(&self) -> Result<ferrum_interfaces::TensorRef> {
125 if matches!(self.tensor.device(), candle_core::Device::Cpu) {
126 return Ok(std::sync::Arc::new(self.clone()));
127 }
128
129 let cpu_tensor = self
130 .tensor
131 .to_device(&candle_core::Device::Cpu)
132 .map_err(|e| FerrumError::model(format!("to_cpu failed: {}", e)))?;
133 Ok(std::sync::Arc::new(CandleTensorWrapper::new(cpu_tensor)))
134 }
135
136 fn to_device(&self, device: &Device) -> Result<ferrum_interfaces::TensorRef> {
137 let candle_device = match device {
138 Device::CPU => candle_core::Device::Cpu,
139 Device::CUDA(id) => candle_core::Device::new_cuda(*id)
140 .map_err(|e| FerrumError::device(format!("CUDA device error: {}", e)))?,
141 #[cfg(any(target_os = "macos", target_os = "ios"))]
142 Device::Metal => candle_core::Device::new_metal(0)
143 .map_err(|e| FerrumError::device(format!("Metal device error: {}", e)))?,
144 Device::ROCm(_) => {
145 return Err(FerrumError::device("ROCm not supported yet"));
146 }
147 };
148
149 let device_tensor = self
150 .tensor
151 .to_device(&candle_device)
152 .map_err(|e| FerrumError::model(format!("to_device failed: {}", e)))?;
153 Ok(std::sync::Arc::new(CandleTensorWrapper::new(device_tensor)))
154 }
155
156 fn to_dtype(&self, dtype: DataType) -> Result<ferrum_interfaces::TensorRef> {
157 let candle_dtype = match &dtype {
158 DataType::FP32 => candle_core::DType::F32,
159 DataType::FP16 => candle_core::DType::F16,
160 DataType::BF16 => candle_core::DType::BF16,
161 _ => {
162 return Err(FerrumError::model(format!(
163 "Unsupported dtype: {:?}",
164 dtype
165 )))
166 }
167 };
168
169 let converted = self
170 .tensor
171 .to_dtype(candle_dtype)
172 .map_err(|e| FerrumError::model(format!("to_dtype failed: {}", e)))?;
173 Ok(std::sync::Arc::new(CandleTensorWrapper::new(converted)))
174 }
175
176 fn to_vec_f32(&self) -> Result<Vec<f32>> {
178 let tensor = if self.tensor.dtype() != candle_core::DType::F32 {
180 self.tensor
181 .to_dtype(candle_core::DType::F32)
182 .map_err(|e| FerrumError::model(format!("Cast to f32 failed: {}", e)))?
183 } else {
184 self.tensor.clone()
185 };
186 match tensor.dims().len() {
188 1 => tensor
189 .to_vec1::<f32>()
190 .map_err(|e| FerrumError::model(format!("to_vec1 failed: {}", e))),
191 2 => {
192 let batch = tensor
194 .to_vec2::<f32>()
195 .map_err(|e| FerrumError::model(format!("to_vec2 failed: {}", e)))?;
196 Ok(batch.into_iter().next().unwrap_or_default())
197 }
198 3 => {
199 let all = tensor
201 .to_vec3::<f32>()
202 .map_err(|e| FerrumError::model(format!("to_vec3 failed: {}", e)))?;
203 Ok(all
204 .into_iter()
205 .next()
206 .and_then(|seq| seq.into_iter().last())
207 .unwrap_or_default())
208 }
209 4 => {
210 let squeezed = tensor
213 .squeeze(2)
214 .map_err(|e| FerrumError::model(format!("Squeeze dim 2 failed: {}", e)))?;
215
216 let all = squeezed
218 .to_vec3::<f32>()
219 .map_err(|e| FerrumError::model(format!("to_vec3 (from 4D) failed: {}", e)))?;
220 Ok(all
221 .into_iter()
222 .next()
223 .and_then(|seq| seq.into_iter().last())
224 .unwrap_or_default())
225 }
226 _ => Err(FerrumError::model(format!(
227 "Unsupported dims: {:?}",
228 self.tensor.dims()
229 ))),
230 }
231 }
232
233 fn to_vec_u32(&self) -> Result<Vec<u32>> {
234 match self.tensor.dims().len() {
236 1 => self
237 .tensor
238 .to_vec1::<u32>()
239 .map_err(|e| FerrumError::model(format!("to_vec1<u32> failed: {}", e))),
240 2 => {
241 let batch = self
243 .tensor
244 .to_vec2::<u32>()
245 .map_err(|e| FerrumError::model(format!("to_vec2<u32> failed: {}", e)))?;
246 Ok(batch.into_iter().next().unwrap_or_default())
247 }
248 _ => Err(FerrumError::model(format!(
249 "Unsupported dims for token extraction: {:?}",
250 self.tensor.dims()
251 ))),
252 }
253 }
254
255 fn argmax_last_dim_u32(&self) -> Result<u32> {
256 use candle_core::{IndexOp, D};
258
259 let dims = self.tensor.dims();
260 let logits_1d = match dims.len() {
261 1 => self.tensor.clone(),
262 2 => self
263 .tensor
264 .i(0)
265 .map_err(|e| FerrumError::model(format!("Index batch failed: {}", e)))?,
266 3 => {
267 let seq_len = dims[1];
268 self.tensor
269 .i((0, seq_len.saturating_sub(1)))
270 .map_err(|e| FerrumError::model(format!("Index last token failed: {}", e)))?
271 }
272 4 => {
273 let seq_len = dims[1];
275 self.tensor
276 .i((0, seq_len.saturating_sub(1), 0))
277 .map_err(|e| {
278 FerrumError::model(format!("Index last token (4D) failed: {}", e))
279 })?
280 }
281 _ => {
282 return Err(FerrumError::model(format!(
283 "argmax_last_dim_u32 unsupported dims: {:?}",
284 dims
285 )))
286 }
287 };
288
289 let idx = logits_1d
290 .argmax(D::Minus1)
291 .map_err(|e| FerrumError::model(format!("Argmax failed: {}", e)))?
292 .to_device(&candle_core::Device::Cpu)
293 .map_err(|e| FerrumError::model(format!("Argmax to CPU failed: {}", e)))?
294 .to_vec0::<u32>()
295 .map_err(|e| FerrumError::model(format!("Argmax readback failed: {}", e)))?;
296
297 Ok(idx)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn view_extracts_last_sequence_slice() {
307 let tensor = Tensor::from_vec(
308 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
309 (1, 2, 3),
310 &candle_core::Device::Cpu,
311 )
312 .expect("create tensor");
313 let wrapper = CandleTensorWrapper::new(tensor);
314
315 let view = wrapper.view(&[0, 1, 0], &[1, 2, 3]).expect("slice view");
316 assert_eq!(view.shape(), &[1, 1, 3]);
317 assert_eq!(view.to_vec_f32().expect("to_vec_f32"), vec![4.0, 5.0, 6.0]);
318 }
319}