Skip to main content

ferrum_runtime/backends/
candle_kernel_ops.rs

1//! Candle reference implementation of the `KernelOps` sub-traits.
2//!
3//! All operations are implemented using pure candle tensor ops (CPU/Metal/CUDA
4//! via candle's own dispatch). This serves as the baseline that any
5//! hardware-specific backend must match.
6
7use candle_core::IndexOp;
8use ferrum_interfaces::kernel_ops::{
9    ActivationOps, AttentionOps, AttentionParams, KernelOps, LinearOps, NormOps, PositionOps,
10    SamplingOps, SamplingParams,
11};
12use ferrum_interfaces::TensorRef;
13use ferrum_types::{FerrumError, Result};
14use std::sync::Arc;
15
16use super::candle::CandleTensor;
17#[cfg(test)]
18use super::candle::CandleTensorOps;
19
20// ---------------------------------------------------------------------------
21// Helpers
22// ---------------------------------------------------------------------------
23
24fn ct(tensor: &TensorRef) -> Result<&candle_core::Tensor> {
25    let concrete: &CandleTensor = unsafe { &*(Arc::as_ptr(tensor) as *const CandleTensor) };
26    Ok(concrete.inner())
27}
28
29fn wrap(tensor: candle_core::Tensor) -> Result<TensorRef> {
30    Ok(Arc::new(CandleTensor::new(tensor)?) as TensorRef)
31}
32
33fn err(msg: impl std::fmt::Display) -> FerrumError {
34    FerrumError::backend(msg.to_string())
35}
36
37// ---------------------------------------------------------------------------
38// NormOps
39// ---------------------------------------------------------------------------
40
41pub struct CandleNormOps;
42
43impl NormOps for CandleNormOps {
44    fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
45        let input = ct(input)?;
46        let weight = ct(weight)?;
47        let result = candle_nn::ops::rms_norm(input, weight, eps).map_err(err)?;
48        wrap(result)
49    }
50
51    fn rms_norm_residual(
52        &self,
53        input: &TensorRef,
54        residual: &TensorRef,
55        weight: &TensorRef,
56        eps: f32,
57    ) -> Result<(TensorRef, TensorRef)> {
58        let input = ct(input)?;
59        let residual = ct(residual)?;
60        let weight = ct(weight)?;
61
62        // updated_residual = input + residual
63        let updated = (input + residual).map_err(err)?;
64        // normed = rms_norm(updated_residual)
65        let normed = candle_nn::ops::rms_norm(&updated, weight, eps).map_err(err)?;
66
67        Ok((wrap(normed)?, wrap(updated)?))
68    }
69}
70
71// ---------------------------------------------------------------------------
72// PositionOps
73// ---------------------------------------------------------------------------
74
75pub struct CandlePositionOps;
76
77impl PositionOps for CandlePositionOps {
78    fn rotary_embedding(
79        &self,
80        x: &TensorRef,
81        cos_cache: &TensorRef,
82        sin_cache: &TensorRef,
83        position_ids: &[usize],
84    ) -> Result<TensorRef> {
85        use candle_core::D;
86
87        let x = ct(x)?;
88        let cos_cache = ct(cos_cache)?;
89        let sin_cache = ct(sin_cache)?;
90
91        let head_dim = *x.dims().last().ok_or_else(|| err("empty tensor"))?;
92        let half_dim = head_dim / 2;
93        let target_dtype = x.dtype();
94
95        // Index into cos/sin caches for the requested position.
96        let pos = position_ids
97            .first()
98            .copied()
99            .ok_or_else(|| err("empty position_ids"))?;
100        let cos = cos_cache.i(pos).map_err(err)?;
101        let sin = sin_cache.i(pos).map_err(err)?;
102
103        let cos = if cos.dtype() != target_dtype {
104            cos.to_dtype(target_dtype).map_err(err)?
105        } else {
106            cos
107        };
108        let sin = if sin.dtype() != target_dtype {
109            sin.to_dtype(target_dtype).map_err(err)?
110        } else {
111            sin
112        };
113
114        // Split into two halves along last dim.
115        let x1 = x.narrow(D::Minus1, 0, half_dim).map_err(err)?;
116        let x2 = x.narrow(D::Minus1, half_dim, half_dim).map_err(err)?;
117
118        // Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
119        let r1 = x1
120            .broadcast_mul(&cos)
121            .map_err(err)?
122            .broadcast_sub(&x2.broadcast_mul(&sin).map_err(err)?)
123            .map_err(err)?;
124        let r2 = x1
125            .broadcast_mul(&sin)
126            .map_err(err)?
127            .broadcast_add(&x2.broadcast_mul(&cos).map_err(err)?)
128            .map_err(err)?;
129
130        let result = candle_core::Tensor::cat(&[r1, r2], D::Minus1).map_err(err)?;
131        wrap(result)
132    }
133}
134
135// ---------------------------------------------------------------------------
136// AttentionOps
137// ---------------------------------------------------------------------------
138
139pub struct CandleAttentionOps;
140
141impl AttentionOps for CandleAttentionOps {
142    fn attention(
143        &self,
144        q: &TensorRef,
145        k: &TensorRef,
146        v: &TensorRef,
147        params: &AttentionParams,
148    ) -> Result<TensorRef> {
149        use candle_core::D;
150
151        let q = ct(q)?;
152        let k = ct(k)?;
153        let v = ct(v)?;
154
155        // Input layout: [batch, seq, heads, head_dim]
156        // Transpose to [batch, heads, seq, head_dim] for batched matmul.
157        let q = q.transpose(1, 2).map_err(err)?;
158        let k = k.transpose(1, 2).map_err(err)?;
159        let v = v.transpose(1, 2).map_err(err)?;
160
161        // Handle GQA: repeat KV heads to match Q heads.
162        let n_rep = params.num_heads / params.num_kv_heads;
163        let (k, v) = if n_rep > 1 {
164            (repeat_kv(&k, n_rep)?, repeat_kv(&v, n_rep)?)
165        } else {
166            (k, v)
167        };
168
169        // Ensure contiguous for Metal/CUDA matmul.
170        let q = q.contiguous().map_err(err)?;
171        let k = k.contiguous().map_err(err)?;
172
173        // scores = Q @ K^T / scale
174        let k_t = k.transpose(D::Minus2, D::Minus1).map_err(err)?;
175        let k_t = k_t.contiguous().map_err(err)?;
176        let scores = q.matmul(&k_t).map_err(err)?;
177        let scores = scores
178            .affine(params.softmax_scale as f64, 0.0)
179            .map_err(err)?;
180
181        // Causal mask.
182        let scores = if params.causal {
183            let (_, _, q_len, kv_len) = scores.dims4().map_err(err)?;
184            let past_len = kv_len.saturating_sub(q_len);
185            let mask_data: Vec<f32> = (0..q_len)
186                .flat_map(|i| {
187                    let max_k = past_len + i;
188                    (0..kv_len).map(move |j| if j <= max_k { 0.0 } else { f32::NEG_INFINITY })
189                })
190                .collect();
191            let mask =
192                candle_core::Tensor::from_vec(mask_data, (1, 1, q_len, kv_len), scores.device())
193                    .map_err(err)?;
194            let mask = if mask.dtype() != scores.dtype() {
195                mask.to_dtype(scores.dtype()).map_err(err)?
196            } else {
197                mask
198            };
199            scores.broadcast_add(&mask).map_err(err)?
200        } else {
201            scores
202        };
203
204        // softmax
205        let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1).map_err(err)?;
206
207        // output = weights @ V
208        let output = attn_weights.matmul(&v).map_err(err)?;
209
210        // Transpose back: [batch, heads, seq, dim] -> [batch, seq, heads, dim]
211        let output = output.transpose(1, 2).map_err(err)?;
212        wrap(output)
213    }
214}
215
216fn repeat_kv(x: &candle_core::Tensor, n_rep: usize) -> Result<candle_core::Tensor> {
217    let (batch, num_kv_heads, seq_len, head_dim) = x.dims4().map_err(err)?;
218    let unsqueezed = x.unsqueeze(2).map_err(err)?;
219    let repeated: Vec<candle_core::Tensor> = (0..n_rep).map(|_| unsqueezed.clone()).collect();
220    let cat = candle_core::Tensor::cat(&repeated, 2).map_err(err)?;
221    cat.reshape((batch, num_kv_heads * n_rep, seq_len, head_dim))
222        .map_err(err)
223}
224
225// ---------------------------------------------------------------------------
226// ActivationOps
227// ---------------------------------------------------------------------------
228
229pub struct CandleActivationOps;
230
231impl ActivationOps for CandleActivationOps {
232    fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef> {
233        let gate = ct(gate)?;
234        let up = ct(up)?;
235        let activated = candle_nn::ops::silu(gate).map_err(err)?;
236        let result = activated.mul(up).map_err(err)?;
237        wrap(result)
238    }
239
240    fn gelu(&self, input: &TensorRef) -> Result<TensorRef> {
241        let input = ct(input)?;
242        let result = input.gelu().map_err(err)?;
243        wrap(result)
244    }
245}
246
247// ---------------------------------------------------------------------------
248// LinearOps
249// ---------------------------------------------------------------------------
250
251pub struct CandleLinearOps;
252
253impl LinearOps for CandleLinearOps {
254    fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef> {
255        let input = ct(input)?;
256        let weight = ct(weight)?;
257        // weight is [out, in], so output = input @ weight^T
258        let w_t = weight.transpose(0, 1).map_err(err)?;
259        let result = input.matmul(&w_t).map_err(err)?;
260        wrap(result)
261    }
262}
263
264// ---------------------------------------------------------------------------
265// SamplingOps
266// ---------------------------------------------------------------------------
267
268pub struct CandleSamplingOps;
269
270impl SamplingOps for CandleSamplingOps {
271    fn sample_token(&self, logits: &TensorRef, _params: &SamplingParams) -> Result<u32> {
272        // Reference impl: greedy (full sampling pipeline is in ferrum-sampler).
273        self.argmax(logits)
274    }
275
276    fn argmax(&self, logits: &TensorRef) -> Result<u32> {
277        logits.argmax_last_dim_u32()
278    }
279}
280
281// ---------------------------------------------------------------------------
282// Umbrella: CandleKernelOps
283// ---------------------------------------------------------------------------
284
285/// Reference `KernelOps` implementation backed by Candle tensor ops.
286pub struct CandleKernelOps {
287    norm: CandleNormOps,
288    position: CandlePositionOps,
289    attention: CandleAttentionOps,
290    activation: CandleActivationOps,
291    linear: CandleLinearOps,
292    sampling: CandleSamplingOps,
293}
294
295impl CandleKernelOps {
296    pub fn new() -> Self {
297        Self {
298            norm: CandleNormOps,
299            position: CandlePositionOps,
300            attention: CandleAttentionOps,
301            activation: CandleActivationOps,
302            linear: CandleLinearOps,
303            sampling: CandleSamplingOps,
304        }
305    }
306}
307
308impl Default for CandleKernelOps {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314impl KernelOps for CandleKernelOps {
315    fn norm_ops(&self) -> Option<&dyn NormOps> {
316        Some(&self.norm)
317    }
318    fn position_ops(&self) -> Option<&dyn PositionOps> {
319        Some(&self.position)
320    }
321    fn attention_ops(&self) -> Option<&dyn AttentionOps> {
322        Some(&self.attention)
323    }
324    fn activation_ops(&self) -> Option<&dyn ActivationOps> {
325        Some(&self.activation)
326    }
327    fn linear_ops(&self) -> Option<&dyn LinearOps> {
328        Some(&self.linear)
329    }
330    fn sampling_ops(&self) -> Option<&dyn SamplingOps> {
331        Some(&self.sampling)
332    }
333    fn backend_name(&self) -> &str {
334        "candle"
335    }
336}
337
338// ---------------------------------------------------------------------------
339// Tests
340// ---------------------------------------------------------------------------
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::backends::candle::CandleTensorFactory;
346    use ferrum_interfaces::{TensorFactory, TensorOps};
347    use ferrum_types::{DataType, Device};
348
349    fn factory() -> CandleTensorFactory {
350        CandleTensorFactory::new(Device::CPU)
351    }
352
353    // -- NormOps --
354
355    #[test]
356    fn test_rms_norm_matches_tensor_ops() {
357        let f = factory();
358        let input = f
359            .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
360            .unwrap();
361        let weight = f
362            .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
363            .unwrap();
364
365        let kernel_result = CandleNormOps.rms_norm(&input, &weight, 1e-5).unwrap();
366        let tensor_result = CandleTensorOps.rms_norm(&input, &weight, 1e-5).unwrap();
367
368        let k = kernel_result.to_vec_f32().unwrap();
369        let t = tensor_result.to_vec_f32().unwrap();
370        assert_eq!(k.len(), t.len());
371        for (a, b) in k.iter().zip(t.iter()) {
372            assert!((a - b).abs() < 1e-5, "mismatch: {} vs {}", a, b);
373        }
374    }
375
376    #[test]
377    fn test_rms_norm_residual() {
378        let f = factory();
379        let input = f
380            .from_slice(&[1.0, 2.0], &[1, 2], DataType::FP32, Device::CPU)
381            .unwrap();
382        let residual = f
383            .from_slice(&[0.5, 0.5], &[1, 2], DataType::FP32, Device::CPU)
384            .unwrap();
385        let weight = f
386            .from_slice(&[1.0, 1.0], &[2], DataType::FP32, Device::CPU)
387            .unwrap();
388
389        let (normed, updated) = CandleNormOps
390            .rms_norm_residual(&input, &residual, &weight, 1e-5)
391            .unwrap();
392
393        // updated should be input + residual = [1.5, 2.5]
394        let u = updated.to_vec_f32().unwrap();
395        assert!((u[0] - 1.5).abs() < 1e-5);
396        assert!((u[1] - 2.5).abs() < 1e-5);
397
398        // normed should be rms_norm(updated)
399        let expected = CandleNormOps
400            .rms_norm(&updated, &weight, 1e-5)
401            .unwrap()
402            .to_vec_f32()
403            .unwrap();
404        let got = normed.to_vec_f32().unwrap();
405        for (a, b) in got.iter().zip(expected.iter()) {
406            assert!((a - b).abs() < 1e-5);
407        }
408    }
409
410    // -- ActivationOps --
411
412    #[test]
413    fn test_silu_mul() {
414        let f = factory();
415        let gate = f
416            .from_slice(&[1.0, -1.0, 2.0, 0.0], &[4], DataType::FP32, Device::CPU)
417            .unwrap();
418        let up = f
419            .from_slice(&[2.0, 2.0, 2.0, 2.0], &[4], DataType::FP32, Device::CPU)
420            .unwrap();
421
422        let result = CandleActivationOps.silu_mul(&gate, &up).unwrap();
423        let vals = result.to_vec_f32().unwrap();
424
425        // silu(x) = x * sigmoid(x)
426        // silu(1.0) ≈ 0.7311, * 2 ≈ 1.4621
427        assert!(vals[0] > 1.0 && vals[0] < 2.0);
428        // silu(0.0) = 0
429        assert!(vals[3].abs() < 1e-5);
430    }
431
432    #[test]
433    fn test_gelu() {
434        let f = factory();
435        let input = f
436            .from_slice(&[0.0, 1.0, -1.0], &[3], DataType::FP32, Device::CPU)
437            .unwrap();
438
439        let result = CandleActivationOps.gelu(&input).unwrap();
440        let vals = result.to_vec_f32().unwrap();
441        // gelu(0) = 0
442        assert!(vals[0].abs() < 1e-5);
443        // gelu(1) ≈ 0.8412
444        assert!(vals[1] > 0.8 && vals[1] < 0.9);
445    }
446
447    // -- LinearOps --
448
449    #[test]
450    fn test_linear_identity() {
451        let f = factory();
452        let input = f
453            .from_slice(&[1.0, 2.0, 3.0], &[1, 3], DataType::FP32, Device::CPU)
454            .unwrap();
455        // Identity weight [3, 3]
456        let weight = f
457            .from_slice(
458                &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
459                &[3, 3],
460                DataType::FP32,
461                Device::CPU,
462            )
463            .unwrap();
464
465        let result = CandleLinearOps.linear(&input, &weight).unwrap();
466        let vals = result.to_vec_f32().unwrap();
467        assert!((vals[0] - 1.0).abs() < 1e-5);
468        assert!((vals[1] - 2.0).abs() < 1e-5);
469        assert!((vals[2] - 3.0).abs() < 1e-5);
470    }
471
472    // -- SamplingOps --
473
474    #[test]
475    fn test_argmax() {
476        let f = factory();
477        let logits = f
478            .from_slice(
479                &[0.1, 0.5, 0.3, 0.9, 0.2],
480                &[5],
481                DataType::FP32,
482                Device::CPU,
483            )
484            .unwrap();
485
486        let token = CandleSamplingOps.argmax(&logits).unwrap();
487        assert_eq!(token, 3); // 0.9 is at index 3
488    }
489
490    // -- KernelOps umbrella --
491
492    #[test]
493    fn test_candle_kernel_ops_all_present() {
494        let ops = CandleKernelOps::new();
495        assert!(ops.norm_ops().is_some());
496        assert!(ops.position_ops().is_some());
497        assert!(ops.attention_ops().is_some());
498        assert!(ops.activation_ops().is_some());
499        assert!(ops.linear_ops().is_some());
500        assert!(ops.sampling_ops().is_some());
501        assert_eq!(ops.backend_name(), "candle");
502    }
503
504    // -- KernelOpsDispatch fallback --
505
506    #[test]
507    fn test_dispatch_fallback_rms_norm() {
508        let f = factory();
509        let input = f
510            .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
511            .unwrap();
512        let weight = f
513            .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
514            .unwrap();
515
516        let tensor_ops = CandleTensorOps;
517
518        // With kernel_ops = None, should fall back to TensorOps
519        let dispatch = ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(None, &tensor_ops);
520        let result = dispatch.rms_norm(&input, &weight, 1e-5).unwrap();
521        let vals = result.to_vec_f32().unwrap();
522        assert_eq!(vals.len(), 4);
523    }
524
525    #[test]
526    fn test_dispatch_with_kernel_ops_rms_norm() {
527        let f = factory();
528        let input = f
529            .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
530            .unwrap();
531        let weight = f
532            .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
533            .unwrap();
534
535        let kernel_ops = CandleKernelOps::new();
536        let tensor_ops = CandleTensorOps;
537
538        // With kernel_ops present, should use KernelOps path
539        let dispatch =
540            ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(Some(&kernel_ops), &tensor_ops);
541        let result = dispatch.rms_norm(&input, &weight, 1e-5).unwrap();
542        let vals = result.to_vec_f32().unwrap();
543        assert_eq!(vals.len(), 4);
544    }
545
546    #[test]
547    fn test_dispatch_silu_mul_fallback() {
548        let f = factory();
549        let gate = f
550            .from_slice(&[1.0, 2.0], &[2], DataType::FP32, Device::CPU)
551            .unwrap();
552        let up = f
553            .from_slice(&[3.0, 4.0], &[2], DataType::FP32, Device::CPU)
554            .unwrap();
555
556        let tensor_ops = CandleTensorOps;
557
558        // No kernel ops → falls back to silu(gate) * up via TensorOps
559        let dispatch = ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(None, &tensor_ops);
560        let result = dispatch.silu_mul(&gate, &up).unwrap();
561        let vals = result.to_vec_f32().unwrap();
562        assert_eq!(vals.len(), 2);
563        // silu(1.0)*3 ≈ 2.19
564        assert!(vals[0] > 2.0 && vals[0] < 2.5);
565    }
566}