1use candle_core::IndexOp;
8use ferrum_interfaces::kernel_ops::{
9 ActivationOps, AttentionOps, AttentionParams, KernelOps, LinearOps, NormOps, PositionOps,
10 SamplingOps, SamplingParams,
11};
12use ferrum_interfaces::TensorRef;
13use ferrum_types::{FerrumError, Result};
14use std::sync::Arc;
15
16use super::candle::CandleTensor;
17#[cfg(test)]
18use super::candle::CandleTensorOps;
19
20fn ct(tensor: &TensorRef) -> Result<&candle_core::Tensor> {
25 let concrete: &CandleTensor = unsafe { &*(Arc::as_ptr(tensor) as *const CandleTensor) };
26 Ok(concrete.inner())
27}
28
29fn wrap(tensor: candle_core::Tensor) -> Result<TensorRef> {
30 Ok(Arc::new(CandleTensor::new(tensor)?) as TensorRef)
31}
32
33fn err(msg: impl std::fmt::Display) -> FerrumError {
34 FerrumError::backend(msg.to_string())
35}
36
37pub struct CandleNormOps;
42
43impl NormOps for CandleNormOps {
44 fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
45 let input = ct(input)?;
46 let weight = ct(weight)?;
47 let result = candle_nn::ops::rms_norm(input, weight, eps).map_err(err)?;
48 wrap(result)
49 }
50
51 fn rms_norm_residual(
52 &self,
53 input: &TensorRef,
54 residual: &TensorRef,
55 weight: &TensorRef,
56 eps: f32,
57 ) -> Result<(TensorRef, TensorRef)> {
58 let input = ct(input)?;
59 let residual = ct(residual)?;
60 let weight = ct(weight)?;
61
62 let updated = (input + residual).map_err(err)?;
64 let normed = candle_nn::ops::rms_norm(&updated, weight, eps).map_err(err)?;
66
67 Ok((wrap(normed)?, wrap(updated)?))
68 }
69}
70
71pub struct CandlePositionOps;
76
77impl PositionOps for CandlePositionOps {
78 fn rotary_embedding(
79 &self,
80 x: &TensorRef,
81 cos_cache: &TensorRef,
82 sin_cache: &TensorRef,
83 position_ids: &[usize],
84 ) -> Result<TensorRef> {
85 use candle_core::D;
86
87 let x = ct(x)?;
88 let cos_cache = ct(cos_cache)?;
89 let sin_cache = ct(sin_cache)?;
90
91 let head_dim = *x.dims().last().ok_or_else(|| err("empty tensor"))?;
92 let half_dim = head_dim / 2;
93 let target_dtype = x.dtype();
94
95 let pos = position_ids
97 .first()
98 .copied()
99 .ok_or_else(|| err("empty position_ids"))?;
100 let cos = cos_cache.i(pos).map_err(err)?;
101 let sin = sin_cache.i(pos).map_err(err)?;
102
103 let cos = if cos.dtype() != target_dtype {
104 cos.to_dtype(target_dtype).map_err(err)?
105 } else {
106 cos
107 };
108 let sin = if sin.dtype() != target_dtype {
109 sin.to_dtype(target_dtype).map_err(err)?
110 } else {
111 sin
112 };
113
114 let x1 = x.narrow(D::Minus1, 0, half_dim).map_err(err)?;
116 let x2 = x.narrow(D::Minus1, half_dim, half_dim).map_err(err)?;
117
118 let r1 = x1
120 .broadcast_mul(&cos)
121 .map_err(err)?
122 .broadcast_sub(&x2.broadcast_mul(&sin).map_err(err)?)
123 .map_err(err)?;
124 let r2 = x1
125 .broadcast_mul(&sin)
126 .map_err(err)?
127 .broadcast_add(&x2.broadcast_mul(&cos).map_err(err)?)
128 .map_err(err)?;
129
130 let result = candle_core::Tensor::cat(&[r1, r2], D::Minus1).map_err(err)?;
131 wrap(result)
132 }
133}
134
135pub struct CandleAttentionOps;
140
141impl AttentionOps for CandleAttentionOps {
142 fn attention(
143 &self,
144 q: &TensorRef,
145 k: &TensorRef,
146 v: &TensorRef,
147 params: &AttentionParams,
148 ) -> Result<TensorRef> {
149 use candle_core::D;
150
151 let q = ct(q)?;
152 let k = ct(k)?;
153 let v = ct(v)?;
154
155 let q = q.transpose(1, 2).map_err(err)?;
158 let k = k.transpose(1, 2).map_err(err)?;
159 let v = v.transpose(1, 2).map_err(err)?;
160
161 let n_rep = params.num_heads / params.num_kv_heads;
163 let (k, v) = if n_rep > 1 {
164 (repeat_kv(&k, n_rep)?, repeat_kv(&v, n_rep)?)
165 } else {
166 (k, v)
167 };
168
169 let q = q.contiguous().map_err(err)?;
171 let k = k.contiguous().map_err(err)?;
172
173 let k_t = k.transpose(D::Minus2, D::Minus1).map_err(err)?;
175 let k_t = k_t.contiguous().map_err(err)?;
176 let scores = q.matmul(&k_t).map_err(err)?;
177 let scores = scores
178 .affine(params.softmax_scale as f64, 0.0)
179 .map_err(err)?;
180
181 let scores = if params.causal {
183 let (_, _, q_len, kv_len) = scores.dims4().map_err(err)?;
184 let past_len = kv_len.saturating_sub(q_len);
185 let mask_data: Vec<f32> = (0..q_len)
186 .flat_map(|i| {
187 let max_k = past_len + i;
188 (0..kv_len).map(move |j| if j <= max_k { 0.0 } else { f32::NEG_INFINITY })
189 })
190 .collect();
191 let mask =
192 candle_core::Tensor::from_vec(mask_data, (1, 1, q_len, kv_len), scores.device())
193 .map_err(err)?;
194 let mask = if mask.dtype() != scores.dtype() {
195 mask.to_dtype(scores.dtype()).map_err(err)?
196 } else {
197 mask
198 };
199 scores.broadcast_add(&mask).map_err(err)?
200 } else {
201 scores
202 };
203
204 let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1).map_err(err)?;
206
207 let output = attn_weights.matmul(&v).map_err(err)?;
209
210 let output = output.transpose(1, 2).map_err(err)?;
212 wrap(output)
213 }
214}
215
216fn repeat_kv(x: &candle_core::Tensor, n_rep: usize) -> Result<candle_core::Tensor> {
217 let (batch, num_kv_heads, seq_len, head_dim) = x.dims4().map_err(err)?;
218 let unsqueezed = x.unsqueeze(2).map_err(err)?;
219 let repeated: Vec<candle_core::Tensor> = (0..n_rep).map(|_| unsqueezed.clone()).collect();
220 let cat = candle_core::Tensor::cat(&repeated, 2).map_err(err)?;
221 cat.reshape((batch, num_kv_heads * n_rep, seq_len, head_dim))
222 .map_err(err)
223}
224
225pub struct CandleActivationOps;
230
231impl ActivationOps for CandleActivationOps {
232 fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef> {
233 let gate = ct(gate)?;
234 let up = ct(up)?;
235 let activated = candle_nn::ops::silu(gate).map_err(err)?;
236 let result = activated.mul(up).map_err(err)?;
237 wrap(result)
238 }
239
240 fn gelu(&self, input: &TensorRef) -> Result<TensorRef> {
241 let input = ct(input)?;
242 let result = input.gelu().map_err(err)?;
243 wrap(result)
244 }
245}
246
247pub struct CandleLinearOps;
252
253impl LinearOps for CandleLinearOps {
254 fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef> {
255 let input = ct(input)?;
256 let weight = ct(weight)?;
257 let w_t = weight.transpose(0, 1).map_err(err)?;
259 let result = input.matmul(&w_t).map_err(err)?;
260 wrap(result)
261 }
262}
263
264pub struct CandleSamplingOps;
269
270impl SamplingOps for CandleSamplingOps {
271 fn sample_token(&self, logits: &TensorRef, _params: &SamplingParams) -> Result<u32> {
272 self.argmax(logits)
274 }
275
276 fn argmax(&self, logits: &TensorRef) -> Result<u32> {
277 logits.argmax_last_dim_u32()
278 }
279}
280
281pub struct CandleKernelOps {
287 norm: CandleNormOps,
288 position: CandlePositionOps,
289 attention: CandleAttentionOps,
290 activation: CandleActivationOps,
291 linear: CandleLinearOps,
292 sampling: CandleSamplingOps,
293}
294
295impl CandleKernelOps {
296 pub fn new() -> Self {
297 Self {
298 norm: CandleNormOps,
299 position: CandlePositionOps,
300 attention: CandleAttentionOps,
301 activation: CandleActivationOps,
302 linear: CandleLinearOps,
303 sampling: CandleSamplingOps,
304 }
305 }
306}
307
308impl Default for CandleKernelOps {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl KernelOps for CandleKernelOps {
315 fn norm_ops(&self) -> Option<&dyn NormOps> {
316 Some(&self.norm)
317 }
318 fn position_ops(&self) -> Option<&dyn PositionOps> {
319 Some(&self.position)
320 }
321 fn attention_ops(&self) -> Option<&dyn AttentionOps> {
322 Some(&self.attention)
323 }
324 fn activation_ops(&self) -> Option<&dyn ActivationOps> {
325 Some(&self.activation)
326 }
327 fn linear_ops(&self) -> Option<&dyn LinearOps> {
328 Some(&self.linear)
329 }
330 fn sampling_ops(&self) -> Option<&dyn SamplingOps> {
331 Some(&self.sampling)
332 }
333 fn backend_name(&self) -> &str {
334 "candle"
335 }
336}
337
338#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::backends::candle::CandleTensorFactory;
346 use ferrum_interfaces::{TensorFactory, TensorOps};
347 use ferrum_types::{DataType, Device};
348
349 fn factory() -> CandleTensorFactory {
350 CandleTensorFactory::new(Device::CPU)
351 }
352
353 #[test]
356 fn test_rms_norm_matches_tensor_ops() {
357 let f = factory();
358 let input = f
359 .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
360 .unwrap();
361 let weight = f
362 .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
363 .unwrap();
364
365 let kernel_result = CandleNormOps.rms_norm(&input, &weight, 1e-5).unwrap();
366 let tensor_result = CandleTensorOps.rms_norm(&input, &weight, 1e-5).unwrap();
367
368 let k = kernel_result.to_vec_f32().unwrap();
369 let t = tensor_result.to_vec_f32().unwrap();
370 assert_eq!(k.len(), t.len());
371 for (a, b) in k.iter().zip(t.iter()) {
372 assert!((a - b).abs() < 1e-5, "mismatch: {} vs {}", a, b);
373 }
374 }
375
376 #[test]
377 fn test_rms_norm_residual() {
378 let f = factory();
379 let input = f
380 .from_slice(&[1.0, 2.0], &[1, 2], DataType::FP32, Device::CPU)
381 .unwrap();
382 let residual = f
383 .from_slice(&[0.5, 0.5], &[1, 2], DataType::FP32, Device::CPU)
384 .unwrap();
385 let weight = f
386 .from_slice(&[1.0, 1.0], &[2], DataType::FP32, Device::CPU)
387 .unwrap();
388
389 let (normed, updated) = CandleNormOps
390 .rms_norm_residual(&input, &residual, &weight, 1e-5)
391 .unwrap();
392
393 let u = updated.to_vec_f32().unwrap();
395 assert!((u[0] - 1.5).abs() < 1e-5);
396 assert!((u[1] - 2.5).abs() < 1e-5);
397
398 let expected = CandleNormOps
400 .rms_norm(&updated, &weight, 1e-5)
401 .unwrap()
402 .to_vec_f32()
403 .unwrap();
404 let got = normed.to_vec_f32().unwrap();
405 for (a, b) in got.iter().zip(expected.iter()) {
406 assert!((a - b).abs() < 1e-5);
407 }
408 }
409
410 #[test]
413 fn test_silu_mul() {
414 let f = factory();
415 let gate = f
416 .from_slice(&[1.0, -1.0, 2.0, 0.0], &[4], DataType::FP32, Device::CPU)
417 .unwrap();
418 let up = f
419 .from_slice(&[2.0, 2.0, 2.0, 2.0], &[4], DataType::FP32, Device::CPU)
420 .unwrap();
421
422 let result = CandleActivationOps.silu_mul(&gate, &up).unwrap();
423 let vals = result.to_vec_f32().unwrap();
424
425 assert!(vals[0] > 1.0 && vals[0] < 2.0);
428 assert!(vals[3].abs() < 1e-5);
430 }
431
432 #[test]
433 fn test_gelu() {
434 let f = factory();
435 let input = f
436 .from_slice(&[0.0, 1.0, -1.0], &[3], DataType::FP32, Device::CPU)
437 .unwrap();
438
439 let result = CandleActivationOps.gelu(&input).unwrap();
440 let vals = result.to_vec_f32().unwrap();
441 assert!(vals[0].abs() < 1e-5);
443 assert!(vals[1] > 0.8 && vals[1] < 0.9);
445 }
446
447 #[test]
450 fn test_linear_identity() {
451 let f = factory();
452 let input = f
453 .from_slice(&[1.0, 2.0, 3.0], &[1, 3], DataType::FP32, Device::CPU)
454 .unwrap();
455 let weight = f
457 .from_slice(
458 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
459 &[3, 3],
460 DataType::FP32,
461 Device::CPU,
462 )
463 .unwrap();
464
465 let result = CandleLinearOps.linear(&input, &weight).unwrap();
466 let vals = result.to_vec_f32().unwrap();
467 assert!((vals[0] - 1.0).abs() < 1e-5);
468 assert!((vals[1] - 2.0).abs() < 1e-5);
469 assert!((vals[2] - 3.0).abs() < 1e-5);
470 }
471
472 #[test]
475 fn test_argmax() {
476 let f = factory();
477 let logits = f
478 .from_slice(
479 &[0.1, 0.5, 0.3, 0.9, 0.2],
480 &[5],
481 DataType::FP32,
482 Device::CPU,
483 )
484 .unwrap();
485
486 let token = CandleSamplingOps.argmax(&logits).unwrap();
487 assert_eq!(token, 3); }
489
490 #[test]
493 fn test_candle_kernel_ops_all_present() {
494 let ops = CandleKernelOps::new();
495 assert!(ops.norm_ops().is_some());
496 assert!(ops.position_ops().is_some());
497 assert!(ops.attention_ops().is_some());
498 assert!(ops.activation_ops().is_some());
499 assert!(ops.linear_ops().is_some());
500 assert!(ops.sampling_ops().is_some());
501 assert_eq!(ops.backend_name(), "candle");
502 }
503
504 #[test]
507 fn test_dispatch_fallback_rms_norm() {
508 let f = factory();
509 let input = f
510 .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
511 .unwrap();
512 let weight = f
513 .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
514 .unwrap();
515
516 let tensor_ops = CandleTensorOps;
517
518 let dispatch = ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(None, &tensor_ops);
520 let result = dispatch.rms_norm(&input, &weight, 1e-5).unwrap();
521 let vals = result.to_vec_f32().unwrap();
522 assert_eq!(vals.len(), 4);
523 }
524
525 #[test]
526 fn test_dispatch_with_kernel_ops_rms_norm() {
527 let f = factory();
528 let input = f
529 .from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4], DataType::FP32, Device::CPU)
530 .unwrap();
531 let weight = f
532 .from_slice(&[1.0, 1.0, 1.0, 1.0], &[4], DataType::FP32, Device::CPU)
533 .unwrap();
534
535 let kernel_ops = CandleKernelOps::new();
536 let tensor_ops = CandleTensorOps;
537
538 let dispatch =
540 ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(Some(&kernel_ops), &tensor_ops);
541 let result = dispatch.rms_norm(&input, &weight, 1e-5).unwrap();
542 let vals = result.to_vec_f32().unwrap();
543 assert_eq!(vals.len(), 4);
544 }
545
546 #[test]
547 fn test_dispatch_silu_mul_fallback() {
548 let f = factory();
549 let gate = f
550 .from_slice(&[1.0, 2.0], &[2], DataType::FP32, Device::CPU)
551 .unwrap();
552 let up = f
553 .from_slice(&[3.0, 4.0], &[2], DataType::FP32, Device::CPU)
554 .unwrap();
555
556 let tensor_ops = CandleTensorOps;
557
558 let dispatch = ferrum_interfaces::kernel_ops::KernelOpsDispatch::new(None, &tensor_ops);
560 let result = dispatch.silu_mul(&gate, &up).unwrap();
561 let vals = result.to_vec_f32().unwrap();
562 assert_eq!(vals.len(), 2);
563 assert!(vals[0] > 2.0 && vals[0] < 2.5);
565 }
566}