Skip to main content

mlx_core/
cpu_kernels.rs

1//! Built-in CPU reference backend — correctness oracle.
2//!
3//! This is an intentionally simple, safe Rust implementation of every op.
4//! It prioritizes correctness and readability over performance.
5
6use crate::backend::{Backend, NodeInput};
7use crate::graph::{OpKind, TensorMeta};
8use crate::{MlxError, Result};
9
10/// Reference CPU backend.
11pub struct CpuRefBackend;
12
13impl Backend for CpuRefBackend {
14    fn eval_node(
15        &self,
16        op: &OpKind,
17        inputs: &[NodeInput<'_>],
18        output_meta: &TensorMeta,
19    ) -> Result<Vec<f32>> {
20        match op {
21            OpKind::Constant | OpKind::Parameter => Err(MlxError::InvalidArgument(
22                "Constant/Parameter nodes should be pre-materialized".into(),
23            )),
24            OpKind::Add => binary_elementwise(inputs, |a, b| a + b),
25            OpKind::Mul => binary_elementwise(inputs, |a, b| a * b),
26            OpKind::Sub => binary_elementwise(inputs, |a, b| a - b),
27            OpKind::Div => binary_elementwise(inputs, |a, b| a / b),
28            OpKind::Neg => {
29                let a = require_input(inputs, 0)?;
30                Ok(a.data.iter().map(|x| -x).collect())
31            }
32            OpKind::Exp => {
33                let a = require_input(inputs, 0)?;
34                Ok(a.data.iter().map(|x| x.exp()).collect())
35            }
36            OpKind::Log => {
37                let a = require_input(inputs, 0)?;
38                Ok(a.data.iter().map(|x| x.ln()).collect())
39            }
40            OpKind::Sum { axis } => reduce_sum(inputs, *axis),
41            OpKind::Mean { axis } => reduce_mean(inputs, *axis),
42            OpKind::Max { axis } => reduce_max(inputs, *axis),
43            OpKind::MatMul => matmul(inputs),
44            OpKind::Reshape { .. } => {
45                let a = require_input(inputs, 0)?;
46                Ok(a.data.to_vec())
47            }
48            OpKind::Transpose { axes } => transpose(inputs, axes.as_deref()),
49            OpKind::Softmax { axis } => softmax(inputs, *axis),
50            OpKind::Silu => {
51                let a = require_input(inputs, 0)?;
52                Ok(a.data.iter().map(|&x| x * sigmoid(x)).collect())
53            }
54            OpKind::Gelu => {
55                let a = require_input(inputs, 0)?;
56                Ok(a.data
57                    .iter()
58                    .map(|&x| {
59                        0.5 * x
60                            * (1.0
61                                + ((2.0 / std::f32::consts::PI).sqrt()
62                                    * (x + 0.044715 * x * x * x))
63                                    .tanh())
64                    })
65                    .collect())
66            }
67            OpKind::LayerNorm { eps } => layer_norm(inputs, *eps, output_meta),
68            OpKind::RmsNorm { eps } => rms_norm(inputs, *eps, output_meta),
69            OpKind::Broadcast { target_shape } => broadcast(inputs, target_shape),
70            OpKind::ScaledMaskedSoftmax { scale, causal } => {
71                scaled_masked_softmax(inputs, *scale, *causal)
72            }
73            OpKind::Attention { scale, causal } => cpu_attention(inputs, *scale, *causal),
74            OpKind::Rope {
75                rotary_dim,
76                pos_offset,
77                theta,
78            } => cpu_rope(inputs, output_meta, *rotary_dim, *pos_offset, *theta),
79            OpKind::LayerNormVjp { eps } => layer_norm_vjp(inputs, *eps),
80            OpKind::RmsNormVjp { eps } => rms_norm_vjp(inputs, *eps),
81            OpKind::SoftmaxVjp { axis } => softmax_vjp(inputs, *axis),
82            OpKind::SiluVjp => silu_vjp(inputs),
83            OpKind::GeluVjp => gelu_vjp(inputs),
84            OpKind::Sqrt => {
85                let a = require_input(inputs, 0)?;
86                Ok(a.data.iter().map(|&x| x.sqrt()).collect())
87            }
88            OpKind::RoPE {
89                base,
90                offset,
91                traditional,
92            } => rope(inputs, *base, *offset, *traditional),
93            OpKind::Embedding => embedding(inputs),
94            OpKind::Narrow {
95                axis,
96                start,
97                length,
98            } => narrow(inputs, *axis, *start, *length),
99            OpKind::Concatenate { axis } => concatenate(inputs, *axis),
100        }
101    }
102}
103
104fn sigmoid(x: f32) -> f32 {
105    1.0 / (1.0 + (-x).exp())
106}
107
108fn require_input<'a>(inputs: &'a [NodeInput<'_>], idx: usize) -> Result<&'a NodeInput<'a>> {
109    inputs
110        .get(idx)
111        .ok_or_else(|| MlxError::InvalidArgument(format!("expected input at index {idx}")))
112}
113
114fn binary_elementwise(inputs: &[NodeInput<'_>], f: fn(f32, f32) -> f32) -> Result<Vec<f32>> {
115    let a = require_input(inputs, 0)?;
116    let b = require_input(inputs, 1)?;
117    if a.data.len() != b.data.len() {
118        return Err(MlxError::ShapeMismatch {
119            expected: a.shape.0.clone(),
120            got: b.shape.0.clone(),
121        });
122    }
123    Ok(a.data
124        .iter()
125        .zip(b.data.iter())
126        .map(|(&x, &y)| f(x, y))
127        .collect())
128}
129
130fn reduce_sum(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
131    let a = require_input(inputs, 0)?;
132    match axis {
133        None => Ok(vec![a.data.iter().sum()]),
134        Some(axis) => reduce_along_axis(a, axis, |slice| slice.iter().sum()),
135    }
136}
137
138fn reduce_mean(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
139    let a = require_input(inputs, 0)?;
140    match axis {
141        None => {
142            let n = a.data.len() as f32;
143            Ok(vec![a.data.iter().sum::<f32>() / n])
144        }
145        Some(axis) => {
146            let ndim = a.shape.ndim() as i32;
147            let ax = if axis < 0 { ndim + axis } else { axis } as usize;
148            let dim = a.shape.0[ax] as f32;
149            reduce_along_axis(a, axis, |slice| slice.iter().sum::<f32>() / dim)
150        }
151    }
152}
153
154fn reduce_max(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
155    let a = require_input(inputs, 0)?;
156    match axis {
157        None => Ok(vec![
158            a.data.iter().copied().fold(f32::NEG_INFINITY, f32::max),
159        ]),
160        Some(axis) => reduce_along_axis(a, axis, |slice| {
161            slice.iter().copied().fold(f32::NEG_INFINITY, f32::max)
162        }),
163    }
164}
165
166fn reduce_along_axis(
167    a: &NodeInput<'_>,
168    axis: i32,
169    reducer: impl Fn(&[f32]) -> f32,
170) -> Result<Vec<f32>> {
171    let ndim = a.shape.ndim() as i32;
172    let ax = if axis < 0 { ndim + axis } else { axis };
173    if ax < 0 || ax >= ndim {
174        return Err(MlxError::InvalidArgument(format!(
175            "axis {axis} out of range for ndim {ndim}"
176        )));
177    }
178    let ax = ax as usize;
179
180    let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
181    let dim: usize = a.shape.0[ax] as usize;
182    let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
183
184    let mut result = Vec::with_capacity(outer * inner);
185    for o in 0..outer {
186        for i in 0..inner {
187            let mut slice = Vec::with_capacity(dim);
188            for d in 0..dim {
189                slice.push(a.data[o * dim * inner + d * inner + i]);
190            }
191            result.push(reducer(&slice));
192        }
193    }
194    Ok(result)
195}
196
197fn matmul(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
198    let a = require_input(inputs, 0)?;
199    let b = require_input(inputs, 1)?;
200
201    if a.shape.ndim() != 2 || b.shape.ndim() != 2 {
202        return Err(MlxError::InvalidArgument(
203            "matmul requires 2D tensors".into(),
204        ));
205    }
206
207    let m = a.shape.0[0] as usize;
208    let k = a.shape.0[1] as usize;
209    let k2 = b.shape.0[0] as usize;
210    let n = b.shape.0[1] as usize;
211
212    if k != k2 {
213        return Err(MlxError::ShapeMismatch {
214            expected: vec![m as i64, k as i64],
215            got: vec![k2 as i64, n as i64],
216        });
217    }
218
219    let mut data = vec![0.0f32; m * n];
220    for i in 0..m {
221        for j in 0..n {
222            let mut sum = 0.0f32;
223            for p in 0..k {
224                sum += a.data[i * k + p] * b.data[p * n + j];
225            }
226            data[i * n + j] = sum;
227        }
228    }
229    Ok(data)
230}
231
232fn transpose(inputs: &[NodeInput<'_>], axes: Option<&[usize]>) -> Result<Vec<f32>> {
233    let a = require_input(inputs, 0)?;
234    let ndim = a.shape.ndim();
235
236    let perm: Vec<usize> = match axes {
237        Some(ax) => ax.to_vec(),
238        None => (0..ndim).rev().collect(),
239    };
240
241    if perm.len() != ndim {
242        return Err(MlxError::InvalidArgument(
243            "transpose axes length must match ndim".into(),
244        ));
245    }
246
247    let old_shape: Vec<usize> = a.shape.0.iter().map(|&d| d as usize).collect();
248    let new_shape: Vec<usize> = perm.iter().map(|&ax| old_shape[ax]).collect();
249
250    // Compute strides for the old shape.
251    let mut old_strides = vec![1usize; ndim];
252    for i in (0..ndim.saturating_sub(1)).rev() {
253        old_strides[i] = old_strides[i + 1] * old_shape[i + 1];
254    }
255
256    let total = a.data.len();
257    let mut result = vec![0.0f32; total];
258
259    for (flat, out) in result.iter_mut().enumerate() {
260        // Convert flat index → multi-index in NEW shape.
261        let mut remaining = flat;
262        let mut old_flat = 0;
263        for dim_idx in 0..ndim {
264            let new_dim_size: usize = new_shape[dim_idx + 1..].iter().product::<usize>().max(1);
265            let coord = remaining / new_dim_size;
266            remaining %= new_dim_size;
267            // This coord in the new tensor corresponds to perm[dim_idx] axis in old tensor.
268            old_flat += coord * old_strides[perm[dim_idx]];
269        }
270        *out = a.data[old_flat];
271    }
272
273    Ok(result)
274}
275
276fn softmax(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
277    let a = require_input(inputs, 0)?;
278    let ndim = a.shape.ndim() as i32;
279    let ax = if axis < 0 { ndim + axis } else { axis };
280    if ax < 0 || ax >= ndim {
281        return Err(MlxError::InvalidArgument(format!(
282            "axis {axis} out of range for ndim {ndim}"
283        )));
284    }
285    let ax = ax as usize;
286
287    let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
288    let dim: usize = a.shape.0[ax] as usize;
289    let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
290
291    let mut data = a.data.to_vec();
292
293    for o in 0..outer {
294        for i in 0..inner {
295            let mut max_val = f32::NEG_INFINITY;
296            for d in 0..dim {
297                let idx = o * dim * inner + d * inner + i;
298                if data[idx] > max_val {
299                    max_val = data[idx];
300                }
301            }
302            let mut sum_exp = 0.0f32;
303            for d in 0..dim {
304                let idx = o * dim * inner + d * inner + i;
305                data[idx] = (data[idx] - max_val).exp();
306                sum_exp += data[idx];
307            }
308            for d in 0..dim {
309                let idx = o * dim * inner + d * inner + i;
310                data[idx] /= sum_exp;
311            }
312        }
313    }
314    Ok(data)
315}
316
317fn layer_norm(inputs: &[NodeInput<'_>], eps: f32, _meta: &TensorMeta) -> Result<Vec<f32>> {
318    let a = require_input(inputs, 0)?;
319    // LayerNorm normalizes over the last dimension.
320    let ndim = a.shape.ndim();
321    if ndim == 0 {
322        return Ok(a.data.to_vec());
323    }
324    let last_dim = a.shape.0[ndim - 1] as usize;
325    let outer = a.data.len() / last_dim;
326
327    let mut result = vec![0.0f32; a.data.len()];
328    for o in 0..outer {
329        let start = o * last_dim;
330        let end = start + last_dim;
331        let slice = &a.data[start..end];
332
333        let mean = slice.iter().sum::<f32>() / last_dim as f32;
334        let var = slice.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / last_dim as f32;
335        let std = (var + eps).sqrt();
336
337        for (i, &x) in slice.iter().enumerate() {
338            result[start + i] = (x - mean) / std;
339        }
340    }
341    Ok(result)
342}
343
344fn broadcast(inputs: &[NodeInput<'_>], target_shape: &crate::Shape) -> Result<Vec<f32>> {
345    let a = require_input(inputs, 0)?;
346    let in_shape = &a.shape.0;
347    let out_shape = &target_shape.0;
348    let out_ndim = out_shape.len();
349    let in_ndim = in_shape.len();
350    let pad = out_ndim - in_ndim;
351    let total: usize = out_shape.iter().product::<i64>() as usize;
352
353    let mut result = vec![0.0f32; total];
354    for (out_flat, out) in result.iter_mut().enumerate() {
355        let mut remaining = out_flat;
356        let mut in_flat = 0usize;
357        let mut in_stride = 1usize;
358
359        for d in (0..out_ndim).rev() {
360            let out_dim = out_shape[d] as usize;
361            let coord = remaining % out_dim;
362            remaining /= out_dim;
363
364            if d >= pad {
365                let in_d = d - pad;
366                let in_dim = in_shape[in_d] as usize;
367                let in_coord = if in_dim == 1 { 0 } else { coord };
368                in_flat += in_coord * in_stride;
369                in_stride *= in_dim;
370            }
371        }
372        *out = a.data[in_flat];
373    }
374    Ok(result)
375}
376
377/// LayerNorm backward: inputs = [grad_output, original_input].
378///
379/// For x of shape [..., D] (normalized over last dim, no affine params):
380///   x_hat = (x - mean) / std
381///   dx_i = (1/std) * (dy_i - mean(dy) - x_hat_i * mean(dy * x_hat))
382fn layer_norm_vjp(inputs: &[NodeInput<'_>], eps: f32) -> Result<Vec<f32>> {
383    let dy = require_input(inputs, 0)?;
384    let x = require_input(inputs, 1)?;
385    if dy.shape != x.shape || dy.data.len() != x.data.len() {
386        return Err(MlxError::ShapeMismatch {
387            expected: x.shape.0.clone(),
388            got: dy.shape.0.clone(),
389        });
390    }
391    let ndim = x.shape.ndim();
392    if ndim == 0 {
393        return Ok(dy.data.to_vec());
394    }
395    let d = x.shape.0[ndim - 1] as usize;
396    if d == 0 || x.data.is_empty() {
397        return Ok(vec![0.0f32; x.data.len()]);
398    }
399    let d_f = d as f32;
400    let outer = x.data.len() / d;
401
402    let mut result = vec![0.0f32; x.data.len()];
403    for o in 0..outer {
404        let start = o * d;
405        let end = start + d;
406        let x_slice = &x.data[start..end];
407        let dy_slice = &dy.data[start..end];
408
409        // Forward recomputation
410        let mean = x_slice.iter().sum::<f32>() / d_f;
411        let var = x_slice.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / d_f;
412        let std = (var + eps).sqrt();
413        let inv_std = 1.0 / std;
414
415        // x_hat = (x - mean) / std
416        let x_hat: Vec<f32> = x_slice.iter().map(|v| (v - mean) * inv_std).collect();
417
418        // mean(dy) and mean(dy * x_hat)
419        let mean_dy = dy_slice.iter().sum::<f32>() / d_f;
420        let mean_dy_xhat: f32 = dy_slice
421            .iter()
422            .zip(x_hat.iter())
423            .map(|(a, b)| a * b)
424            .sum::<f32>()
425            / d_f;
426
427        // dx_i = inv_std * (dy_i - mean_dy - x_hat_i * mean_dy_xhat)
428        for i in 0..d {
429            result[start + i] = inv_std * (dy_slice[i] - mean_dy - x_hat[i] * mean_dy_xhat);
430        }
431    }
432    Ok(result)
433}
434
435/// RmsNorm backward: inputs = [grad_output, original_input].
436///
437/// For x of shape [..., D] (normalized over last dim, no affine params):
438///   rms = sqrt(mean(x^2) + eps)
439///   y = x / rms
440///   dx_i = (1/rms) * (dy_i - y_i * mean(dy * y))
441fn rms_norm_vjp(inputs: &[NodeInput<'_>], eps: f32) -> Result<Vec<f32>> {
442    let dy = require_input(inputs, 0)?;
443    let x = require_input(inputs, 1)?;
444    if dy.shape != x.shape || dy.data.len() != x.data.len() {
445        return Err(MlxError::ShapeMismatch {
446            expected: x.shape.0.clone(),
447            got: dy.shape.0.clone(),
448        });
449    }
450    let ndim = x.shape.ndim();
451    if ndim == 0 {
452        return Ok(dy.data.to_vec());
453    }
454    let d = x.shape.0[ndim - 1] as usize;
455    if d == 0 || x.data.is_empty() {
456        return Ok(vec![0.0f32; x.data.len()]);
457    }
458    let d_f = d as f32;
459    let outer = x.data.len() / d;
460
461    let mut result = vec![0.0f32; x.data.len()];
462    for o in 0..outer {
463        let start = o * d;
464        let end = start + d;
465        let x_slice = &x.data[start..end];
466        let dy_slice = &dy.data[start..end];
467
468        // Forward recomputation
469        let rms = (x_slice.iter().map(|v| v * v).sum::<f32>() / d_f + eps).sqrt();
470        let inv_rms = 1.0 / rms;
471
472        // y = x / rms
473        let y: Vec<f32> = x_slice.iter().map(|v| v * inv_rms).collect();
474
475        // mean(dy * y)
476        let mean_dy_y: f32 = dy_slice
477            .iter()
478            .zip(y.iter())
479            .map(|(a, b)| a * b)
480            .sum::<f32>()
481            / d_f;
482
483        // dx_i = inv_rms * (dy_i - y_i * mean_dy_y)
484        for i in 0..d {
485            result[start + i] = inv_rms * (dy_slice[i] - y[i] * mean_dy_y);
486        }
487    }
488    Ok(result)
489}
490
491/// Softmax backward: inputs = [grad_output, softmax_output].
492///
493/// dx_i = s_i * (dy_i - sum(dy * s))
494fn softmax_vjp(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
495    let dy = require_input(inputs, 0)?;
496    let s = require_input(inputs, 1)?;
497    if dy.data.len() != s.data.len() {
498        return Err(MlxError::ShapeMismatch {
499            expected: s.shape.0.clone(),
500            got: dy.shape.0.clone(),
501        });
502    }
503    let ndim = s.shape.ndim() as i32;
504    let ax = if axis < 0 { ndim + axis } else { axis };
505    if ax < 0 || ax >= ndim {
506        return Err(MlxError::InvalidArgument(format!(
507            "axis {axis} out of range for ndim {ndim}"
508        )));
509    }
510    let ax = ax as usize;
511
512    let outer: usize = s.shape.0[..ax].iter().product::<i64>() as usize;
513    let dim: usize = s.shape.0[ax] as usize;
514    let inner: usize = s.shape.0[ax + 1..].iter().product::<i64>() as usize;
515
516    let mut result = vec![0.0f32; s.data.len()];
517    for o in 0..outer {
518        for i in 0..inner {
519            // dot = sum(dy * s) along the axis
520            let mut dot = 0.0f32;
521            for d in 0..dim {
522                let idx = o * dim * inner + d * inner + i;
523                dot += dy.data[idx] * s.data[idx];
524            }
525            for d in 0..dim {
526                let idx = o * dim * inner + d * inner + i;
527                result[idx] = s.data[idx] * (dy.data[idx] - dot);
528            }
529        }
530    }
531    Ok(result)
532}
533
534/// SiLU backward: inputs = [grad_output, original_input].
535///
536/// d_silu/dx = σ(x) * (1 + x * (1 - σ(x)))
537fn silu_vjp(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
538    let dy = require_input(inputs, 0)?;
539    let x = require_input(inputs, 1)?;
540    if dy.data.len() != x.data.len() {
541        return Err(MlxError::ShapeMismatch {
542            expected: x.shape.0.clone(),
543            got: dy.shape.0.clone(),
544        });
545    }
546    Ok(dy
547        .data
548        .iter()
549        .zip(x.data.iter())
550        .map(|(&dy_i, &x_i)| {
551            let sig = sigmoid(x_i);
552            dy_i * sig * (1.0 + x_i * (1.0 - sig))
553        })
554        .collect())
555}
556
557/// GELU backward (tanh approximation): inputs = [grad_output, original_input].
558///
559/// gelu(x) = 0.5x(1 + tanh(a(x + bx³)))
560/// d_gelu/dx = 0.5(1 + tanh(...)) + 0.5x * sech²(...) * a(1 + 3bx²)
561fn gelu_vjp(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
562    let dy = require_input(inputs, 0)?;
563    let x = require_input(inputs, 1)?;
564    if dy.data.len() != x.data.len() {
565        return Err(MlxError::ShapeMismatch {
566            expected: x.shape.0.clone(),
567            got: dy.shape.0.clone(),
568        });
569    }
570    let a = (2.0f32 / std::f32::consts::PI).sqrt();
571    let b = 0.044715f32;
572    Ok(dy
573        .data
574        .iter()
575        .zip(x.data.iter())
576        .map(|(&dy_i, &x_i)| {
577            let inner = a * (x_i + b * x_i * x_i * x_i);
578            let tanh_inner = inner.tanh();
579            let sech2 = 1.0 - tanh_inner * tanh_inner;
580            let dgelu =
581                0.5 * (1.0 + tanh_inner) + 0.5 * x_i * sech2 * a * (1.0 + 3.0 * b * x_i * x_i);
582            dy_i * dgelu
583        })
584        .collect())
585}
586
587fn cpu_rope(
588    inputs: &[NodeInput<'_>],
589    meta: &TensorMeta,
590    rotary_dim: usize,
591    pos_offset: usize,
592    theta: f32,
593) -> Result<Vec<f32>> {
594    let x = require_input(inputs, 0)?;
595    if meta.shape.ndim() != 2 {
596        return Err(MlxError::InvalidArgument(
597            "Rope input must be 2-D [tokens, head_dim]".into(),
598        ));
599    }
600    let tokens = meta.shape.0[0] as usize;
601    let head_dim = meta.shape.0[1] as usize;
602    if rotary_dim > head_dim || !rotary_dim.is_multiple_of(2) {
603        return Err(MlxError::InvalidArgument(
604            "rotary_dim must be even and <= head_dim".into(),
605        ));
606    }
607
608    let mut out = x.data.to_vec();
609    for t in 0..tokens {
610        for i in 0..rotary_dim / 2 {
611            let inv_freq = theta.powf(-2.0 * i as f32 / rotary_dim as f32);
612            let angle = (pos_offset + t) as f32 * inv_freq;
613            let (s, c) = angle.sin_cos();
614
615            let base = t * head_dim + i * 2;
616            let x0 = x.data[base];
617            let x1 = x.data[base + 1];
618
619            out[base] = x0 * c - x1 * s;
620            out[base + 1] = x0 * s + x1 * c;
621        }
622    }
623    Ok(out)
624}
625
626fn rms_norm(inputs: &[NodeInput<'_>], eps: f32, _meta: &TensorMeta) -> Result<Vec<f32>> {
627    let a = require_input(inputs, 0)?;
628    let ndim = a.shape.ndim();
629    if ndim == 0 {
630        return Ok(a.data.to_vec());
631    }
632    let last_dim = a.shape.0[ndim - 1] as usize;
633    let outer = a.data.len() / last_dim;
634
635    let mut result = vec![0.0f32; a.data.len()];
636    for o in 0..outer {
637        let start = o * last_dim;
638        let end = start + last_dim;
639        let slice = &a.data[start..end];
640
641        let rms = (slice.iter().map(|x| x * x).sum::<f32>() / last_dim as f32 + eps).sqrt();
642
643        for (i, &x) in slice.iter().enumerate() {
644            result[start + i] = x / rms;
645        }
646    }
647    Ok(result)
648}
649
650fn scaled_masked_softmax(inputs: &[NodeInput<'_>], scale: f32, causal: bool) -> Result<Vec<f32>> {
651    let a = require_input(inputs, 0)?;
652    if a.shape.ndim() != 2 {
653        return Err(MlxError::InvalidArgument(
654            "ScaledMaskedSoftmax requires 2D input [Tq, Tk]".into(),
655        ));
656    }
657    let tq = a.shape.0[0] as usize;
658    let tk = a.shape.0[1] as usize;
659
660    let mut data = vec![0.0f32; tq * tk];
661
662    for i in 0..tq {
663        // Scale + mask
664        for j in 0..tk {
665            let idx = i * tk + j;
666            let mut val = a.data[idx] * scale;
667            if causal && j > i {
668                val = -1e9;
669            }
670            data[idx] = val;
671        }
672
673        // Numerically stable softmax per row
674        let row_start = i * tk;
675        let mut max_val = f32::NEG_INFINITY;
676        for j in 0..tk {
677            if data[row_start + j] > max_val {
678                max_val = data[row_start + j];
679            }
680        }
681        let mut sum_exp = 0.0f32;
682        for j in 0..tk {
683            data[row_start + j] = (data[row_start + j] - max_val).exp();
684            sum_exp += data[row_start + j];
685        }
686        for j in 0..tk {
687            data[row_start + j] /= sum_exp;
688        }
689    }
690    Ok(data)
691}
692
693fn cpu_matmul_raw(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
694    let mut out = vec![0.0f32; m * n];
695    for i in 0..m {
696        for j in 0..n {
697            let mut sum = 0.0f32;
698            for p in 0..k {
699                sum += a[i * k + p] * b[p * n + j];
700            }
701            out[i * n + j] = sum;
702        }
703    }
704    out
705}
706
707fn cpu_transpose_2d(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
708    let mut out = vec![0.0f32; rows * cols];
709    for r in 0..rows {
710        for c in 0..cols {
711            out[c * rows + r] = data[r * cols + c];
712        }
713    }
714    out
715}
716
717fn cpu_attention(inputs: &[NodeInput<'_>], scale: f32, causal: bool) -> Result<Vec<f32>> {
718    if inputs.len() != 3 {
719        return Err(MlxError::InvalidArgument(
720            "Attention requires exactly 3 inputs [Q, K, V]".into(),
721        ));
722    }
723    let q = require_input(inputs, 0)?;
724    let k = require_input(inputs, 1)?;
725    let v = require_input(inputs, 2)?;
726
727    if q.shape.ndim() != 2 || k.shape.ndim() != 2 || v.shape.ndim() != 2 {
728        return Err(MlxError::InvalidArgument(
729            "Attention inputs must be 2D".into(),
730        ));
731    }
732
733    let tq = q.shape.0[0] as usize;
734    let dh = q.shape.0[1] as usize;
735    let tk = k.shape.0[0] as usize;
736    let dh_k = k.shape.0[1] as usize;
737    let tk_v = v.shape.0[0] as usize;
738    let dh_v = v.shape.0[1] as usize;
739
740    if dh != dh_k {
741        return Err(MlxError::ShapeMismatch {
742            expected: vec![tq as i64, dh as i64],
743            got: vec![tk as i64, dh_k as i64],
744        });
745    }
746    if tk != tk_v || dh != dh_v {
747        return Err(MlxError::ShapeMismatch {
748            expected: vec![tk as i64, dh as i64],
749            got: vec![tk_v as i64, dh_v as i64],
750        });
751    }
752
753    // 1. Transpose K: [Tk, Dh] -> [Dh, Tk]
754    let kt = cpu_transpose_2d(k.data, tk, dh);
755
756    // 2. scores = Q @ K^T: [Tq, Dh] @ [Dh, Tk] -> [Tq, Tk]
757    let scores = cpu_matmul_raw(q.data, &kt, tq, dh, tk);
758
759    // 3. Scaled masked softmax on scores
760    let mut probs = vec![0.0f32; tq * tk];
761    for i in 0..tq {
762        for j in 0..tk {
763            let idx = i * tk + j;
764            let mut val = scores[idx] * scale;
765            if causal && j > i {
766                val = -1e9;
767            }
768            probs[idx] = val;
769        }
770        let row_start = i * tk;
771        let mut max_val = f32::NEG_INFINITY;
772        for j in 0..tk {
773            if probs[row_start + j] > max_val {
774                max_val = probs[row_start + j];
775            }
776        }
777        let mut sum_exp = 0.0f32;
778        for j in 0..tk {
779            probs[row_start + j] = (probs[row_start + j] - max_val).exp();
780            sum_exp += probs[row_start + j];
781        }
782        for j in 0..tk {
783            probs[row_start + j] /= sum_exp;
784        }
785    }
786
787    // 4. Y = P @ V: [Tq, Tk] @ [Tk, Dh] -> [Tq, Dh]
788    let y = cpu_matmul_raw(&probs, v.data, tq, tk, dh);
789
790    Ok(y)
791}
792
793fn embedding(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
794    let weight = require_input(inputs, 0)?;
795    let indices = require_input(inputs, 1)?;
796
797    if weight.shape.ndim() != 2 {
798        return Err(MlxError::InvalidArgument(
799            "Embedding weight must be 2D [vocab_size, embed_dim]".into(),
800        ));
801    }
802    if indices.shape.ndim() != 1 {
803        return Err(MlxError::InvalidArgument(
804            "Embedding indices must be 1D [seq_len]".into(),
805        ));
806    }
807    let vocab_size = weight.shape.0[0] as usize;
808    let embed_dim = weight.shape.0[1] as usize;
809    let seq_len = indices.data.len();
810
811    let mut result = Vec::with_capacity(seq_len * embed_dim);
812    for &idx_f in indices.data {
813        if idx_f < 0.0 || idx_f != idx_f.trunc() {
814            return Err(MlxError::InvalidArgument(format!(
815                "Embedding index must be a non-negative integer, got {idx_f}"
816            )));
817        }
818        let idx = idx_f as usize;
819        if idx >= vocab_size {
820            return Err(MlxError::InvalidArgument(format!(
821                "Embedding index {idx} out of range for vocab_size {vocab_size}"
822            )));
823        }
824        let start = idx * embed_dim;
825        result.extend_from_slice(&weight.data[start..start + embed_dim]);
826    }
827    Ok(result)
828}
829
830fn narrow(inputs: &[NodeInput<'_>], axis: i32, start: i64, length: i64) -> Result<Vec<f32>> {
831    let a = require_input(inputs, 0)?;
832    let ndim = a.shape.ndim() as i32;
833    let ax = if axis < 0 { ndim + axis } else { axis };
834    if ax < 0 || ax >= ndim {
835        return Err(MlxError::InvalidArgument(format!(
836            "narrow: axis {axis} out of range for ndim {ndim}"
837        )));
838    }
839    let ax = ax as usize;
840    let dim_size = a.shape.0[ax] as i64;
841    if start < 0 || start + length > dim_size {
842        return Err(MlxError::InvalidArgument(format!(
843            "narrow: start {start} + length {length} exceeds dim size {dim_size}"
844        )));
845    }
846
847    let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
848    let dim: usize = a.shape.0[ax] as usize;
849    let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
850    let start = start as usize;
851    let length = length as usize;
852
853    let mut result = Vec::with_capacity(outer * length * inner);
854    for o in 0..outer {
855        for d in start..start + length {
856            let base = (o * dim + d) * inner;
857            result.extend_from_slice(&a.data[base..base + inner]);
858        }
859    }
860    Ok(result)
861}
862
863fn concatenate(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
864    if inputs.is_empty() {
865        return Err(MlxError::InvalidArgument(
866            "Concatenate requires at least one input".into(),
867        ));
868    }
869    let first = &inputs[0];
870    let ndim = first.shape.ndim() as i32;
871    let ax = if axis < 0 { ndim + axis } else { axis };
872    if ax < 0 || ax >= ndim {
873        return Err(MlxError::InvalidArgument(format!(
874            "concatenate: axis {axis} out of range for ndim {ndim}"
875        )));
876    }
877    let ax = ax as usize;
878
879    // Validate all inputs have same shape except along concat axis
880    for inp in &inputs[1..] {
881        if inp.shape.ndim() != first.shape.ndim() {
882            return Err(MlxError::InvalidArgument(
883                "Concatenate: all inputs must have same ndim".into(),
884            ));
885        }
886        for (d, (&a, &b)) in first.shape.0.iter().zip(inp.shape.0.iter()).enumerate() {
887            if d != ax && a != b {
888                return Err(MlxError::ShapeMismatch {
889                    expected: first.shape.0.clone(),
890                    got: inp.shape.0.clone(),
891                });
892            }
893        }
894    }
895
896    let outer: usize = first.shape.0[..ax].iter().product::<i64>() as usize;
897    let inner: usize = first.shape.0[ax + 1..].iter().product::<i64>() as usize;
898
899    let total_dim: usize = inputs.iter().map(|i| i.shape.0[ax] as usize).sum();
900    let mut result = Vec::with_capacity(outer * total_dim * inner);
901
902    for o in 0..outer {
903        for inp in inputs {
904            let dim = inp.shape.0[ax] as usize;
905            let base = o * dim * inner;
906            result.extend_from_slice(&inp.data[base..base + dim * inner]);
907        }
908    }
909    Ok(result)
910}
911
912fn rope(inputs: &[NodeInput<'_>], base: f32, offset: usize, traditional: bool) -> Result<Vec<f32>> {
913    let a = require_input(inputs, 0)?;
914    let ndim = a.shape.ndim();
915    if ndim < 1 {
916        return Err(MlxError::InvalidArgument(
917            "RoPE requires at least 1 dimension".into(),
918        ));
919    }
920
921    let head_dim = a.shape.0[ndim - 1] as usize;
922    if !head_dim.is_multiple_of(2) {
923        return Err(MlxError::InvalidArgument(format!(
924            "RoPE head_dim must be even, got {head_dim}"
925        )));
926    }
927    let half_dim = head_dim / 2;
928
929    let total = a.data.len();
930    let num_heads_total = total / head_dim;
931
932    let mut result = vec![0.0f32; total];
933
934    for i in 0..num_heads_total {
935        // Calculate position based on offset.
936        // Assuming flattening over batch/seq for now.
937        // More robust logic would use shape explicitly.
938        // Here we simplify assuming linear indexing corresponds to position.
939        // Wait, issue specified (tokens, head_dim) -> i corresponds to token index (pos).
940
941        let pos = (offset + i) as f32;
942
943        for d in 0..half_dim {
944            let theta = pos * base.powf(-(2.0 * d as f32 / head_dim as f32));
945            let cos_theta = theta.cos();
946            let sin_theta = theta.sin();
947
948            if traditional {
949                // Pairs are adjacent: (2d, 2d+1)
950                let idx0 = i * head_dim + 2 * d;
951                let idx1 = idx0 + 1;
952
953                let x0 = a.data[idx0];
954                let x1 = a.data[idx1];
955
956                result[idx0] = x0 * cos_theta - x1 * sin_theta;
957                result[idx1] = x0 * sin_theta + x1 * cos_theta;
958            } else {
959                // OpenAI style: pairs are (d, d + half_dim)
960                let idx0 = i * head_dim + d;
961                let idx1 = i * head_dim + d + half_dim;
962
963                let x0 = a.data[idx0];
964                let x1 = a.data[idx1];
965
966                result[idx0] = x0 * cos_theta - x1 * sin_theta;
967                result[idx1] = x0 * sin_theta + x1 * cos_theta;
968            }
969        }
970    }
971    Ok(result)
972}
973
974#[cfg(test)]
975mod tests {
976    use super::*;
977    use crate::graph::TensorMeta;
978    use crate::types::Shape;
979
980    fn meta(shape: Vec<i64>) -> TensorMeta {
981        TensorMeta {
982            shape: Shape::new(shape),
983            dtype: crate::DType::F32,
984        }
985    }
986
987    fn input(data: &[f32], shape: Vec<i64>) -> NodeInput<'_> {
988        // We need to leak the shape to get a reference. Use a workaround.
989        NodeInput {
990            data,
991            shape: Box::leak(Box::new(Shape::new(shape))),
992            dtype: crate::DType::F32,
993        }
994    }
995
996    #[test]
997    fn test_add() {
998        let backend = CpuRefBackend;
999        let a_data = [1.0, 2.0, 3.0];
1000        let b_data = [4.0, 5.0, 6.0];
1001        let result = backend
1002            .eval_node(
1003                &OpKind::Add,
1004                &[input(&a_data, vec![3]), input(&b_data, vec![3])],
1005                &meta(vec![3]),
1006            )
1007            .unwrap();
1008        assert_eq!(result, vec![5.0, 7.0, 9.0]);
1009    }
1010
1011    #[test]
1012    fn test_matmul() {
1013        let backend = CpuRefBackend;
1014        let a_data = [1.0, 2.0, 3.0, 4.0];
1015        let b_data = [5.0, 6.0, 7.0, 8.0];
1016        let result = backend
1017            .eval_node(
1018                &OpKind::MatMul,
1019                &[input(&a_data, vec![2, 2]), input(&b_data, vec![2, 2])],
1020                &meta(vec![2, 2]),
1021            )
1022            .unwrap();
1023        assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
1024    }
1025
1026    #[test]
1027    fn test_softmax() {
1028        let backend = CpuRefBackend;
1029        let data = [1.0, 2.0, 3.0];
1030        let result = backend
1031            .eval_node(
1032                &OpKind::Softmax { axis: 0 },
1033                &[input(&data, vec![3])],
1034                &meta(vec![3]),
1035            )
1036            .unwrap();
1037        let sum: f32 = result.iter().sum();
1038        assert!((sum - 1.0).abs() < 1e-6);
1039        assert!(result[0] < result[1]);
1040        assert!(result[1] < result[2]);
1041    }
1042
1043    #[test]
1044    fn test_neg() {
1045        let backend = CpuRefBackend;
1046        let data = [1.0, -2.0, 3.0];
1047        let result = backend
1048            .eval_node(&OpKind::Neg, &[input(&data, vec![3])], &meta(vec![3]))
1049            .unwrap();
1050        assert_eq!(result, vec![-1.0, 2.0, -3.0]);
1051    }
1052
1053    #[test]
1054    fn test_layer_norm() {
1055        let backend = CpuRefBackend;
1056        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1057        let result = backend
1058            .eval_node(
1059                &OpKind::LayerNorm { eps: 1e-5 },
1060                &[input(&data, vec![2, 3])],
1061                &meta(vec![2, 3]),
1062            )
1063            .unwrap();
1064        // Each row should be normalized to mean≈0, std≈1
1065        let row1_mean: f32 = result[0..3].iter().sum::<f32>() / 3.0;
1066        assert!(row1_mean.abs() < 1e-5);
1067    }
1068
1069    #[test]
1070    fn test_reduce_sum_axis() {
1071        let backend = CpuRefBackend;
1072        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1073        let result = backend
1074            .eval_node(
1075                &OpKind::Sum { axis: Some(0) },
1076                &[input(&data, vec![2, 3])],
1077                &meta(vec![3]),
1078            )
1079            .unwrap();
1080        assert_eq!(result, vec![5.0, 7.0, 9.0]);
1081    }
1082
1083    #[test]
1084    fn test_reduce_sum_all() {
1085        let backend = CpuRefBackend;
1086        let data = [1.0, 2.0, 3.0];
1087        let result = backend
1088            .eval_node(
1089                &OpKind::Sum { axis: None },
1090                &[input(&data, vec![3])],
1091                &meta(vec![]),
1092            )
1093            .unwrap();
1094        assert_eq!(result, vec![6.0]);
1095    }
1096
1097    #[test]
1098    fn test_silu() {
1099        let backend = CpuRefBackend;
1100        let data = [0.0, 1.0, -1.0];
1101        let result = backend
1102            .eval_node(&OpKind::Silu, &[input(&data, vec![3])], &meta(vec![3]))
1103            .unwrap();
1104        // silu(0) = 0, silu(1) ≈ 0.7311, silu(-1) ≈ -0.2689
1105        assert!((result[1] - 0.7311).abs() < 1e-3);
1106        assert!((result[2] - (-0.2689)).abs() < 1e-3);
1107    }
1108
1109    #[test]
1110    fn test_rope_offsets() {
1111        let backend = CpuRefBackend;
1112        let theta = 10_000.0;
1113        let pos_offset = 100usize;
1114        let rotary_dim = 4;
1115        // Shape: 1 seq, 4 head_dim. total = 4 floats.
1116        let data = [1.0, 0.0, 0.0, 1.0];
1117        let result = backend
1118            .eval_node(
1119                &OpKind::Rope {
1120                    rotary_dim,
1121                    pos_offset,
1122                    theta,
1123                },
1124                &[input(&data, vec![1, 4])],
1125                &meta(vec![1, 4]),
1126            )
1127            .unwrap();
1128
1129        // Expected values (interleaved)
1130        // i=0: inv_freq = 1.0. angle = 100.
1131        let cos100 = 100.0f32.cos();
1132        let sin100 = 100.0f32.sin();
1133        // i=1: inv_freq = 10000^-0.5 = 0.01. angle = 1.0.
1134        let cos1 = 1.0f32.cos();
1135        let sin1 = 1.0f32.sin();
1136
1137        // data[0]=1, data[1]=0 -> out[0]=cos, out[1]=sin
1138        // data[2]=0, data[3]=1 -> out[2]=-sin, out[3]=cos
1139        assert!((result[0] - cos100).abs() < 1e-5);
1140        assert!((result[1] - sin100).abs() < 1e-5);
1141        assert!((result[2] - (-sin1)).abs() < 1e-5);
1142        assert!((result[3] - cos1).abs() < 1e-5);
1143    }
1144
1145    #[test]
1146    fn test_rope_large() {
1147        let backend = CpuRefBackend;
1148        let shape = vec![128, 128];
1149        let numel = 128 * 128;
1150        let data = vec![1.0; numel];
1151        let result = backend.eval_node(
1152            &OpKind::Rope {
1153                rotary_dim: 128,
1154                pos_offset: 0,
1155                theta: 10000.0,
1156            },
1157            &[input(&data, shape.clone())],
1158            &meta(shape.clone()),
1159        );
1160        assert!(result.is_ok());
1161        assert_eq!(result.unwrap().len(), numel);
1162    }
1163}