1use crate::backend::{Backend, NodeInput};
7use crate::graph::{OpKind, TensorMeta};
8use crate::{MlxError, Result};
9
10pub struct CpuRefBackend;
12
13impl Backend for CpuRefBackend {
14 fn eval_node(
15 &self,
16 op: &OpKind,
17 inputs: &[NodeInput<'_>],
18 output_meta: &TensorMeta,
19 ) -> Result<Vec<f32>> {
20 match op {
21 OpKind::Constant | OpKind::Parameter => Err(MlxError::InvalidArgument(
22 "Constant/Parameter nodes should be pre-materialized".into(),
23 )),
24 OpKind::Add => binary_elementwise(inputs, |a, b| a + b),
25 OpKind::Mul => binary_elementwise(inputs, |a, b| a * b),
26 OpKind::Sub => binary_elementwise(inputs, |a, b| a - b),
27 OpKind::Div => binary_elementwise(inputs, |a, b| a / b),
28 OpKind::Neg => {
29 let a = require_input(inputs, 0)?;
30 Ok(a.data.iter().map(|x| -x).collect())
31 }
32 OpKind::Exp => {
33 let a = require_input(inputs, 0)?;
34 Ok(a.data.iter().map(|x| x.exp()).collect())
35 }
36 OpKind::Log => {
37 let a = require_input(inputs, 0)?;
38 Ok(a.data.iter().map(|x| x.ln()).collect())
39 }
40 OpKind::Sum { axis } => reduce_sum(inputs, *axis),
41 OpKind::Mean { axis } => reduce_mean(inputs, *axis),
42 OpKind::Max { axis } => reduce_max(inputs, *axis),
43 OpKind::MatMul => matmul(inputs),
44 OpKind::Reshape { .. } => {
45 let a = require_input(inputs, 0)?;
46 Ok(a.data.to_vec())
47 }
48 OpKind::Transpose { axes } => transpose(inputs, axes.as_deref()),
49 OpKind::Softmax { axis } => softmax(inputs, *axis),
50 OpKind::Silu => {
51 let a = require_input(inputs, 0)?;
52 Ok(a.data.iter().map(|&x| x * sigmoid(x)).collect())
53 }
54 OpKind::Gelu => {
55 let a = require_input(inputs, 0)?;
56 Ok(a.data
57 .iter()
58 .map(|&x| {
59 0.5 * x
60 * (1.0
61 + ((2.0 / std::f32::consts::PI).sqrt()
62 * (x + 0.044715 * x * x * x))
63 .tanh())
64 })
65 .collect())
66 }
67 OpKind::LayerNorm { eps } => layer_norm(inputs, *eps, output_meta),
68 OpKind::RmsNorm { eps } => rms_norm(inputs, *eps, output_meta),
69 OpKind::Broadcast { target_shape } => broadcast(inputs, target_shape),
70 OpKind::ScaledMaskedSoftmax { scale, causal } => {
71 scaled_masked_softmax(inputs, *scale, *causal)
72 }
73 OpKind::Attention { scale, causal } => cpu_attention(inputs, *scale, *causal),
74 OpKind::Rope {
75 rotary_dim,
76 pos_offset,
77 theta,
78 } => cpu_rope(inputs, output_meta, *rotary_dim, *pos_offset, *theta),
79 OpKind::LayerNormVjp { eps } => layer_norm_vjp(inputs, *eps),
80 OpKind::RmsNormVjp { eps } => rms_norm_vjp(inputs, *eps),
81 OpKind::SoftmaxVjp { axis } => softmax_vjp(inputs, *axis),
82 OpKind::SiluVjp => silu_vjp(inputs),
83 OpKind::GeluVjp => gelu_vjp(inputs),
84 OpKind::Sqrt => {
85 let a = require_input(inputs, 0)?;
86 Ok(a.data.iter().map(|&x| x.sqrt()).collect())
87 }
88 OpKind::RoPE {
89 base,
90 offset,
91 traditional,
92 } => rope(inputs, *base, *offset, *traditional),
93 OpKind::Embedding => embedding(inputs),
94 OpKind::Narrow {
95 axis,
96 start,
97 length,
98 } => narrow(inputs, *axis, *start, *length),
99 OpKind::Concatenate { axis } => concatenate(inputs, *axis),
100 }
101 }
102}
103
104fn sigmoid(x: f32) -> f32 {
105 1.0 / (1.0 + (-x).exp())
106}
107
108fn require_input<'a>(inputs: &'a [NodeInput<'_>], idx: usize) -> Result<&'a NodeInput<'a>> {
109 inputs
110 .get(idx)
111 .ok_or_else(|| MlxError::InvalidArgument(format!("expected input at index {idx}")))
112}
113
114fn binary_elementwise(inputs: &[NodeInput<'_>], f: fn(f32, f32) -> f32) -> Result<Vec<f32>> {
115 let a = require_input(inputs, 0)?;
116 let b = require_input(inputs, 1)?;
117 if a.data.len() != b.data.len() {
118 return Err(MlxError::ShapeMismatch {
119 expected: a.shape.0.clone(),
120 got: b.shape.0.clone(),
121 });
122 }
123 Ok(a.data
124 .iter()
125 .zip(b.data.iter())
126 .map(|(&x, &y)| f(x, y))
127 .collect())
128}
129
130fn reduce_sum(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
131 let a = require_input(inputs, 0)?;
132 match axis {
133 None => Ok(vec![a.data.iter().sum()]),
134 Some(axis) => reduce_along_axis(a, axis, |slice| slice.iter().sum()),
135 }
136}
137
138fn reduce_mean(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
139 let a = require_input(inputs, 0)?;
140 match axis {
141 None => {
142 let n = a.data.len() as f32;
143 Ok(vec![a.data.iter().sum::<f32>() / n])
144 }
145 Some(axis) => {
146 let ndim = a.shape.ndim() as i32;
147 let ax = if axis < 0 { ndim + axis } else { axis } as usize;
148 let dim = a.shape.0[ax] as f32;
149 reduce_along_axis(a, axis, |slice| slice.iter().sum::<f32>() / dim)
150 }
151 }
152}
153
154fn reduce_max(inputs: &[NodeInput<'_>], axis: Option<i32>) -> Result<Vec<f32>> {
155 let a = require_input(inputs, 0)?;
156 match axis {
157 None => Ok(vec![
158 a.data.iter().copied().fold(f32::NEG_INFINITY, f32::max),
159 ]),
160 Some(axis) => reduce_along_axis(a, axis, |slice| {
161 slice.iter().copied().fold(f32::NEG_INFINITY, f32::max)
162 }),
163 }
164}
165
166fn reduce_along_axis(
167 a: &NodeInput<'_>,
168 axis: i32,
169 reducer: impl Fn(&[f32]) -> f32,
170) -> Result<Vec<f32>> {
171 let ndim = a.shape.ndim() as i32;
172 let ax = if axis < 0 { ndim + axis } else { axis };
173 if ax < 0 || ax >= ndim {
174 return Err(MlxError::InvalidArgument(format!(
175 "axis {axis} out of range for ndim {ndim}"
176 )));
177 }
178 let ax = ax as usize;
179
180 let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
181 let dim: usize = a.shape.0[ax] as usize;
182 let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
183
184 let mut result = Vec::with_capacity(outer * inner);
185 for o in 0..outer {
186 for i in 0..inner {
187 let mut slice = Vec::with_capacity(dim);
188 for d in 0..dim {
189 slice.push(a.data[o * dim * inner + d * inner + i]);
190 }
191 result.push(reducer(&slice));
192 }
193 }
194 Ok(result)
195}
196
197fn matmul(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
198 let a = require_input(inputs, 0)?;
199 let b = require_input(inputs, 1)?;
200
201 if a.shape.ndim() != 2 || b.shape.ndim() != 2 {
202 return Err(MlxError::InvalidArgument(
203 "matmul requires 2D tensors".into(),
204 ));
205 }
206
207 let m = a.shape.0[0] as usize;
208 let k = a.shape.0[1] as usize;
209 let k2 = b.shape.0[0] as usize;
210 let n = b.shape.0[1] as usize;
211
212 if k != k2 {
213 return Err(MlxError::ShapeMismatch {
214 expected: vec![m as i64, k as i64],
215 got: vec![k2 as i64, n as i64],
216 });
217 }
218
219 let mut data = vec![0.0f32; m * n];
220 for i in 0..m {
221 for j in 0..n {
222 let mut sum = 0.0f32;
223 for p in 0..k {
224 sum += a.data[i * k + p] * b.data[p * n + j];
225 }
226 data[i * n + j] = sum;
227 }
228 }
229 Ok(data)
230}
231
232fn transpose(inputs: &[NodeInput<'_>], axes: Option<&[usize]>) -> Result<Vec<f32>> {
233 let a = require_input(inputs, 0)?;
234 let ndim = a.shape.ndim();
235
236 let perm: Vec<usize> = match axes {
237 Some(ax) => ax.to_vec(),
238 None => (0..ndim).rev().collect(),
239 };
240
241 if perm.len() != ndim {
242 return Err(MlxError::InvalidArgument(
243 "transpose axes length must match ndim".into(),
244 ));
245 }
246
247 let old_shape: Vec<usize> = a.shape.0.iter().map(|&d| d as usize).collect();
248 let new_shape: Vec<usize> = perm.iter().map(|&ax| old_shape[ax]).collect();
249
250 let mut old_strides = vec![1usize; ndim];
252 for i in (0..ndim.saturating_sub(1)).rev() {
253 old_strides[i] = old_strides[i + 1] * old_shape[i + 1];
254 }
255
256 let total = a.data.len();
257 let mut result = vec![0.0f32; total];
258
259 for (flat, out) in result.iter_mut().enumerate() {
260 let mut remaining = flat;
262 let mut old_flat = 0;
263 for dim_idx in 0..ndim {
264 let new_dim_size: usize = new_shape[dim_idx + 1..].iter().product::<usize>().max(1);
265 let coord = remaining / new_dim_size;
266 remaining %= new_dim_size;
267 old_flat += coord * old_strides[perm[dim_idx]];
269 }
270 *out = a.data[old_flat];
271 }
272
273 Ok(result)
274}
275
276fn softmax(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
277 let a = require_input(inputs, 0)?;
278 let ndim = a.shape.ndim() as i32;
279 let ax = if axis < 0 { ndim + axis } else { axis };
280 if ax < 0 || ax >= ndim {
281 return Err(MlxError::InvalidArgument(format!(
282 "axis {axis} out of range for ndim {ndim}"
283 )));
284 }
285 let ax = ax as usize;
286
287 let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
288 let dim: usize = a.shape.0[ax] as usize;
289 let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
290
291 let mut data = a.data.to_vec();
292
293 for o in 0..outer {
294 for i in 0..inner {
295 let mut max_val = f32::NEG_INFINITY;
296 for d in 0..dim {
297 let idx = o * dim * inner + d * inner + i;
298 if data[idx] > max_val {
299 max_val = data[idx];
300 }
301 }
302 let mut sum_exp = 0.0f32;
303 for d in 0..dim {
304 let idx = o * dim * inner + d * inner + i;
305 data[idx] = (data[idx] - max_val).exp();
306 sum_exp += data[idx];
307 }
308 for d in 0..dim {
309 let idx = o * dim * inner + d * inner + i;
310 data[idx] /= sum_exp;
311 }
312 }
313 }
314 Ok(data)
315}
316
317fn layer_norm(inputs: &[NodeInput<'_>], eps: f32, _meta: &TensorMeta) -> Result<Vec<f32>> {
318 let a = require_input(inputs, 0)?;
319 let ndim = a.shape.ndim();
321 if ndim == 0 {
322 return Ok(a.data.to_vec());
323 }
324 let last_dim = a.shape.0[ndim - 1] as usize;
325 let outer = a.data.len() / last_dim;
326
327 let mut result = vec![0.0f32; a.data.len()];
328 for o in 0..outer {
329 let start = o * last_dim;
330 let end = start + last_dim;
331 let slice = &a.data[start..end];
332
333 let mean = slice.iter().sum::<f32>() / last_dim as f32;
334 let var = slice.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / last_dim as f32;
335 let std = (var + eps).sqrt();
336
337 for (i, &x) in slice.iter().enumerate() {
338 result[start + i] = (x - mean) / std;
339 }
340 }
341 Ok(result)
342}
343
344fn broadcast(inputs: &[NodeInput<'_>], target_shape: &crate::Shape) -> Result<Vec<f32>> {
345 let a = require_input(inputs, 0)?;
346 let in_shape = &a.shape.0;
347 let out_shape = &target_shape.0;
348 let out_ndim = out_shape.len();
349 let in_ndim = in_shape.len();
350 let pad = out_ndim - in_ndim;
351 let total: usize = out_shape.iter().product::<i64>() as usize;
352
353 let mut result = vec![0.0f32; total];
354 for (out_flat, out) in result.iter_mut().enumerate() {
355 let mut remaining = out_flat;
356 let mut in_flat = 0usize;
357 let mut in_stride = 1usize;
358
359 for d in (0..out_ndim).rev() {
360 let out_dim = out_shape[d] as usize;
361 let coord = remaining % out_dim;
362 remaining /= out_dim;
363
364 if d >= pad {
365 let in_d = d - pad;
366 let in_dim = in_shape[in_d] as usize;
367 let in_coord = if in_dim == 1 { 0 } else { coord };
368 in_flat += in_coord * in_stride;
369 in_stride *= in_dim;
370 }
371 }
372 *out = a.data[in_flat];
373 }
374 Ok(result)
375}
376
377fn layer_norm_vjp(inputs: &[NodeInput<'_>], eps: f32) -> Result<Vec<f32>> {
383 let dy = require_input(inputs, 0)?;
384 let x = require_input(inputs, 1)?;
385 if dy.shape != x.shape || dy.data.len() != x.data.len() {
386 return Err(MlxError::ShapeMismatch {
387 expected: x.shape.0.clone(),
388 got: dy.shape.0.clone(),
389 });
390 }
391 let ndim = x.shape.ndim();
392 if ndim == 0 {
393 return Ok(dy.data.to_vec());
394 }
395 let d = x.shape.0[ndim - 1] as usize;
396 if d == 0 || x.data.is_empty() {
397 return Ok(vec![0.0f32; x.data.len()]);
398 }
399 let d_f = d as f32;
400 let outer = x.data.len() / d;
401
402 let mut result = vec![0.0f32; x.data.len()];
403 for o in 0..outer {
404 let start = o * d;
405 let end = start + d;
406 let x_slice = &x.data[start..end];
407 let dy_slice = &dy.data[start..end];
408
409 let mean = x_slice.iter().sum::<f32>() / d_f;
411 let var = x_slice.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / d_f;
412 let std = (var + eps).sqrt();
413 let inv_std = 1.0 / std;
414
415 let x_hat: Vec<f32> = x_slice.iter().map(|v| (v - mean) * inv_std).collect();
417
418 let mean_dy = dy_slice.iter().sum::<f32>() / d_f;
420 let mean_dy_xhat: f32 = dy_slice
421 .iter()
422 .zip(x_hat.iter())
423 .map(|(a, b)| a * b)
424 .sum::<f32>()
425 / d_f;
426
427 for i in 0..d {
429 result[start + i] = inv_std * (dy_slice[i] - mean_dy - x_hat[i] * mean_dy_xhat);
430 }
431 }
432 Ok(result)
433}
434
435fn rms_norm_vjp(inputs: &[NodeInput<'_>], eps: f32) -> Result<Vec<f32>> {
442 let dy = require_input(inputs, 0)?;
443 let x = require_input(inputs, 1)?;
444 if dy.shape != x.shape || dy.data.len() != x.data.len() {
445 return Err(MlxError::ShapeMismatch {
446 expected: x.shape.0.clone(),
447 got: dy.shape.0.clone(),
448 });
449 }
450 let ndim = x.shape.ndim();
451 if ndim == 0 {
452 return Ok(dy.data.to_vec());
453 }
454 let d = x.shape.0[ndim - 1] as usize;
455 if d == 0 || x.data.is_empty() {
456 return Ok(vec![0.0f32; x.data.len()]);
457 }
458 let d_f = d as f32;
459 let outer = x.data.len() / d;
460
461 let mut result = vec![0.0f32; x.data.len()];
462 for o in 0..outer {
463 let start = o * d;
464 let end = start + d;
465 let x_slice = &x.data[start..end];
466 let dy_slice = &dy.data[start..end];
467
468 let rms = (x_slice.iter().map(|v| v * v).sum::<f32>() / d_f + eps).sqrt();
470 let inv_rms = 1.0 / rms;
471
472 let y: Vec<f32> = x_slice.iter().map(|v| v * inv_rms).collect();
474
475 let mean_dy_y: f32 = dy_slice
477 .iter()
478 .zip(y.iter())
479 .map(|(a, b)| a * b)
480 .sum::<f32>()
481 / d_f;
482
483 for i in 0..d {
485 result[start + i] = inv_rms * (dy_slice[i] - y[i] * mean_dy_y);
486 }
487 }
488 Ok(result)
489}
490
491fn softmax_vjp(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
495 let dy = require_input(inputs, 0)?;
496 let s = require_input(inputs, 1)?;
497 if dy.data.len() != s.data.len() {
498 return Err(MlxError::ShapeMismatch {
499 expected: s.shape.0.clone(),
500 got: dy.shape.0.clone(),
501 });
502 }
503 let ndim = s.shape.ndim() as i32;
504 let ax = if axis < 0 { ndim + axis } else { axis };
505 if ax < 0 || ax >= ndim {
506 return Err(MlxError::InvalidArgument(format!(
507 "axis {axis} out of range for ndim {ndim}"
508 )));
509 }
510 let ax = ax as usize;
511
512 let outer: usize = s.shape.0[..ax].iter().product::<i64>() as usize;
513 let dim: usize = s.shape.0[ax] as usize;
514 let inner: usize = s.shape.0[ax + 1..].iter().product::<i64>() as usize;
515
516 let mut result = vec![0.0f32; s.data.len()];
517 for o in 0..outer {
518 for i in 0..inner {
519 let mut dot = 0.0f32;
521 for d in 0..dim {
522 let idx = o * dim * inner + d * inner + i;
523 dot += dy.data[idx] * s.data[idx];
524 }
525 for d in 0..dim {
526 let idx = o * dim * inner + d * inner + i;
527 result[idx] = s.data[idx] * (dy.data[idx] - dot);
528 }
529 }
530 }
531 Ok(result)
532}
533
534fn silu_vjp(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
538 let dy = require_input(inputs, 0)?;
539 let x = require_input(inputs, 1)?;
540 if dy.data.len() != x.data.len() {
541 return Err(MlxError::ShapeMismatch {
542 expected: x.shape.0.clone(),
543 got: dy.shape.0.clone(),
544 });
545 }
546 Ok(dy
547 .data
548 .iter()
549 .zip(x.data.iter())
550 .map(|(&dy_i, &x_i)| {
551 let sig = sigmoid(x_i);
552 dy_i * sig * (1.0 + x_i * (1.0 - sig))
553 })
554 .collect())
555}
556
557fn gelu_vjp(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
562 let dy = require_input(inputs, 0)?;
563 let x = require_input(inputs, 1)?;
564 if dy.data.len() != x.data.len() {
565 return Err(MlxError::ShapeMismatch {
566 expected: x.shape.0.clone(),
567 got: dy.shape.0.clone(),
568 });
569 }
570 let a = (2.0f32 / std::f32::consts::PI).sqrt();
571 let b = 0.044715f32;
572 Ok(dy
573 .data
574 .iter()
575 .zip(x.data.iter())
576 .map(|(&dy_i, &x_i)| {
577 let inner = a * (x_i + b * x_i * x_i * x_i);
578 let tanh_inner = inner.tanh();
579 let sech2 = 1.0 - tanh_inner * tanh_inner;
580 let dgelu =
581 0.5 * (1.0 + tanh_inner) + 0.5 * x_i * sech2 * a * (1.0 + 3.0 * b * x_i * x_i);
582 dy_i * dgelu
583 })
584 .collect())
585}
586
587fn cpu_rope(
588 inputs: &[NodeInput<'_>],
589 meta: &TensorMeta,
590 rotary_dim: usize,
591 pos_offset: usize,
592 theta: f32,
593) -> Result<Vec<f32>> {
594 let x = require_input(inputs, 0)?;
595 if meta.shape.ndim() != 2 {
596 return Err(MlxError::InvalidArgument(
597 "Rope input must be 2-D [tokens, head_dim]".into(),
598 ));
599 }
600 let tokens = meta.shape.0[0] as usize;
601 let head_dim = meta.shape.0[1] as usize;
602 if rotary_dim > head_dim || !rotary_dim.is_multiple_of(2) {
603 return Err(MlxError::InvalidArgument(
604 "rotary_dim must be even and <= head_dim".into(),
605 ));
606 }
607
608 let mut out = x.data.to_vec();
609 for t in 0..tokens {
610 for i in 0..rotary_dim / 2 {
611 let inv_freq = theta.powf(-2.0 * i as f32 / rotary_dim as f32);
612 let angle = (pos_offset + t) as f32 * inv_freq;
613 let (s, c) = angle.sin_cos();
614
615 let base = t * head_dim + i * 2;
616 let x0 = x.data[base];
617 let x1 = x.data[base + 1];
618
619 out[base] = x0 * c - x1 * s;
620 out[base + 1] = x0 * s + x1 * c;
621 }
622 }
623 Ok(out)
624}
625
626fn rms_norm(inputs: &[NodeInput<'_>], eps: f32, _meta: &TensorMeta) -> Result<Vec<f32>> {
627 let a = require_input(inputs, 0)?;
628 let ndim = a.shape.ndim();
629 if ndim == 0 {
630 return Ok(a.data.to_vec());
631 }
632 let last_dim = a.shape.0[ndim - 1] as usize;
633 let outer = a.data.len() / last_dim;
634
635 let mut result = vec![0.0f32; a.data.len()];
636 for o in 0..outer {
637 let start = o * last_dim;
638 let end = start + last_dim;
639 let slice = &a.data[start..end];
640
641 let rms = (slice.iter().map(|x| x * x).sum::<f32>() / last_dim as f32 + eps).sqrt();
642
643 for (i, &x) in slice.iter().enumerate() {
644 result[start + i] = x / rms;
645 }
646 }
647 Ok(result)
648}
649
650fn scaled_masked_softmax(inputs: &[NodeInput<'_>], scale: f32, causal: bool) -> Result<Vec<f32>> {
651 let a = require_input(inputs, 0)?;
652 if a.shape.ndim() != 2 {
653 return Err(MlxError::InvalidArgument(
654 "ScaledMaskedSoftmax requires 2D input [Tq, Tk]".into(),
655 ));
656 }
657 let tq = a.shape.0[0] as usize;
658 let tk = a.shape.0[1] as usize;
659
660 let mut data = vec![0.0f32; tq * tk];
661
662 for i in 0..tq {
663 for j in 0..tk {
665 let idx = i * tk + j;
666 let mut val = a.data[idx] * scale;
667 if causal && j > i {
668 val = -1e9;
669 }
670 data[idx] = val;
671 }
672
673 let row_start = i * tk;
675 let mut max_val = f32::NEG_INFINITY;
676 for j in 0..tk {
677 if data[row_start + j] > max_val {
678 max_val = data[row_start + j];
679 }
680 }
681 let mut sum_exp = 0.0f32;
682 for j in 0..tk {
683 data[row_start + j] = (data[row_start + j] - max_val).exp();
684 sum_exp += data[row_start + j];
685 }
686 for j in 0..tk {
687 data[row_start + j] /= sum_exp;
688 }
689 }
690 Ok(data)
691}
692
693fn cpu_matmul_raw(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
694 let mut out = vec![0.0f32; m * n];
695 for i in 0..m {
696 for j in 0..n {
697 let mut sum = 0.0f32;
698 for p in 0..k {
699 sum += a[i * k + p] * b[p * n + j];
700 }
701 out[i * n + j] = sum;
702 }
703 }
704 out
705}
706
707fn cpu_transpose_2d(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
708 let mut out = vec![0.0f32; rows * cols];
709 for r in 0..rows {
710 for c in 0..cols {
711 out[c * rows + r] = data[r * cols + c];
712 }
713 }
714 out
715}
716
717fn cpu_attention(inputs: &[NodeInput<'_>], scale: f32, causal: bool) -> Result<Vec<f32>> {
718 if inputs.len() != 3 {
719 return Err(MlxError::InvalidArgument(
720 "Attention requires exactly 3 inputs [Q, K, V]".into(),
721 ));
722 }
723 let q = require_input(inputs, 0)?;
724 let k = require_input(inputs, 1)?;
725 let v = require_input(inputs, 2)?;
726
727 if q.shape.ndim() != 2 || k.shape.ndim() != 2 || v.shape.ndim() != 2 {
728 return Err(MlxError::InvalidArgument(
729 "Attention inputs must be 2D".into(),
730 ));
731 }
732
733 let tq = q.shape.0[0] as usize;
734 let dh = q.shape.0[1] as usize;
735 let tk = k.shape.0[0] as usize;
736 let dh_k = k.shape.0[1] as usize;
737 let tk_v = v.shape.0[0] as usize;
738 let dh_v = v.shape.0[1] as usize;
739
740 if dh != dh_k {
741 return Err(MlxError::ShapeMismatch {
742 expected: vec![tq as i64, dh as i64],
743 got: vec![tk as i64, dh_k as i64],
744 });
745 }
746 if tk != tk_v || dh != dh_v {
747 return Err(MlxError::ShapeMismatch {
748 expected: vec![tk as i64, dh as i64],
749 got: vec![tk_v as i64, dh_v as i64],
750 });
751 }
752
753 let kt = cpu_transpose_2d(k.data, tk, dh);
755
756 let scores = cpu_matmul_raw(q.data, &kt, tq, dh, tk);
758
759 let mut probs = vec![0.0f32; tq * tk];
761 for i in 0..tq {
762 for j in 0..tk {
763 let idx = i * tk + j;
764 let mut val = scores[idx] * scale;
765 if causal && j > i {
766 val = -1e9;
767 }
768 probs[idx] = val;
769 }
770 let row_start = i * tk;
771 let mut max_val = f32::NEG_INFINITY;
772 for j in 0..tk {
773 if probs[row_start + j] > max_val {
774 max_val = probs[row_start + j];
775 }
776 }
777 let mut sum_exp = 0.0f32;
778 for j in 0..tk {
779 probs[row_start + j] = (probs[row_start + j] - max_val).exp();
780 sum_exp += probs[row_start + j];
781 }
782 for j in 0..tk {
783 probs[row_start + j] /= sum_exp;
784 }
785 }
786
787 let y = cpu_matmul_raw(&probs, v.data, tq, tk, dh);
789
790 Ok(y)
791}
792
793fn embedding(inputs: &[NodeInput<'_>]) -> Result<Vec<f32>> {
794 let weight = require_input(inputs, 0)?;
795 let indices = require_input(inputs, 1)?;
796
797 if weight.shape.ndim() != 2 {
798 return Err(MlxError::InvalidArgument(
799 "Embedding weight must be 2D [vocab_size, embed_dim]".into(),
800 ));
801 }
802 if indices.shape.ndim() != 1 {
803 return Err(MlxError::InvalidArgument(
804 "Embedding indices must be 1D [seq_len]".into(),
805 ));
806 }
807 let vocab_size = weight.shape.0[0] as usize;
808 let embed_dim = weight.shape.0[1] as usize;
809 let seq_len = indices.data.len();
810
811 let mut result = Vec::with_capacity(seq_len * embed_dim);
812 for &idx_f in indices.data {
813 if idx_f < 0.0 || idx_f != idx_f.trunc() {
814 return Err(MlxError::InvalidArgument(format!(
815 "Embedding index must be a non-negative integer, got {idx_f}"
816 )));
817 }
818 let idx = idx_f as usize;
819 if idx >= vocab_size {
820 return Err(MlxError::InvalidArgument(format!(
821 "Embedding index {idx} out of range for vocab_size {vocab_size}"
822 )));
823 }
824 let start = idx * embed_dim;
825 result.extend_from_slice(&weight.data[start..start + embed_dim]);
826 }
827 Ok(result)
828}
829
830fn narrow(inputs: &[NodeInput<'_>], axis: i32, start: i64, length: i64) -> Result<Vec<f32>> {
831 let a = require_input(inputs, 0)?;
832 let ndim = a.shape.ndim() as i32;
833 let ax = if axis < 0 { ndim + axis } else { axis };
834 if ax < 0 || ax >= ndim {
835 return Err(MlxError::InvalidArgument(format!(
836 "narrow: axis {axis} out of range for ndim {ndim}"
837 )));
838 }
839 let ax = ax as usize;
840 let dim_size = a.shape.0[ax] as i64;
841 if start < 0 || start + length > dim_size {
842 return Err(MlxError::InvalidArgument(format!(
843 "narrow: start {start} + length {length} exceeds dim size {dim_size}"
844 )));
845 }
846
847 let outer: usize = a.shape.0[..ax].iter().product::<i64>() as usize;
848 let dim: usize = a.shape.0[ax] as usize;
849 let inner: usize = a.shape.0[ax + 1..].iter().product::<i64>() as usize;
850 let start = start as usize;
851 let length = length as usize;
852
853 let mut result = Vec::with_capacity(outer * length * inner);
854 for o in 0..outer {
855 for d in start..start + length {
856 let base = (o * dim + d) * inner;
857 result.extend_from_slice(&a.data[base..base + inner]);
858 }
859 }
860 Ok(result)
861}
862
863fn concatenate(inputs: &[NodeInput<'_>], axis: i32) -> Result<Vec<f32>> {
864 if inputs.is_empty() {
865 return Err(MlxError::InvalidArgument(
866 "Concatenate requires at least one input".into(),
867 ));
868 }
869 let first = &inputs[0];
870 let ndim = first.shape.ndim() as i32;
871 let ax = if axis < 0 { ndim + axis } else { axis };
872 if ax < 0 || ax >= ndim {
873 return Err(MlxError::InvalidArgument(format!(
874 "concatenate: axis {axis} out of range for ndim {ndim}"
875 )));
876 }
877 let ax = ax as usize;
878
879 for inp in &inputs[1..] {
881 if inp.shape.ndim() != first.shape.ndim() {
882 return Err(MlxError::InvalidArgument(
883 "Concatenate: all inputs must have same ndim".into(),
884 ));
885 }
886 for (d, (&a, &b)) in first.shape.0.iter().zip(inp.shape.0.iter()).enumerate() {
887 if d != ax && a != b {
888 return Err(MlxError::ShapeMismatch {
889 expected: first.shape.0.clone(),
890 got: inp.shape.0.clone(),
891 });
892 }
893 }
894 }
895
896 let outer: usize = first.shape.0[..ax].iter().product::<i64>() as usize;
897 let inner: usize = first.shape.0[ax + 1..].iter().product::<i64>() as usize;
898
899 let total_dim: usize = inputs.iter().map(|i| i.shape.0[ax] as usize).sum();
900 let mut result = Vec::with_capacity(outer * total_dim * inner);
901
902 for o in 0..outer {
903 for inp in inputs {
904 let dim = inp.shape.0[ax] as usize;
905 let base = o * dim * inner;
906 result.extend_from_slice(&inp.data[base..base + dim * inner]);
907 }
908 }
909 Ok(result)
910}
911
912fn rope(inputs: &[NodeInput<'_>], base: f32, offset: usize, traditional: bool) -> Result<Vec<f32>> {
913 let a = require_input(inputs, 0)?;
914 let ndim = a.shape.ndim();
915 if ndim < 1 {
916 return Err(MlxError::InvalidArgument(
917 "RoPE requires at least 1 dimension".into(),
918 ));
919 }
920
921 let head_dim = a.shape.0[ndim - 1] as usize;
922 if !head_dim.is_multiple_of(2) {
923 return Err(MlxError::InvalidArgument(format!(
924 "RoPE head_dim must be even, got {head_dim}"
925 )));
926 }
927 let half_dim = head_dim / 2;
928
929 let total = a.data.len();
930 let num_heads_total = total / head_dim;
931
932 let mut result = vec![0.0f32; total];
933
934 for i in 0..num_heads_total {
935 let pos = (offset + i) as f32;
942
943 for d in 0..half_dim {
944 let theta = pos * base.powf(-(2.0 * d as f32 / head_dim as f32));
945 let cos_theta = theta.cos();
946 let sin_theta = theta.sin();
947
948 if traditional {
949 let idx0 = i * head_dim + 2 * d;
951 let idx1 = idx0 + 1;
952
953 let x0 = a.data[idx0];
954 let x1 = a.data[idx1];
955
956 result[idx0] = x0 * cos_theta - x1 * sin_theta;
957 result[idx1] = x0 * sin_theta + x1 * cos_theta;
958 } else {
959 let idx0 = i * head_dim + d;
961 let idx1 = i * head_dim + d + half_dim;
962
963 let x0 = a.data[idx0];
964 let x1 = a.data[idx1];
965
966 result[idx0] = x0 * cos_theta - x1 * sin_theta;
967 result[idx1] = x0 * sin_theta + x1 * cos_theta;
968 }
969 }
970 }
971 Ok(result)
972}
973
974#[cfg(test)]
975mod tests {
976 use super::*;
977 use crate::graph::TensorMeta;
978 use crate::types::Shape;
979
980 fn meta(shape: Vec<i64>) -> TensorMeta {
981 TensorMeta {
982 shape: Shape::new(shape),
983 dtype: crate::DType::F32,
984 }
985 }
986
987 fn input(data: &[f32], shape: Vec<i64>) -> NodeInput<'_> {
988 NodeInput {
990 data,
991 shape: Box::leak(Box::new(Shape::new(shape))),
992 dtype: crate::DType::F32,
993 }
994 }
995
996 #[test]
997 fn test_add() {
998 let backend = CpuRefBackend;
999 let a_data = [1.0, 2.0, 3.0];
1000 let b_data = [4.0, 5.0, 6.0];
1001 let result = backend
1002 .eval_node(
1003 &OpKind::Add,
1004 &[input(&a_data, vec![3]), input(&b_data, vec![3])],
1005 &meta(vec![3]),
1006 )
1007 .unwrap();
1008 assert_eq!(result, vec![5.0, 7.0, 9.0]);
1009 }
1010
1011 #[test]
1012 fn test_matmul() {
1013 let backend = CpuRefBackend;
1014 let a_data = [1.0, 2.0, 3.0, 4.0];
1015 let b_data = [5.0, 6.0, 7.0, 8.0];
1016 let result = backend
1017 .eval_node(
1018 &OpKind::MatMul,
1019 &[input(&a_data, vec![2, 2]), input(&b_data, vec![2, 2])],
1020 &meta(vec![2, 2]),
1021 )
1022 .unwrap();
1023 assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
1024 }
1025
1026 #[test]
1027 fn test_softmax() {
1028 let backend = CpuRefBackend;
1029 let data = [1.0, 2.0, 3.0];
1030 let result = backend
1031 .eval_node(
1032 &OpKind::Softmax { axis: 0 },
1033 &[input(&data, vec![3])],
1034 &meta(vec![3]),
1035 )
1036 .unwrap();
1037 let sum: f32 = result.iter().sum();
1038 assert!((sum - 1.0).abs() < 1e-6);
1039 assert!(result[0] < result[1]);
1040 assert!(result[1] < result[2]);
1041 }
1042
1043 #[test]
1044 fn test_neg() {
1045 let backend = CpuRefBackend;
1046 let data = [1.0, -2.0, 3.0];
1047 let result = backend
1048 .eval_node(&OpKind::Neg, &[input(&data, vec![3])], &meta(vec![3]))
1049 .unwrap();
1050 assert_eq!(result, vec![-1.0, 2.0, -3.0]);
1051 }
1052
1053 #[test]
1054 fn test_layer_norm() {
1055 let backend = CpuRefBackend;
1056 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1057 let result = backend
1058 .eval_node(
1059 &OpKind::LayerNorm { eps: 1e-5 },
1060 &[input(&data, vec![2, 3])],
1061 &meta(vec![2, 3]),
1062 )
1063 .unwrap();
1064 let row1_mean: f32 = result[0..3].iter().sum::<f32>() / 3.0;
1066 assert!(row1_mean.abs() < 1e-5);
1067 }
1068
1069 #[test]
1070 fn test_reduce_sum_axis() {
1071 let backend = CpuRefBackend;
1072 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1073 let result = backend
1074 .eval_node(
1075 &OpKind::Sum { axis: Some(0) },
1076 &[input(&data, vec![2, 3])],
1077 &meta(vec![3]),
1078 )
1079 .unwrap();
1080 assert_eq!(result, vec![5.0, 7.0, 9.0]);
1081 }
1082
1083 #[test]
1084 fn test_reduce_sum_all() {
1085 let backend = CpuRefBackend;
1086 let data = [1.0, 2.0, 3.0];
1087 let result = backend
1088 .eval_node(
1089 &OpKind::Sum { axis: None },
1090 &[input(&data, vec![3])],
1091 &meta(vec![]),
1092 )
1093 .unwrap();
1094 assert_eq!(result, vec![6.0]);
1095 }
1096
1097 #[test]
1098 fn test_silu() {
1099 let backend = CpuRefBackend;
1100 let data = [0.0, 1.0, -1.0];
1101 let result = backend
1102 .eval_node(&OpKind::Silu, &[input(&data, vec![3])], &meta(vec![3]))
1103 .unwrap();
1104 assert!((result[1] - 0.7311).abs() < 1e-3);
1106 assert!((result[2] - (-0.2689)).abs() < 1e-3);
1107 }
1108
1109 #[test]
1110 fn test_rope_offsets() {
1111 let backend = CpuRefBackend;
1112 let theta = 10_000.0;
1113 let pos_offset = 100usize;
1114 let rotary_dim = 4;
1115 let data = [1.0, 0.0, 0.0, 1.0];
1117 let result = backend
1118 .eval_node(
1119 &OpKind::Rope {
1120 rotary_dim,
1121 pos_offset,
1122 theta,
1123 },
1124 &[input(&data, vec![1, 4])],
1125 &meta(vec![1, 4]),
1126 )
1127 .unwrap();
1128
1129 let cos100 = 100.0f32.cos();
1132 let sin100 = 100.0f32.sin();
1133 let cos1 = 1.0f32.cos();
1135 let sin1 = 1.0f32.sin();
1136
1137 assert!((result[0] - cos100).abs() < 1e-5);
1140 assert!((result[1] - sin100).abs() < 1e-5);
1141 assert!((result[2] - (-sin1)).abs() < 1e-5);
1142 assert!((result[3] - cos1).abs() < 1e-5);
1143 }
1144
1145 #[test]
1146 fn test_rope_large() {
1147 let backend = CpuRefBackend;
1148 let shape = vec![128, 128];
1149 let numel = 128 * 128;
1150 let data = vec![1.0; numel];
1151 let result = backend.eval_node(
1152 &OpKind::Rope {
1153 rotary_dim: 128,
1154 pos_offset: 0,
1155 theta: 10000.0,
1156 },
1157 &[input(&data, shape.clone())],
1158 &meta(shape.clone()),
1159 );
1160 assert!(result.is_ok());
1161 assert_eq!(result.unwrap().len(), numel);
1162 }
1163}