Skip to main content

ferrum_runtime/backends/
candle.rs

1//! Candle backend - MVP implementation with core functionality
2//!
3//! Provides basic Candle tensor operations for CPU and GPU devices.
4
5use crate::{
6    ComputeBackend, DeviceMemoryManager, MemoryPool, TensorFactory, TensorLike, TensorOps,
7    TensorRef,
8};
9use async_trait::async_trait;
10use ferrum_interfaces::backend::{BackendCapabilities, BackendStatus, KernelExecutor};
11use ferrum_interfaces::kernel_ops::KernelOps;
12use ferrum_types::{DataType, Device, Result};
13use std::any::Any;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tracing::debug;
17
18/// Candle tensor wrapper
19pub struct CandleTensor {
20    inner: candle_core::Tensor,
21    device: Device,
22    dtype: DataType,
23}
24
25impl CandleTensor {
26    pub fn new(tensor: candle_core::Tensor) -> Result<Self> {
27        let device = candle_device_to_ferrum(tensor.device())?;
28        let dtype = candle_dtype_to_ferrum(tensor.dtype())?;
29
30        Ok(Self {
31            inner: tensor,
32            device,
33            dtype,
34        })
35    }
36
37    pub fn inner(&self) -> &candle_core::Tensor {
38        &self.inner
39    }
40}
41
42impl std::fmt::Debug for CandleTensor {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("CandleTensor")
45            .field("shape", &self.inner.dims())
46            .field("dtype", &self.dtype)
47            .field("device", &self.device)
48            .finish()
49    }
50}
51
52impl TensorLike for CandleTensor {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn shape(&self) -> &[usize] {
58        self.inner.dims()
59    }
60
61    fn dtype(&self) -> DataType {
62        self.dtype
63    }
64
65    fn device(&self) -> Device {
66        self.device.clone()
67    }
68
69    fn to_device(&self, device: &Device) -> Result<TensorRef> {
70        let candle_device = ferrum_device_to_candle(device.clone())?;
71        let moved = self
72            .inner
73            .to_device(&candle_device)
74            .map_err(|e| ferrum_types::FerrumError::backend(format!("Device transfer: {}", e)))?;
75        Ok(Arc::new(Self::new(moved)?))
76    }
77
78    fn to_dtype(&self, dtype: DataType) -> Result<TensorRef> {
79        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
80        let converted = self
81            .inner
82            .to_dtype(candle_dtype)
83            .map_err(|e| ferrum_types::FerrumError::backend(format!("DType conversion: {}", e)))?;
84        Ok(Arc::new(Self::new(converted)?))
85    }
86
87    fn to_vec_f32(&self) -> Result<Vec<f32>> {
88        // Extract tensor data as Vec<f32>
89        match self.inner.dims().len() {
90            1 => self
91                .inner
92                .to_vec1::<f32>()
93                .map_err(|e| ferrum_types::FerrumError::backend(format!("to_vec1 failed: {}", e))),
94            2 => {
95                let batch = self.inner.to_vec2::<f32>().map_err(|e| {
96                    ferrum_types::FerrumError::backend(format!("to_vec2 failed: {}", e))
97                })?;
98                Ok(batch.into_iter().next().unwrap_or_default())
99            }
100            3 => {
101                let all = self.inner.to_vec3::<f32>().map_err(|e| {
102                    ferrum_types::FerrumError::backend(format!("to_vec3 failed: {}", e))
103                })?;
104                Ok(all
105                    .into_iter()
106                    .next()
107                    .and_then(|seq| seq.into_iter().last())
108                    .unwrap_or_default())
109            }
110            _ => Err(ferrum_types::FerrumError::backend(format!(
111                "Unsupported tensor dimensions: {:?}",
112                self.inner.dims()
113            ))),
114        }
115    }
116
117    fn to_vec_u32(&self) -> Result<Vec<u32>> {
118        let cpu_tensor = self
119            .inner
120            .to_device(&candle_core::Device::Cpu)
121            .map_err(|e| ferrum_types::FerrumError::backend(format!("to_cpu failed: {}", e)))?;
122
123        match cpu_tensor.dims().len() {
124            1 => match cpu_tensor.to_vec1::<u32>() {
125                Ok(tokens) => Ok(tokens),
126                Err(_) => cpu_tensor
127                    .to_vec1::<f32>()
128                    .map(|tokens| tokens.into_iter().map(|x| x as u32).collect())
129                    .map_err(|e| {
130                        ferrum_types::FerrumError::backend(format!(
131                            "to_vec1<u32/f32> failed: {}",
132                            e
133                        ))
134                    }),
135            },
136            2 => match cpu_tensor.to_vec2::<u32>() {
137                Ok(batch) => Ok(batch.into_iter().next().unwrap_or_default()),
138                Err(_) => cpu_tensor
139                    .to_vec2::<f32>()
140                    .map(|batch| {
141                        batch
142                            .into_iter()
143                            .next()
144                            .unwrap_or_default()
145                            .into_iter()
146                            .map(|x| x as u32)
147                            .collect()
148                    })
149                    .map_err(|e| {
150                        ferrum_types::FerrumError::backend(format!(
151                            "to_vec2<u32/f32> failed: {}",
152                            e
153                        ))
154                    }),
155            },
156            _ => Err(ferrum_types::FerrumError::backend(format!(
157                "Unsupported tensor dimensions for token extraction: {:?}",
158                cpu_tensor.dims()
159            ))),
160        }
161    }
162
163    fn reshape(&self, shape: &[usize]) -> Result<TensorRef> {
164        let reshaped = self
165            .inner
166            .reshape(shape)
167            .map_err(|e| ferrum_types::FerrumError::backend(format!("Reshape: {}", e)))?;
168        Ok(Arc::new(Self::new(reshaped)?))
169    }
170
171    fn to_cpu(&self) -> Result<TensorRef> {
172        self.to_device(&Device::CPU)
173    }
174
175    fn view(&self, _start: &[usize], _end: &[usize]) -> Result<TensorRef> {
176        // MVP: simplified, return clone
177        Ok(Arc::new(Self {
178            inner: self.inner.clone(),
179            device: self.device.clone(),
180            dtype: self.dtype,
181        }))
182    }
183
184    fn is_contiguous(&self) -> bool {
185        self.inner.is_contiguous()
186    }
187
188    fn argmax_last_dim_u32(&self) -> Result<u32> {
189        // Fast path for greedy sampling: compute argmax on the tensor's device,
190        // then transfer only a single scalar to CPU.
191        //
192        // This is intentionally conservative: it assumes batch=1 and returns the
193        // first element when batch exists.
194        use candle_core::{IndexOp, D};
195
196        let dims = self.inner.dims();
197        let logits_1d = match dims.len() {
198            1 => self.inner.clone(),
199            2 => {
200                // [batch, vocab] -> take batch 0 -> [vocab]
201                self.inner.i(0).map_err(|e| {
202                    ferrum_types::FerrumError::backend(format!("Index batch failed: {}", e))
203                })?
204            }
205            3 => {
206                // [batch, seq, vocab] -> take batch 0, last seq -> [vocab]
207                let seq_len = dims[1];
208                self.inner.i((0, seq_len.saturating_sub(1))).map_err(|e| {
209                    ferrum_types::FerrumError::backend(format!("Index last token failed: {}", e))
210                })?
211            }
212            _ => {
213                return Err(ferrum_types::FerrumError::backend(format!(
214                    "argmax_last_dim_u32 unsupported dims: {:?}",
215                    dims
216                )))
217            }
218        };
219
220        // Candle argmax returns a tensor; we read back a single u32.
221        let idx = logits_1d
222            .argmax(D::Minus1)
223            .map_err(|e| ferrum_types::FerrumError::backend(format!("Argmax failed: {}", e)))?
224            .to_device(&candle_core::Device::Cpu)
225            .map_err(|e| {
226                ferrum_types::FerrumError::backend(format!("Argmax to CPU failed: {}", e))
227            })?
228            .to_vec0::<u32>()
229            .map_err(|e| {
230                ferrum_types::FerrumError::backend(format!("Argmax readback failed: {}", e))
231            })?;
232
233        Ok(idx)
234    }
235}
236
237/// Candle tensor factory
238pub struct CandleTensorFactory {
239    device: Device,
240}
241
242impl CandleTensorFactory {
243    pub fn new(device: Device) -> Self {
244        Self { device }
245    }
246}
247
248impl std::fmt::Debug for CandleTensorFactory {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("CandleTensorFactory")
251            .field("device", &self.device)
252            .finish()
253    }
254}
255
256impl TensorFactory for CandleTensorFactory {
257    fn empty(&self, shape: &[usize], dtype: DataType, device: Device) -> Result<TensorRef> {
258        let candle_device = ferrum_device_to_candle(device)?;
259        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
260
261        let tensor = candle_core::Tensor::zeros(shape, candle_dtype, &candle_device)
262            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
263        Ok(Arc::new(CandleTensor::new(tensor)?))
264    }
265
266    fn from_slice(
267        &self,
268        data: &[f32],
269        shape: &[usize],
270        dtype: DataType,
271        device: Device,
272    ) -> Result<TensorRef> {
273        let candle_device = ferrum_device_to_candle(device)?;
274        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
275
276        let tensor = candle_core::Tensor::from_slice(data, shape, &candle_device)
277            .and_then(|t| t.to_dtype(candle_dtype))
278            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
279
280        Ok(Arc::new(CandleTensor::new(tensor)?))
281    }
282
283    fn to_device(&self, tensor: &TensorRef, device: Device) -> Result<TensorRef> {
284        tensor.to_device(&device)
285    }
286
287    fn narrow(
288        &self,
289        tensor: &TensorRef,
290        dim: usize,
291        start: usize,
292        length: usize,
293    ) -> Result<TensorRef> {
294        let candle_tensor = get_candle_tensor(tensor)?;
295        let narrowed = candle_tensor
296            .narrow(dim, start, length)
297            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
298        Ok(Arc::new(CandleTensor::new(narrowed)?))
299    }
300
301    fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef> {
302        tensor.reshape(shape)
303    }
304
305    fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef> {
306        let candle_tensor = get_candle_tensor(tensor)?;
307        let zeros = candle_core::Tensor::zeros(
308            candle_tensor.shape(),
309            candle_tensor.dtype(),
310            candle_tensor.device(),
311        )
312        .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
313        Ok(Arc::new(CandleTensor::new(zeros)?))
314    }
315
316    fn zeros(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef> {
317        let candle_device = ferrum_device_to_candle(device.clone())?;
318        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
319
320        let tensor = candle_core::Tensor::zeros(shape, candle_dtype, &candle_device)
321            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
322        Ok(Arc::new(CandleTensor::new(tensor)?))
323    }
324
325    fn ones(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef> {
326        let candle_device = ferrum_device_to_candle(device.clone())?;
327        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
328
329        let tensor = candle_core::Tensor::ones(shape, candle_dtype, &candle_device)
330            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
331        Ok(Arc::new(CandleTensor::new(tensor)?))
332    }
333
334    fn uniform(
335        &self,
336        shape: &[usize],
337        low: f32,
338        high: f32,
339        dtype: DataType,
340        device: &Device,
341    ) -> Result<TensorRef> {
342        let candle_device = ferrum_device_to_candle(device.clone())?;
343        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
344
345        let tensor = candle_core::Tensor::rand(low, high, shape, &candle_device)
346            .and_then(|t| t.to_dtype(candle_dtype))
347            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
348        Ok(Arc::new(CandleTensor::new(tensor)?))
349    }
350
351    fn normal(
352        &self,
353        shape: &[usize],
354        mean: f32,
355        std: f32,
356        dtype: DataType,
357        device: &Device,
358    ) -> Result<TensorRef> {
359        let candle_device = ferrum_device_to_candle(device.clone())?;
360        let candle_dtype = ferrum_dtype_to_candle(dtype)?;
361
362        let tensor = candle_core::Tensor::randn(mean, std, shape, &candle_device)
363            .and_then(|t| t.to_dtype(candle_dtype))
364            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
365        Ok(Arc::new(CandleTensor::new(tensor)?))
366    }
367
368    fn from_tensor(&self, tensor: &TensorRef, device: &Device) -> Result<TensorRef> {
369        tensor.to_device(device)
370    }
371}
372
373/// Candle tensor operations
374#[derive(Debug, Clone, Default)]
375pub struct CandleTensorOps;
376
377impl TensorOps for CandleTensorOps {
378    fn matmul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
379        let a_candle = get_candle_tensor(a)?;
380        let b_candle = get_candle_tensor(b)?;
381
382        let result = a_candle
383            .matmul(b_candle)
384            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
385        Ok(Arc::new(CandleTensor::new(result)?))
386    }
387
388    fn add(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
389        let a_candle = get_candle_tensor(a)?;
390        let b_candle = get_candle_tensor(b)?;
391
392        let result =
393            (a_candle + b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
394        Ok(Arc::new(CandleTensor::new(result)?))
395    }
396
397    fn mul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
398        let a_candle = get_candle_tensor(a)?;
399        let b_candle = get_candle_tensor(b)?;
400
401        let result =
402            (a_candle * b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
403        Ok(Arc::new(CandleTensor::new(result)?))
404    }
405
406    fn sub(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
407        let a_candle = get_candle_tensor(a)?;
408        let b_candle = get_candle_tensor(b)?;
409
410        let result =
411            (a_candle - b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
412        Ok(Arc::new(CandleTensor::new(result)?))
413    }
414
415    fn div(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
416        let a_candle = get_candle_tensor(a)?;
417        let b_candle = get_candle_tensor(b)?;
418
419        let result =
420            (a_candle / b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
421        Ok(Arc::new(CandleTensor::new(result)?))
422    }
423
424    fn softmax(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef> {
425        let candle_tensor = get_candle_tensor(tensor)?;
426        let dim_usize = if dim < 0 {
427            (candle_tensor.rank() as i32 + dim) as usize
428        } else {
429            dim as usize
430        };
431
432        let result = candle_nn::ops::softmax(candle_tensor, dim_usize)
433            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
434        Ok(Arc::new(CandleTensor::new(result)?))
435    }
436
437    fn layer_norm(
438        &self,
439        input: &TensorRef,
440        weight: &TensorRef,
441        bias: Option<&TensorRef>,
442        eps: f32,
443    ) -> Result<TensorRef> {
444        let input_candle = get_candle_tensor(input)?;
445        let weight_candle = get_candle_tensor(weight)?;
446        let _bias_candle = bias.map(|b| get_candle_tensor(b)).transpose()?;
447
448        // MVP: simplified layer norm
449        let zero_bias = candle_core::Tensor::zeros(
450            weight_candle.shape(),
451            weight_candle.dtype(),
452            weight_candle.device(),
453        )
454        .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
455
456        let bias_tensor = if let Some(b) = _bias_candle {
457            b
458        } else {
459            &zero_bias
460        };
461
462        let normalized = candle_nn::ops::layer_norm(input_candle, weight_candle, bias_tensor, eps)
463            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
464        Ok(Arc::new(CandleTensor::new(normalized)?))
465    }
466
467    fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
468        let input_candle = get_candle_tensor(input)?;
469        let weight_candle = get_candle_tensor(weight)?;
470
471        let _rms = candle_nn::RmsNorm::new(weight_candle.clone(), eps as f64);
472        let result = candle_nn::ops::rms_norm(input_candle, weight_candle, eps)
473            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
474        Ok(Arc::new(CandleTensor::new(result)?))
475    }
476
477    fn relu(&self, tensor: &TensorRef) -> Result<TensorRef> {
478        let candle_tensor = get_candle_tensor(tensor)?;
479        let result = candle_tensor
480            .relu()
481            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
482        Ok(Arc::new(CandleTensor::new(result)?))
483    }
484
485    fn gelu(&self, tensor: &TensorRef) -> Result<TensorRef> {
486        let candle_tensor = get_candle_tensor(tensor)?;
487        let result = candle_tensor
488            .gelu()
489            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
490        Ok(Arc::new(CandleTensor::new(result)?))
491    }
492
493    fn silu(&self, tensor: &TensorRef) -> Result<TensorRef> {
494        let candle_tensor = get_candle_tensor(tensor)?;
495        let result = candle_nn::ops::silu(candle_tensor)
496            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
497        Ok(Arc::new(CandleTensor::new(result)?))
498    }
499
500    fn concat(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef> {
501        let candle_tensors: Result<Vec<_>> = tensors.iter().map(|t| get_candle_tensor(t)).collect();
502        let candle_tensors = candle_tensors?;
503
504        let result = candle_core::Tensor::cat(&candle_tensors, dim)
505            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
506        Ok(Arc::new(CandleTensor::new(result)?))
507    }
508
509    fn split(&self, tensor: &TensorRef, sizes: &[usize], dim: usize) -> Result<Vec<TensorRef>> {
510        let candle_tensor = get_candle_tensor(tensor)?;
511        let mut result = Vec::new();
512        let mut offset = 0;
513
514        for &size in sizes {
515            let chunk = candle_tensor
516                .narrow(dim, offset, size)
517                .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
518            result.push(Arc::new(CandleTensor::new(chunk)?) as TensorRef);
519            offset += size;
520        }
521
522        Ok(result)
523    }
524
525    fn transpose(&self, tensor: &TensorRef, dim0: usize, dim1: usize) -> Result<TensorRef> {
526        let candle_tensor = get_candle_tensor(tensor)?;
527        let result = candle_tensor
528            .transpose(dim0, dim1)
529            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
530        Ok(Arc::new(CandleTensor::new(result)?))
531    }
532
533    fn permute(&self, tensor: &TensorRef, dims: &[usize]) -> Result<TensorRef> {
534        let candle_tensor = get_candle_tensor(tensor)?;
535        let result = candle_tensor
536            .permute(dims)
537            .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
538        Ok(Arc::new(CandleTensor::new(result)?))
539    }
540}
541
542/// Candle backend
543pub struct CandleBackend {
544    device: Device,
545    tensor_factory: CandleTensorFactory,
546    tensor_ops: CandleTensorOps,
547    kernel_ops: super::candle_kernel_ops::CandleKernelOps,
548    memory_manager: MemoryPool,
549}
550
551impl std::fmt::Debug for CandleBackend {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        f.debug_struct("CandleBackend")
554            .field("device", &self.device)
555            .finish()
556    }
557}
558
559impl CandleBackend {
560    pub async fn new(device: Device) -> Result<Self> {
561        debug!("Initializing Candle backend for: {:?}", device);
562
563        let tensor_factory = CandleTensorFactory::new(device.clone());
564        let tensor_ops = CandleTensorOps;
565        let kernel_ops = super::candle_kernel_ops::CandleKernelOps::new();
566        let memory_manager = MemoryPool::new(
567            device.clone(),
568            crate::memory::InternalMemoryPoolConfig {
569                initial_size: 1024 * 1024 * 1024, // 1GB
570                max_size: 4 * 1024 * 1024 * 1024, // 4GB
571                growth_factor: 1.5,
572                enable_defragmentation: true,
573                min_pooled_size: 256,
574                max_pooled_size: 1024 * 1024, // 1MB
575                size_buckets: 32,
576            },
577        );
578
579        Ok(Self {
580            device,
581            tensor_factory,
582            tensor_ops,
583            kernel_ops,
584            memory_manager,
585        })
586    }
587}
588
589#[async_trait]
590impl ComputeBackend for CandleBackend {
591    fn name(&self) -> &str {
592        "candle"
593    }
594
595    fn capabilities(&self) -> BackendCapabilities {
596        let supported_devices = {
597            #[cfg(all(feature = "metal", any(target_os = "macos", target_os = "ios")))]
598            {
599                vec![Device::CPU, Device::CUDA(0), Device::Metal]
600            }
601            #[cfg(not(all(feature = "metal", any(target_os = "macos", target_os = "ios"))))]
602            {
603                vec![Device::CPU, Device::CUDA(0)]
604            }
605        };
606
607        BackendCapabilities {
608            supported_dtypes: vec![DataType::FP32, DataType::FP16, DataType::BF16],
609            supported_devices,
610            max_tensor_dims: 8,
611            supports_fp16: true,
612            supports_bf16: true,
613            supports_int8: false,
614            supports_flash_attention: false,
615            supports_paged_attention: false,
616            supports_tensor_parallelism: false,
617            supports_pipeline_parallelism: false,
618            max_batch_size: 32,
619            max_sequence_length: 4096,
620            memory_alignment: 256,
621            supports_custom_kernels: false,
622            supports_cuda_graphs: false,
623            extra_capabilities: HashMap::new(),
624        }
625    }
626
627    fn tensor_ops(&self) -> &dyn TensorOps {
628        &self.tensor_ops
629    }
630
631    fn tensor_factory(&self) -> &dyn TensorFactory {
632        &self.tensor_factory
633    }
634
635    fn memory_manager(&self) -> &dyn DeviceMemoryManager {
636        &self.memory_manager
637    }
638
639    fn kernel_executor(&self) -> Option<&dyn KernelExecutor> {
640        None // MVP: no custom kernels
641    }
642
643    fn kernel_ops(&self) -> Option<&dyn KernelOps> {
644        Some(&self.kernel_ops)
645    }
646
647    async fn initialize(&mut self, _device: &Device) -> Result<()> {
648        // Already initialized in new()
649        Ok(())
650    }
651
652    fn supports_device(&self, device: &Device) -> bool {
653        match device {
654            Device::CPU | Device::CUDA(_) => true,
655            Device::ROCm(_) => false,
656            #[cfg(any(target_os = "macos", target_os = "ios"))]
657            Device::Metal => cfg!(feature = "metal"),
658        }
659    }
660
661    fn version(&self) -> String {
662        env!("CARGO_PKG_VERSION").to_string()
663    }
664
665    async fn synchronize(&self, _device: &Device) -> Result<()> {
666        // MVP: no-op for CPU, would need actual sync for GPU
667        Ok(())
668    }
669
670    fn status(&self) -> BackendStatus {
671        BackendStatus {
672            is_initialized: true,
673            is_ready: true,
674            active_devices: vec![self.device.clone()],
675            memory_usage: HashMap::new(),
676            operations_completed: 0,
677            last_error: None,
678            backend_specific: HashMap::new(),
679        }
680    }
681
682    async fn shutdown(&mut self) -> Result<()> {
683        debug!("Shutting down Candle backend");
684        Ok(())
685    }
686}
687
688// ============================================================================
689// Helper Functions
690// ============================================================================
691
692fn get_candle_tensor(tensor: &TensorRef) -> Result<&candle_core::Tensor> {
693    // MVP: use type_id check since as_any not in TensorLike trait yet
694    let concrete_ref: &CandleTensor = unsafe {
695        // This is safe if we always create tensors through this backend
696        &*(Arc::as_ptr(tensor) as *const CandleTensor)
697    };
698    Ok(&concrete_ref.inner)
699}
700
701fn ferrum_dtype_to_candle(dtype: DataType) -> Result<candle_core::DType> {
702    match dtype {
703        DataType::FP32 => Ok(candle_core::DType::F32),
704        DataType::FP16 => Ok(candle_core::DType::F16),
705        DataType::BF16 => Ok(candle_core::DType::BF16),
706        DataType::UINT32 => Ok(candle_core::DType::U32),
707        DataType::UINT8 => Ok(candle_core::DType::U8),
708        DataType::INT32 => Ok(candle_core::DType::U32), // Fallback
709        _ => Err(ferrum_types::FerrumError::backend(format!(
710            "Unsupported dtype: {:?}",
711            dtype
712        ))),
713    }
714}
715
716fn candle_dtype_to_ferrum(dtype: candle_core::DType) -> Result<DataType> {
717    match dtype {
718        candle_core::DType::F32 => Ok(DataType::FP32),
719        candle_core::DType::F16 => Ok(DataType::FP16),
720        candle_core::DType::BF16 => Ok(DataType::BF16),
721        candle_core::DType::U32 => Ok(DataType::UINT32),
722        candle_core::DType::U8 => Ok(DataType::UINT8),
723        _ => Err(ferrum_types::FerrumError::backend(format!(
724            "Unsupported Candle dtype: {:?}",
725            dtype
726        ))),
727    }
728}
729
730fn ferrum_device_to_candle(device: Device) -> Result<candle_core::Device> {
731    match device {
732        Device::CPU => Ok(candle_core::Device::Cpu),
733        Device::CUDA(id) => {
734            #[cfg(feature = "cuda")]
735            {
736                candle_core::Device::new_cuda(id as usize)
737                    .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))
738            }
739            #[cfg(not(feature = "cuda"))]
740            {
741                let _ = id;
742                Err(ferrum_types::FerrumError::unsupported("CUDA not enabled"))
743            }
744        }
745        #[cfg(any(target_os = "macos", target_os = "ios"))]
746        Device::Metal => {
747            #[cfg(feature = "metal")]
748            {
749                candle_core::Device::new_metal(0)
750                    .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))
751            }
752            #[cfg(not(feature = "metal"))]
753            {
754                Err(ferrum_types::FerrumError::unsupported("Metal not enabled"))
755            }
756        }
757        Device::ROCm(_) => Err(ferrum_types::FerrumError::unsupported("ROCm not supported")),
758    }
759}
760
761fn candle_device_to_ferrum(device: &candle_core::Device) -> Result<Device> {
762    match device {
763        candle_core::Device::Cpu => Ok(Device::CPU),
764        candle_core::Device::Cuda(_) => Ok(Device::CUDA(0)), // Default to GPU 0
765        candle_core::Device::Metal(_) => {
766            #[cfg(any(target_os = "macos", target_os = "ios"))]
767            {
768                Ok(Device::Metal)
769            }
770            #[cfg(not(any(target_os = "macos", target_os = "ios")))]
771            {
772                Err(ferrum_types::FerrumError::unsupported(
773                    "Metal devices are not available on this platform",
774                ))
775            }
776        }
777    }
778}
779
780// ============================================================================
781// Unit Tests
782// ============================================================================
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787
788    #[test]
789    fn test_dtype_conversions() {
790        // FP32
791        let candle_fp32 = ferrum_dtype_to_candle(DataType::FP32).unwrap();
792        let back_fp32 = candle_dtype_to_ferrum(candle_fp32).unwrap();
793        assert_eq!(back_fp32, DataType::FP32);
794
795        // FP16
796        let candle_fp16 = ferrum_dtype_to_candle(DataType::FP16).unwrap();
797        let back_fp16 = candle_dtype_to_ferrum(candle_fp16).unwrap();
798        assert_eq!(back_fp16, DataType::FP16);
799    }
800
801    #[test]
802    fn test_device_conversions_cpu() {
803        let ferrum_device = Device::CPU;
804        let candle_device = ferrum_device_to_candle(ferrum_device.clone()).unwrap();
805        let back_device = candle_device_to_ferrum(&candle_device).unwrap();
806        assert_eq!(back_device, Device::CPU);
807    }
808
809    #[tokio::test]
810    async fn test_candle_backend_creation() {
811        let backend = CandleBackend::new(Device::CPU).await;
812        assert!(backend.is_ok());
813    }
814
815    #[tokio::test]
816    async fn test_candle_backend_name() {
817        let backend = CandleBackend::new(Device::CPU).await.unwrap();
818        assert_eq!(backend.name(), "candle");
819    }
820
821    #[tokio::test]
822    async fn test_candle_backend_capabilities() {
823        let backend = CandleBackend::new(Device::CPU).await.unwrap();
824        let caps = backend.capabilities();
825
826        assert!(caps.supported_dtypes.contains(&DataType::FP32));
827        assert!(caps.max_tensor_dims > 0);
828    }
829
830    #[tokio::test]
831    async fn test_candle_backend_supports_cpu() {
832        let backend = CandleBackend::new(Device::CPU).await.unwrap();
833        assert!(backend.supports_device(&Device::CPU));
834    }
835
836    #[test]
837    fn test_tensor_factory_zeros() {
838        let factory = CandleTensorFactory::new(Device::CPU);
839        let tensor = factory
840            .zeros(&[2, 3], DataType::FP32, &Device::CPU)
841            .unwrap();
842
843        assert_eq!(tensor.shape(), &[2, 3]);
844        assert_eq!(tensor.dtype(), DataType::FP32);
845    }
846
847    #[test]
848    fn test_tensor_factory_ones() {
849        let factory = CandleTensorFactory::new(Device::CPU);
850        let tensor = factory.ones(&[2, 2], DataType::FP32, &Device::CPU).unwrap();
851
852        assert_eq!(tensor.shape(), &[2, 2]);
853    }
854
855    #[test]
856    fn test_tensor_ops_add() {
857        let factory = CandleTensorFactory::new(Device::CPU);
858        let ops = CandleTensorOps;
859
860        let a = factory
861            .from_slice(&[1.0, 2.0], &[2], DataType::FP32, Device::CPU)
862            .unwrap();
863        let b = factory
864            .from_slice(&[3.0, 4.0], &[2], DataType::FP32, Device::CPU)
865            .unwrap();
866
867        let result = ops.add(&a, &b).unwrap();
868        let data = result.to_vec_f32().unwrap();
869
870        assert!((data[0] - 4.0).abs() < 1e-5);
871        assert!((data[1] - 6.0).abs() < 1e-5);
872    }
873
874    #[test]
875    fn test_tensor_ops_matmul() {
876        let factory = CandleTensorFactory::new(Device::CPU);
877        let ops = CandleTensorOps;
878
879        // 2x2 matrices
880        let a = factory
881            .from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], DataType::FP32, Device::CPU)
882            .unwrap();
883        let b = factory
884            .from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], DataType::FP32, Device::CPU)
885            .unwrap();
886
887        let result = ops.matmul(&a, &b).unwrap();
888        assert_eq!(result.shape(), &[2, 2]);
889    }
890
891    #[test]
892    fn test_tensor_reshape() {
893        let factory = CandleTensorFactory::new(Device::CPU);
894        let tensor = factory
895            .zeros(&[2, 3], DataType::FP32, &Device::CPU)
896            .unwrap();
897
898        let reshaped = tensor.reshape(&[3, 2]).unwrap();
899        assert_eq!(reshaped.shape(), &[3, 2]);
900    }
901
902    #[test]
903    fn test_tensor_to_cpu() {
904        let factory = CandleTensorFactory::new(Device::CPU);
905        let tensor = factory
906            .zeros(&[2, 3], DataType::FP32, &Device::CPU)
907            .unwrap();
908
909        let cpu_tensor = tensor.to_cpu().unwrap();
910        assert_eq!(cpu_tensor.device(), Device::CPU);
911    }
912
913    #[test]
914    fn test_tensor_to_vec_u32_from_fp32_ids() {
915        let factory = CandleTensorFactory::new(Device::CPU);
916        let tensor = factory
917            .from_slice(&[1.0, 2.0, 3.0], &[1, 3], DataType::FP32, Device::CPU)
918            .unwrap();
919
920        let tokens = tensor.to_vec_u32().unwrap();
921        assert_eq!(tokens, vec![1, 2, 3]);
922    }
923}