1use crate::{
6 ComputeBackend, DeviceMemoryManager, MemoryPool, TensorFactory, TensorLike, TensorOps,
7 TensorRef,
8};
9use async_trait::async_trait;
10use ferrum_interfaces::backend::{BackendCapabilities, BackendStatus, KernelExecutor};
11use ferrum_interfaces::kernel_ops::KernelOps;
12use ferrum_types::{DataType, Device, Result};
13use std::any::Any;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tracing::debug;
17
18pub struct CandleTensor {
20 inner: candle_core::Tensor,
21 device: Device,
22 dtype: DataType,
23}
24
25impl CandleTensor {
26 pub fn new(tensor: candle_core::Tensor) -> Result<Self> {
27 let device = candle_device_to_ferrum(tensor.device())?;
28 let dtype = candle_dtype_to_ferrum(tensor.dtype())?;
29
30 Ok(Self {
31 inner: tensor,
32 device,
33 dtype,
34 })
35 }
36
37 pub fn inner(&self) -> &candle_core::Tensor {
38 &self.inner
39 }
40}
41
42impl std::fmt::Debug for CandleTensor {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("CandleTensor")
45 .field("shape", &self.inner.dims())
46 .field("dtype", &self.dtype)
47 .field("device", &self.device)
48 .finish()
49 }
50}
51
52impl TensorLike for CandleTensor {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn shape(&self) -> &[usize] {
58 self.inner.dims()
59 }
60
61 fn dtype(&self) -> DataType {
62 self.dtype
63 }
64
65 fn device(&self) -> Device {
66 self.device.clone()
67 }
68
69 fn to_device(&self, device: &Device) -> Result<TensorRef> {
70 let candle_device = ferrum_device_to_candle(device.clone())?;
71 let moved = self
72 .inner
73 .to_device(&candle_device)
74 .map_err(|e| ferrum_types::FerrumError::backend(format!("Device transfer: {}", e)))?;
75 Ok(Arc::new(Self::new(moved)?))
76 }
77
78 fn to_dtype(&self, dtype: DataType) -> Result<TensorRef> {
79 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
80 let converted = self
81 .inner
82 .to_dtype(candle_dtype)
83 .map_err(|e| ferrum_types::FerrumError::backend(format!("DType conversion: {}", e)))?;
84 Ok(Arc::new(Self::new(converted)?))
85 }
86
87 fn to_vec_f32(&self) -> Result<Vec<f32>> {
88 match self.inner.dims().len() {
90 1 => self
91 .inner
92 .to_vec1::<f32>()
93 .map_err(|e| ferrum_types::FerrumError::backend(format!("to_vec1 failed: {}", e))),
94 2 => {
95 let batch = self.inner.to_vec2::<f32>().map_err(|e| {
96 ferrum_types::FerrumError::backend(format!("to_vec2 failed: {}", e))
97 })?;
98 Ok(batch.into_iter().next().unwrap_or_default())
99 }
100 3 => {
101 let all = self.inner.to_vec3::<f32>().map_err(|e| {
102 ferrum_types::FerrumError::backend(format!("to_vec3 failed: {}", e))
103 })?;
104 Ok(all
105 .into_iter()
106 .next()
107 .and_then(|seq| seq.into_iter().last())
108 .unwrap_or_default())
109 }
110 _ => Err(ferrum_types::FerrumError::backend(format!(
111 "Unsupported tensor dimensions: {:?}",
112 self.inner.dims()
113 ))),
114 }
115 }
116
117 fn to_vec_u32(&self) -> Result<Vec<u32>> {
118 let cpu_tensor = self
119 .inner
120 .to_device(&candle_core::Device::Cpu)
121 .map_err(|e| ferrum_types::FerrumError::backend(format!("to_cpu failed: {}", e)))?;
122
123 match cpu_tensor.dims().len() {
124 1 => match cpu_tensor.to_vec1::<u32>() {
125 Ok(tokens) => Ok(tokens),
126 Err(_) => cpu_tensor
127 .to_vec1::<f32>()
128 .map(|tokens| tokens.into_iter().map(|x| x as u32).collect())
129 .map_err(|e| {
130 ferrum_types::FerrumError::backend(format!(
131 "to_vec1<u32/f32> failed: {}",
132 e
133 ))
134 }),
135 },
136 2 => match cpu_tensor.to_vec2::<u32>() {
137 Ok(batch) => Ok(batch.into_iter().next().unwrap_or_default()),
138 Err(_) => cpu_tensor
139 .to_vec2::<f32>()
140 .map(|batch| {
141 batch
142 .into_iter()
143 .next()
144 .unwrap_or_default()
145 .into_iter()
146 .map(|x| x as u32)
147 .collect()
148 })
149 .map_err(|e| {
150 ferrum_types::FerrumError::backend(format!(
151 "to_vec2<u32/f32> failed: {}",
152 e
153 ))
154 }),
155 },
156 _ => Err(ferrum_types::FerrumError::backend(format!(
157 "Unsupported tensor dimensions for token extraction: {:?}",
158 cpu_tensor.dims()
159 ))),
160 }
161 }
162
163 fn reshape(&self, shape: &[usize]) -> Result<TensorRef> {
164 let reshaped = self
165 .inner
166 .reshape(shape)
167 .map_err(|e| ferrum_types::FerrumError::backend(format!("Reshape: {}", e)))?;
168 Ok(Arc::new(Self::new(reshaped)?))
169 }
170
171 fn to_cpu(&self) -> Result<TensorRef> {
172 self.to_device(&Device::CPU)
173 }
174
175 fn view(&self, _start: &[usize], _end: &[usize]) -> Result<TensorRef> {
176 Ok(Arc::new(Self {
178 inner: self.inner.clone(),
179 device: self.device.clone(),
180 dtype: self.dtype,
181 }))
182 }
183
184 fn is_contiguous(&self) -> bool {
185 self.inner.is_contiguous()
186 }
187
188 fn argmax_last_dim_u32(&self) -> Result<u32> {
189 use candle_core::{IndexOp, D};
195
196 let dims = self.inner.dims();
197 let logits_1d = match dims.len() {
198 1 => self.inner.clone(),
199 2 => {
200 self.inner.i(0).map_err(|e| {
202 ferrum_types::FerrumError::backend(format!("Index batch failed: {}", e))
203 })?
204 }
205 3 => {
206 let seq_len = dims[1];
208 self.inner.i((0, seq_len.saturating_sub(1))).map_err(|e| {
209 ferrum_types::FerrumError::backend(format!("Index last token failed: {}", e))
210 })?
211 }
212 _ => {
213 return Err(ferrum_types::FerrumError::backend(format!(
214 "argmax_last_dim_u32 unsupported dims: {:?}",
215 dims
216 )))
217 }
218 };
219
220 let idx = logits_1d
222 .argmax(D::Minus1)
223 .map_err(|e| ferrum_types::FerrumError::backend(format!("Argmax failed: {}", e)))?
224 .to_device(&candle_core::Device::Cpu)
225 .map_err(|e| {
226 ferrum_types::FerrumError::backend(format!("Argmax to CPU failed: {}", e))
227 })?
228 .to_vec0::<u32>()
229 .map_err(|e| {
230 ferrum_types::FerrumError::backend(format!("Argmax readback failed: {}", e))
231 })?;
232
233 Ok(idx)
234 }
235}
236
237pub struct CandleTensorFactory {
239 device: Device,
240}
241
242impl CandleTensorFactory {
243 pub fn new(device: Device) -> Self {
244 Self { device }
245 }
246}
247
248impl std::fmt::Debug for CandleTensorFactory {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 f.debug_struct("CandleTensorFactory")
251 .field("device", &self.device)
252 .finish()
253 }
254}
255
256impl TensorFactory for CandleTensorFactory {
257 fn empty(&self, shape: &[usize], dtype: DataType, device: Device) -> Result<TensorRef> {
258 let candle_device = ferrum_device_to_candle(device)?;
259 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
260
261 let tensor = candle_core::Tensor::zeros(shape, candle_dtype, &candle_device)
262 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
263 Ok(Arc::new(CandleTensor::new(tensor)?))
264 }
265
266 fn from_slice(
267 &self,
268 data: &[f32],
269 shape: &[usize],
270 dtype: DataType,
271 device: Device,
272 ) -> Result<TensorRef> {
273 let candle_device = ferrum_device_to_candle(device)?;
274 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
275
276 let tensor = candle_core::Tensor::from_slice(data, shape, &candle_device)
277 .and_then(|t| t.to_dtype(candle_dtype))
278 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
279
280 Ok(Arc::new(CandleTensor::new(tensor)?))
281 }
282
283 fn to_device(&self, tensor: &TensorRef, device: Device) -> Result<TensorRef> {
284 tensor.to_device(&device)
285 }
286
287 fn narrow(
288 &self,
289 tensor: &TensorRef,
290 dim: usize,
291 start: usize,
292 length: usize,
293 ) -> Result<TensorRef> {
294 let candle_tensor = get_candle_tensor(tensor)?;
295 let narrowed = candle_tensor
296 .narrow(dim, start, length)
297 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
298 Ok(Arc::new(CandleTensor::new(narrowed)?))
299 }
300
301 fn reshape(&self, tensor: &TensorRef, shape: &[usize]) -> Result<TensorRef> {
302 tensor.reshape(shape)
303 }
304
305 fn zeros_like(&self, tensor: &TensorRef) -> Result<TensorRef> {
306 let candle_tensor = get_candle_tensor(tensor)?;
307 let zeros = candle_core::Tensor::zeros(
308 candle_tensor.shape(),
309 candle_tensor.dtype(),
310 candle_tensor.device(),
311 )
312 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
313 Ok(Arc::new(CandleTensor::new(zeros)?))
314 }
315
316 fn zeros(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef> {
317 let candle_device = ferrum_device_to_candle(device.clone())?;
318 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
319
320 let tensor = candle_core::Tensor::zeros(shape, candle_dtype, &candle_device)
321 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
322 Ok(Arc::new(CandleTensor::new(tensor)?))
323 }
324
325 fn ones(&self, shape: &[usize], dtype: DataType, device: &Device) -> Result<TensorRef> {
326 let candle_device = ferrum_device_to_candle(device.clone())?;
327 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
328
329 let tensor = candle_core::Tensor::ones(shape, candle_dtype, &candle_device)
330 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
331 Ok(Arc::new(CandleTensor::new(tensor)?))
332 }
333
334 fn uniform(
335 &self,
336 shape: &[usize],
337 low: f32,
338 high: f32,
339 dtype: DataType,
340 device: &Device,
341 ) -> Result<TensorRef> {
342 let candle_device = ferrum_device_to_candle(device.clone())?;
343 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
344
345 let tensor = candle_core::Tensor::rand(low, high, shape, &candle_device)
346 .and_then(|t| t.to_dtype(candle_dtype))
347 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
348 Ok(Arc::new(CandleTensor::new(tensor)?))
349 }
350
351 fn normal(
352 &self,
353 shape: &[usize],
354 mean: f32,
355 std: f32,
356 dtype: DataType,
357 device: &Device,
358 ) -> Result<TensorRef> {
359 let candle_device = ferrum_device_to_candle(device.clone())?;
360 let candle_dtype = ferrum_dtype_to_candle(dtype)?;
361
362 let tensor = candle_core::Tensor::randn(mean, std, shape, &candle_device)
363 .and_then(|t| t.to_dtype(candle_dtype))
364 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
365 Ok(Arc::new(CandleTensor::new(tensor)?))
366 }
367
368 fn from_tensor(&self, tensor: &TensorRef, device: &Device) -> Result<TensorRef> {
369 tensor.to_device(device)
370 }
371}
372
373#[derive(Debug, Clone, Default)]
375pub struct CandleTensorOps;
376
377impl TensorOps for CandleTensorOps {
378 fn matmul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
379 let a_candle = get_candle_tensor(a)?;
380 let b_candle = get_candle_tensor(b)?;
381
382 let result = a_candle
383 .matmul(b_candle)
384 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
385 Ok(Arc::new(CandleTensor::new(result)?))
386 }
387
388 fn add(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
389 let a_candle = get_candle_tensor(a)?;
390 let b_candle = get_candle_tensor(b)?;
391
392 let result =
393 (a_candle + b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
394 Ok(Arc::new(CandleTensor::new(result)?))
395 }
396
397 fn mul(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
398 let a_candle = get_candle_tensor(a)?;
399 let b_candle = get_candle_tensor(b)?;
400
401 let result =
402 (a_candle * b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
403 Ok(Arc::new(CandleTensor::new(result)?))
404 }
405
406 fn sub(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
407 let a_candle = get_candle_tensor(a)?;
408 let b_candle = get_candle_tensor(b)?;
409
410 let result =
411 (a_candle - b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
412 Ok(Arc::new(CandleTensor::new(result)?))
413 }
414
415 fn div(&self, a: &TensorRef, b: &TensorRef) -> Result<TensorRef> {
416 let a_candle = get_candle_tensor(a)?;
417 let b_candle = get_candle_tensor(b)?;
418
419 let result =
420 (a_candle / b_candle).map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
421 Ok(Arc::new(CandleTensor::new(result)?))
422 }
423
424 fn softmax(&self, tensor: &TensorRef, dim: i32) -> Result<TensorRef> {
425 let candle_tensor = get_candle_tensor(tensor)?;
426 let dim_usize = if dim < 0 {
427 (candle_tensor.rank() as i32 + dim) as usize
428 } else {
429 dim as usize
430 };
431
432 let result = candle_nn::ops::softmax(candle_tensor, dim_usize)
433 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
434 Ok(Arc::new(CandleTensor::new(result)?))
435 }
436
437 fn layer_norm(
438 &self,
439 input: &TensorRef,
440 weight: &TensorRef,
441 bias: Option<&TensorRef>,
442 eps: f32,
443 ) -> Result<TensorRef> {
444 let input_candle = get_candle_tensor(input)?;
445 let weight_candle = get_candle_tensor(weight)?;
446 let _bias_candle = bias.map(|b| get_candle_tensor(b)).transpose()?;
447
448 let zero_bias = candle_core::Tensor::zeros(
450 weight_candle.shape(),
451 weight_candle.dtype(),
452 weight_candle.device(),
453 )
454 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
455
456 let bias_tensor = if let Some(b) = _bias_candle {
457 b
458 } else {
459 &zero_bias
460 };
461
462 let normalized = candle_nn::ops::layer_norm(input_candle, weight_candle, bias_tensor, eps)
463 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
464 Ok(Arc::new(CandleTensor::new(normalized)?))
465 }
466
467 fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
468 let input_candle = get_candle_tensor(input)?;
469 let weight_candle = get_candle_tensor(weight)?;
470
471 let _rms = candle_nn::RmsNorm::new(weight_candle.clone(), eps as f64);
472 let result = candle_nn::ops::rms_norm(input_candle, weight_candle, eps)
473 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
474 Ok(Arc::new(CandleTensor::new(result)?))
475 }
476
477 fn relu(&self, tensor: &TensorRef) -> Result<TensorRef> {
478 let candle_tensor = get_candle_tensor(tensor)?;
479 let result = candle_tensor
480 .relu()
481 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
482 Ok(Arc::new(CandleTensor::new(result)?))
483 }
484
485 fn gelu(&self, tensor: &TensorRef) -> Result<TensorRef> {
486 let candle_tensor = get_candle_tensor(tensor)?;
487 let result = candle_tensor
488 .gelu()
489 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
490 Ok(Arc::new(CandleTensor::new(result)?))
491 }
492
493 fn silu(&self, tensor: &TensorRef) -> Result<TensorRef> {
494 let candle_tensor = get_candle_tensor(tensor)?;
495 let result = candle_nn::ops::silu(candle_tensor)
496 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
497 Ok(Arc::new(CandleTensor::new(result)?))
498 }
499
500 fn concat(&self, tensors: &[&TensorRef], dim: usize) -> Result<TensorRef> {
501 let candle_tensors: Result<Vec<_>> = tensors.iter().map(|t| get_candle_tensor(t)).collect();
502 let candle_tensors = candle_tensors?;
503
504 let result = candle_core::Tensor::cat(&candle_tensors, dim)
505 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
506 Ok(Arc::new(CandleTensor::new(result)?))
507 }
508
509 fn split(&self, tensor: &TensorRef, sizes: &[usize], dim: usize) -> Result<Vec<TensorRef>> {
510 let candle_tensor = get_candle_tensor(tensor)?;
511 let mut result = Vec::new();
512 let mut offset = 0;
513
514 for &size in sizes {
515 let chunk = candle_tensor
516 .narrow(dim, offset, size)
517 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
518 result.push(Arc::new(CandleTensor::new(chunk)?) as TensorRef);
519 offset += size;
520 }
521
522 Ok(result)
523 }
524
525 fn transpose(&self, tensor: &TensorRef, dim0: usize, dim1: usize) -> Result<TensorRef> {
526 let candle_tensor = get_candle_tensor(tensor)?;
527 let result = candle_tensor
528 .transpose(dim0, dim1)
529 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
530 Ok(Arc::new(CandleTensor::new(result)?))
531 }
532
533 fn permute(&self, tensor: &TensorRef, dims: &[usize]) -> Result<TensorRef> {
534 let candle_tensor = get_candle_tensor(tensor)?;
535 let result = candle_tensor
536 .permute(dims)
537 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))?;
538 Ok(Arc::new(CandleTensor::new(result)?))
539 }
540}
541
542pub struct CandleBackend {
544 device: Device,
545 tensor_factory: CandleTensorFactory,
546 tensor_ops: CandleTensorOps,
547 kernel_ops: super::candle_kernel_ops::CandleKernelOps,
548 memory_manager: MemoryPool,
549}
550
551impl std::fmt::Debug for CandleBackend {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("CandleBackend")
554 .field("device", &self.device)
555 .finish()
556 }
557}
558
559impl CandleBackend {
560 pub async fn new(device: Device) -> Result<Self> {
561 debug!("Initializing Candle backend for: {:?}", device);
562
563 let tensor_factory = CandleTensorFactory::new(device.clone());
564 let tensor_ops = CandleTensorOps;
565 let kernel_ops = super::candle_kernel_ops::CandleKernelOps::new();
566 let memory_manager = MemoryPool::new(
567 device.clone(),
568 crate::memory::InternalMemoryPoolConfig {
569 initial_size: 1024 * 1024 * 1024, max_size: 4 * 1024 * 1024 * 1024, growth_factor: 1.5,
572 enable_defragmentation: true,
573 min_pooled_size: 256,
574 max_pooled_size: 1024 * 1024, size_buckets: 32,
576 },
577 );
578
579 Ok(Self {
580 device,
581 tensor_factory,
582 tensor_ops,
583 kernel_ops,
584 memory_manager,
585 })
586 }
587}
588
589#[async_trait]
590impl ComputeBackend for CandleBackend {
591 fn name(&self) -> &str {
592 "candle"
593 }
594
595 fn capabilities(&self) -> BackendCapabilities {
596 let supported_devices = {
597 #[cfg(all(feature = "metal", any(target_os = "macos", target_os = "ios")))]
598 {
599 vec![Device::CPU, Device::CUDA(0), Device::Metal]
600 }
601 #[cfg(not(all(feature = "metal", any(target_os = "macos", target_os = "ios"))))]
602 {
603 vec![Device::CPU, Device::CUDA(0)]
604 }
605 };
606
607 BackendCapabilities {
608 supported_dtypes: vec![DataType::FP32, DataType::FP16, DataType::BF16],
609 supported_devices,
610 max_tensor_dims: 8,
611 supports_fp16: true,
612 supports_bf16: true,
613 supports_int8: false,
614 supports_flash_attention: false,
615 supports_paged_attention: false,
616 supports_tensor_parallelism: false,
617 supports_pipeline_parallelism: false,
618 max_batch_size: 32,
619 max_sequence_length: 4096,
620 memory_alignment: 256,
621 supports_custom_kernels: false,
622 supports_cuda_graphs: false,
623 extra_capabilities: HashMap::new(),
624 }
625 }
626
627 fn tensor_ops(&self) -> &dyn TensorOps {
628 &self.tensor_ops
629 }
630
631 fn tensor_factory(&self) -> &dyn TensorFactory {
632 &self.tensor_factory
633 }
634
635 fn memory_manager(&self) -> &dyn DeviceMemoryManager {
636 &self.memory_manager
637 }
638
639 fn kernel_executor(&self) -> Option<&dyn KernelExecutor> {
640 None }
642
643 fn kernel_ops(&self) -> Option<&dyn KernelOps> {
644 Some(&self.kernel_ops)
645 }
646
647 async fn initialize(&mut self, _device: &Device) -> Result<()> {
648 Ok(())
650 }
651
652 fn supports_device(&self, device: &Device) -> bool {
653 match device {
654 Device::CPU | Device::CUDA(_) => true,
655 Device::ROCm(_) => false,
656 #[cfg(any(target_os = "macos", target_os = "ios"))]
657 Device::Metal => cfg!(feature = "metal"),
658 }
659 }
660
661 fn version(&self) -> String {
662 env!("CARGO_PKG_VERSION").to_string()
663 }
664
665 async fn synchronize(&self, _device: &Device) -> Result<()> {
666 Ok(())
668 }
669
670 fn status(&self) -> BackendStatus {
671 BackendStatus {
672 is_initialized: true,
673 is_ready: true,
674 active_devices: vec![self.device.clone()],
675 memory_usage: HashMap::new(),
676 operations_completed: 0,
677 last_error: None,
678 backend_specific: HashMap::new(),
679 }
680 }
681
682 async fn shutdown(&mut self) -> Result<()> {
683 debug!("Shutting down Candle backend");
684 Ok(())
685 }
686}
687
688fn get_candle_tensor(tensor: &TensorRef) -> Result<&candle_core::Tensor> {
693 let concrete_ref: &CandleTensor = unsafe {
695 &*(Arc::as_ptr(tensor) as *const CandleTensor)
697 };
698 Ok(&concrete_ref.inner)
699}
700
701fn ferrum_dtype_to_candle(dtype: DataType) -> Result<candle_core::DType> {
702 match dtype {
703 DataType::FP32 => Ok(candle_core::DType::F32),
704 DataType::FP16 => Ok(candle_core::DType::F16),
705 DataType::BF16 => Ok(candle_core::DType::BF16),
706 DataType::UINT32 => Ok(candle_core::DType::U32),
707 DataType::UINT8 => Ok(candle_core::DType::U8),
708 DataType::INT32 => Ok(candle_core::DType::U32), _ => Err(ferrum_types::FerrumError::backend(format!(
710 "Unsupported dtype: {:?}",
711 dtype
712 ))),
713 }
714}
715
716fn candle_dtype_to_ferrum(dtype: candle_core::DType) -> Result<DataType> {
717 match dtype {
718 candle_core::DType::F32 => Ok(DataType::FP32),
719 candle_core::DType::F16 => Ok(DataType::FP16),
720 candle_core::DType::BF16 => Ok(DataType::BF16),
721 candle_core::DType::U32 => Ok(DataType::UINT32),
722 candle_core::DType::U8 => Ok(DataType::UINT8),
723 _ => Err(ferrum_types::FerrumError::backend(format!(
724 "Unsupported Candle dtype: {:?}",
725 dtype
726 ))),
727 }
728}
729
730fn ferrum_device_to_candle(device: Device) -> Result<candle_core::Device> {
731 match device {
732 Device::CPU => Ok(candle_core::Device::Cpu),
733 Device::CUDA(id) => {
734 #[cfg(feature = "cuda")]
735 {
736 candle_core::Device::new_cuda(id as usize)
737 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))
738 }
739 #[cfg(not(feature = "cuda"))]
740 {
741 let _ = id;
742 Err(ferrum_types::FerrumError::unsupported("CUDA not enabled"))
743 }
744 }
745 #[cfg(any(target_os = "macos", target_os = "ios"))]
746 Device::Metal => {
747 #[cfg(feature = "metal")]
748 {
749 candle_core::Device::new_metal(0)
750 .map_err(|e| ferrum_types::FerrumError::backend(e.to_string()))
751 }
752 #[cfg(not(feature = "metal"))]
753 {
754 Err(ferrum_types::FerrumError::unsupported("Metal not enabled"))
755 }
756 }
757 Device::ROCm(_) => Err(ferrum_types::FerrumError::unsupported("ROCm not supported")),
758 }
759}
760
761fn candle_device_to_ferrum(device: &candle_core::Device) -> Result<Device> {
762 match device {
763 candle_core::Device::Cpu => Ok(Device::CPU),
764 candle_core::Device::Cuda(_) => Ok(Device::CUDA(0)), candle_core::Device::Metal(_) => {
766 #[cfg(any(target_os = "macos", target_os = "ios"))]
767 {
768 Ok(Device::Metal)
769 }
770 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
771 {
772 Err(ferrum_types::FerrumError::unsupported(
773 "Metal devices are not available on this platform",
774 ))
775 }
776 }
777 }
778}
779
780#[cfg(test)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn test_dtype_conversions() {
790 let candle_fp32 = ferrum_dtype_to_candle(DataType::FP32).unwrap();
792 let back_fp32 = candle_dtype_to_ferrum(candle_fp32).unwrap();
793 assert_eq!(back_fp32, DataType::FP32);
794
795 let candle_fp16 = ferrum_dtype_to_candle(DataType::FP16).unwrap();
797 let back_fp16 = candle_dtype_to_ferrum(candle_fp16).unwrap();
798 assert_eq!(back_fp16, DataType::FP16);
799 }
800
801 #[test]
802 fn test_device_conversions_cpu() {
803 let ferrum_device = Device::CPU;
804 let candle_device = ferrum_device_to_candle(ferrum_device.clone()).unwrap();
805 let back_device = candle_device_to_ferrum(&candle_device).unwrap();
806 assert_eq!(back_device, Device::CPU);
807 }
808
809 #[tokio::test]
810 async fn test_candle_backend_creation() {
811 let backend = CandleBackend::new(Device::CPU).await;
812 assert!(backend.is_ok());
813 }
814
815 #[tokio::test]
816 async fn test_candle_backend_name() {
817 let backend = CandleBackend::new(Device::CPU).await.unwrap();
818 assert_eq!(backend.name(), "candle");
819 }
820
821 #[tokio::test]
822 async fn test_candle_backend_capabilities() {
823 let backend = CandleBackend::new(Device::CPU).await.unwrap();
824 let caps = backend.capabilities();
825
826 assert!(caps.supported_dtypes.contains(&DataType::FP32));
827 assert!(caps.max_tensor_dims > 0);
828 }
829
830 #[tokio::test]
831 async fn test_candle_backend_supports_cpu() {
832 let backend = CandleBackend::new(Device::CPU).await.unwrap();
833 assert!(backend.supports_device(&Device::CPU));
834 }
835
836 #[test]
837 fn test_tensor_factory_zeros() {
838 let factory = CandleTensorFactory::new(Device::CPU);
839 let tensor = factory
840 .zeros(&[2, 3], DataType::FP32, &Device::CPU)
841 .unwrap();
842
843 assert_eq!(tensor.shape(), &[2, 3]);
844 assert_eq!(tensor.dtype(), DataType::FP32);
845 }
846
847 #[test]
848 fn test_tensor_factory_ones() {
849 let factory = CandleTensorFactory::new(Device::CPU);
850 let tensor = factory.ones(&[2, 2], DataType::FP32, &Device::CPU).unwrap();
851
852 assert_eq!(tensor.shape(), &[2, 2]);
853 }
854
855 #[test]
856 fn test_tensor_ops_add() {
857 let factory = CandleTensorFactory::new(Device::CPU);
858 let ops = CandleTensorOps;
859
860 let a = factory
861 .from_slice(&[1.0, 2.0], &[2], DataType::FP32, Device::CPU)
862 .unwrap();
863 let b = factory
864 .from_slice(&[3.0, 4.0], &[2], DataType::FP32, Device::CPU)
865 .unwrap();
866
867 let result = ops.add(&a, &b).unwrap();
868 let data = result.to_vec_f32().unwrap();
869
870 assert!((data[0] - 4.0).abs() < 1e-5);
871 assert!((data[1] - 6.0).abs() < 1e-5);
872 }
873
874 #[test]
875 fn test_tensor_ops_matmul() {
876 let factory = CandleTensorFactory::new(Device::CPU);
877 let ops = CandleTensorOps;
878
879 let a = factory
881 .from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], DataType::FP32, Device::CPU)
882 .unwrap();
883 let b = factory
884 .from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], DataType::FP32, Device::CPU)
885 .unwrap();
886
887 let result = ops.matmul(&a, &b).unwrap();
888 assert_eq!(result.shape(), &[2, 2]);
889 }
890
891 #[test]
892 fn test_tensor_reshape() {
893 let factory = CandleTensorFactory::new(Device::CPU);
894 let tensor = factory
895 .zeros(&[2, 3], DataType::FP32, &Device::CPU)
896 .unwrap();
897
898 let reshaped = tensor.reshape(&[3, 2]).unwrap();
899 assert_eq!(reshaped.shape(), &[3, 2]);
900 }
901
902 #[test]
903 fn test_tensor_to_cpu() {
904 let factory = CandleTensorFactory::new(Device::CPU);
905 let tensor = factory
906 .zeros(&[2, 3], DataType::FP32, &Device::CPU)
907 .unwrap();
908
909 let cpu_tensor = tensor.to_cpu().unwrap();
910 assert_eq!(cpu_tensor.device(), Device::CPU);
911 }
912
913 #[test]
914 fn test_tensor_to_vec_u32_from_fp32_ids() {
915 let factory = CandleTensorFactory::new(Device::CPU);
916 let tensor = factory
917 .from_slice(&[1.0, 2.0, 3.0], &[1, 3], DataType::FP32, Device::CPU)
918 .unwrap();
919
920 let tokens = tensor.to_vec_u32().unwrap();
921 assert_eq!(tokens, vec![1, 2, 3]);
922 }
923}