Skip to main content

mlx_core/
tensor.rs

1//! Tensor type — a lazy handle to a node in the computation graph.
2//!
3//! Operations on tensors record nodes in the graph. Actual computation is
4//! deferred until `eval()` (or `to_vec_f32()`) is called, at which point the
5//! stream topologically sorts the subgraph and dispatches to the backend.
6
7use std::sync::Arc;
8
9use smallvec::SmallVec;
10
11use crate::backend::{Stream, default_stream};
12use crate::graph::{OpKind, TensorMeta};
13use crate::{DType, MlxError, NodeId, Result, Shape};
14
15/// Compute device.
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum Device {
18    Cpu,
19    Gpu,
20}
21
22impl Device {
23    /// Return the default device for this platform.
24    pub fn default_device() -> Self {
25        #[cfg(target_os = "macos")]
26        {
27            // TODO: Enable Metal backend when available.
28            Device::Cpu
29        }
30        #[cfg(not(target_os = "macos"))]
31        {
32            Device::Cpu
33        }
34    }
35}
36
37/// A tensor handle.
38///
39/// In the lazy graph model a `Tensor` is a lightweight reference to a node in
40/// the computation graph. Operations build up the graph; actual computation
41/// happens when `eval()` is called (or implicitly via `to_vec_f32()`).
42#[derive(Clone)]
43pub struct Tensor {
44    node_id: NodeId,
45    shape: Shape,
46    dtype: DType,
47    device: Device,
48    stream: Arc<Stream>,
49}
50
51impl Tensor {
52    // ── Constructors ────────────────────────────────────────────────────
53
54    /// Create a tensor filled with zeros.
55    pub fn zeros(shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
56        let n = shape.numel() as usize;
57        Self::from_data(vec![0.0; n], shape, dtype, device)
58    }
59
60    /// Create a tensor filled with ones.
61    pub fn ones(shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
62        let n = shape.numel() as usize;
63        Self::from_data(vec![1.0; n], shape, dtype, device)
64    }
65
66    /// Create a tensor from f32 data.
67    pub fn from_f32(data: &[f32], shape: &Shape, device: &Device) -> Result<Self> {
68        let expected = shape.numel() as usize;
69        if data.len() != expected {
70            return Err(MlxError::InvalidArgument(format!(
71                "data length {} does not match shape {} (expected {})",
72                data.len(),
73                shape,
74                expected,
75            )));
76        }
77        Self::from_data(data.to_vec(), shape, DType::F32, device)
78    }
79
80    /// Create a tensor from f32 data on a specific stream.
81    pub fn from_f32_on_stream(data: &[f32], shape: &Shape, stream: &Arc<Stream>) -> Result<Self> {
82        let expected = shape.numel() as usize;
83        if data.len() != expected {
84            return Err(MlxError::InvalidArgument(format!(
85                "data length {} does not match shape {} (expected {})",
86                data.len(),
87                shape,
88                expected,
89            )));
90        }
91        let meta = TensorMeta {
92            shape: shape.clone(),
93            dtype: DType::F32,
94        };
95        let node_id = stream.add_constant(data.to_vec(), meta);
96        Ok(Self {
97            node_id,
98            shape: shape.clone(),
99            dtype: DType::F32,
100            device: Device::Gpu, // Assume GPU if custom stream for now, or detect
101            stream: Arc::clone(stream),
102        })
103    }
104
105    /// Create a tensor from f32 data with a specified dtype.
106    ///
107    /// The data is stored as f32 internally; `dtype` records the logical type
108    /// (e.g. F16 weights that were converted to f32 on load).
109    pub fn from_data_with_dtype(
110        data: Vec<f32>,
111        shape: &Shape,
112        dtype: DType,
113        device: &Device,
114    ) -> Result<Self> {
115        let expected = shape.numel() as usize;
116        if data.len() != expected {
117            return Err(MlxError::InvalidArgument(format!(
118                "data length {} does not match shape {} (expected {})",
119                data.len(),
120                shape,
121                expected,
122            )));
123        }
124        Self::from_data(data, shape, dtype, device)
125    }
126
127    fn from_data(data: Vec<f32>, shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
128        let stream = default_stream();
129        let meta = TensorMeta {
130            shape: shape.clone(),
131            dtype,
132        };
133        let node_id = stream.add_constant(data, meta);
134        Ok(Self {
135            node_id,
136            shape: shape.clone(),
137            dtype,
138            device: device.clone(),
139            stream,
140        })
141    }
142
143    fn lazy_op(
144        &self,
145        op: OpKind,
146        inputs: SmallVec<[NodeId; 2]>,
147        shape: Shape,
148        dtype: DType,
149    ) -> Self {
150        let meta = TensorMeta {
151            shape: shape.clone(),
152            dtype,
153        };
154        let node_id = self.stream.add_op(op, inputs, meta);
155        Tensor {
156            node_id,
157            shape,
158            dtype,
159            device: self.device.clone(),
160            stream: Arc::clone(&self.stream),
161        }
162    }
163
164    // ── Elementwise ops ─────────────────────────────────────────────────
165
166    /// Element-wise addition.
167    pub fn add(&self, rhs: &Tensor) -> Result<Tensor> {
168        if self.shape != rhs.shape {
169            return Err(MlxError::ShapeMismatch {
170                expected: self.shape.0.clone(),
171                got: rhs.shape.0.clone(),
172            });
173        }
174        Ok(self.lazy_op(
175            OpKind::Add,
176            SmallVec::from_slice(&[self.node_id, rhs.node_id]),
177            self.shape.clone(),
178            self.dtype,
179        ))
180    }
181
182    /// Element-wise subtraction.
183    pub fn sub(&self, rhs: &Tensor) -> Result<Tensor> {
184        if self.shape != rhs.shape {
185            return Err(MlxError::ShapeMismatch {
186                expected: self.shape.0.clone(),
187                got: rhs.shape.0.clone(),
188            });
189        }
190        Ok(self.lazy_op(
191            OpKind::Sub,
192            SmallVec::from_slice(&[self.node_id, rhs.node_id]),
193            self.shape.clone(),
194            self.dtype,
195        ))
196    }
197
198    /// Element-wise multiplication.
199    pub fn mul(&self, rhs: &Tensor) -> Result<Tensor> {
200        if self.shape != rhs.shape {
201            return Err(MlxError::ShapeMismatch {
202                expected: self.shape.0.clone(),
203                got: rhs.shape.0.clone(),
204            });
205        }
206        Ok(self.lazy_op(
207            OpKind::Mul,
208            SmallVec::from_slice(&[self.node_id, rhs.node_id]),
209            self.shape.clone(),
210            self.dtype,
211        ))
212    }
213
214    /// Element-wise division.
215    pub fn div(&self, rhs: &Tensor) -> Result<Tensor> {
216        if self.shape != rhs.shape {
217            return Err(MlxError::ShapeMismatch {
218                expected: self.shape.0.clone(),
219                got: rhs.shape.0.clone(),
220            });
221        }
222        Ok(self.lazy_op(
223            OpKind::Div,
224            SmallVec::from_slice(&[self.node_id, rhs.node_id]),
225            self.shape.clone(),
226            self.dtype,
227        ))
228    }
229
230    /// Element-wise negation.
231    pub fn neg(&self) -> Tensor {
232        self.lazy_op(
233            OpKind::Neg,
234            SmallVec::from_slice(&[self.node_id]),
235            self.shape.clone(),
236            self.dtype,
237        )
238    }
239
240    // ── Reductions ──────────────────────────────────────────────────────
241
242    /// Sum along an axis.
243    pub fn sum_axis(&self, axis: i32) -> Result<Tensor> {
244        let ndim = self.shape.ndim() as i32;
245        let ax = if axis < 0 { ndim + axis } else { axis };
246        if ax < 0 || ax >= ndim {
247            return Err(MlxError::InvalidArgument(format!(
248                "axis {axis} out of range for ndim {ndim}"
249            )));
250        }
251        let mut new_dims: Vec<i64> = self.shape.0.clone();
252        new_dims.remove(ax as usize);
253        Ok(self.lazy_op(
254            OpKind::Sum { axis: Some(ax) },
255            SmallVec::from_slice(&[self.node_id]),
256            Shape::new(new_dims),
257            self.dtype,
258        ))
259    }
260
261    /// Sum all elements to a scalar.
262    pub fn sum_all(&self) -> Result<Tensor> {
263        Ok(self.lazy_op(
264            OpKind::Sum { axis: None },
265            SmallVec::from_slice(&[self.node_id]),
266            Shape::scalar(),
267            self.dtype,
268        ))
269    }
270
271    // ── Linear algebra ──────────────────────────────────────────────────
272
273    /// Matrix multiplication (2D only for now).
274    pub fn matmul(&self, rhs: &Tensor) -> Result<Tensor> {
275        if self.shape.ndim() != 2 || rhs.shape.ndim() != 2 {
276            return Err(MlxError::InvalidArgument(
277                "matmul requires 2D tensors".to_string(),
278            ));
279        }
280        let m = self.shape.0[0];
281        let k = self.shape.0[1];
282        let k2 = rhs.shape.0[0];
283        let n = rhs.shape.0[1];
284        if k != k2 {
285            return Err(MlxError::ShapeMismatch {
286                expected: self.shape.0.clone(),
287                got: rhs.shape.0.clone(),
288            });
289        }
290        Ok(self.lazy_op(
291            OpKind::MatMul,
292            SmallVec::from_slice(&[self.node_id, rhs.node_id]),
293            Shape::new(vec![m, n]),
294            self.dtype,
295        ))
296    }
297
298    // ── Shape manipulation ──────────────────────────────────────────────
299
300    /// Reshape the tensor.
301    pub fn reshape(&self, new_shape: &Shape) -> Result<Tensor> {
302        if self.shape.numel() != new_shape.numel() {
303            return Err(MlxError::ShapeMismatch {
304                expected: self.shape.0.clone(),
305                got: new_shape.0.clone(),
306            });
307        }
308        Ok(self.lazy_op(
309            OpKind::Reshape {
310                new_shape: new_shape.clone(),
311            },
312            SmallVec::from_slice(&[self.node_id]),
313            new_shape.clone(),
314            self.dtype,
315        ))
316    }
317
318    /// Transpose (reverses axes by default, or use specified permutation).
319    pub fn transpose(&self, axes: Option<&[usize]>) -> Result<Tensor> {
320        let ndim = self.shape.ndim();
321        let perm: Vec<usize> = match axes {
322            Some(ax) => {
323                if ax.len() != ndim {
324                    return Err(MlxError::InvalidArgument(
325                        "transpose axes length must match ndim".into(),
326                    ));
327                }
328                let mut seen = vec![false; ndim];
329                for &axis in ax {
330                    if axis >= ndim {
331                        return Err(MlxError::InvalidArgument(format!(
332                            "transpose axis {axis} out of range for ndim {ndim}"
333                        )));
334                    }
335                    if seen[axis] {
336                        return Err(MlxError::InvalidArgument(format!(
337                            "duplicate transpose axis {axis} in axes; expected a permutation of 0..{ndim}"
338                        )));
339                    }
340                    seen[axis] = true;
341                }
342                ax.to_vec()
343            }
344            None => (0..ndim).rev().collect(),
345        };
346        let new_dims: Vec<i64> = perm.iter().map(|&ax| self.shape.0[ax]).collect();
347        Ok(self.lazy_op(
348            OpKind::Transpose { axes: Some(perm) },
349            SmallVec::from_slice(&[self.node_id]),
350            Shape::new(new_dims),
351            self.dtype,
352        ))
353    }
354
355    // ── Activations ─────────────────────────────────────────────────────
356
357    /// Softmax along an axis.
358    pub fn softmax(&self, axis: i32) -> Result<Tensor> {
359        let ndim = self.shape.ndim() as i32;
360        let ax = if axis < 0 { ndim + axis } else { axis };
361        if ax < 0 || ax >= ndim {
362            return Err(MlxError::InvalidArgument(format!(
363                "axis {axis} out of range for ndim {ndim}"
364            )));
365        }
366        Ok(self.lazy_op(
367            OpKind::Softmax { axis },
368            SmallVec::from_slice(&[self.node_id]),
369            self.shape.clone(),
370            self.dtype,
371        ))
372    }
373
374    /// SiLU (Sigmoid Linear Unit) activation.
375    pub fn silu(&self) -> Tensor {
376        self.lazy_op(
377            OpKind::Silu,
378            SmallVec::from_slice(&[self.node_id]),
379            self.shape.clone(),
380            self.dtype,
381        )
382    }
383
384    /// GELU (Gaussian Error Linear Unit) activation.
385    pub fn gelu(&self) -> Tensor {
386        self.lazy_op(
387            OpKind::Gelu,
388            SmallVec::from_slice(&[self.node_id]),
389            self.shape.clone(),
390            self.dtype,
391        )
392    }
393
394    // ── Normalization ───────────────────────────────────────────────────
395
396    /// Layer normalization over the last dimension.
397    pub fn layer_norm(&self, eps: f32) -> Tensor {
398        self.lazy_op(
399            OpKind::LayerNorm { eps },
400            SmallVec::from_slice(&[self.node_id]),
401            self.shape.clone(),
402            self.dtype,
403        )
404    }
405
406    /// RMS normalization over the last dimension.
407    pub fn rms_norm(&self, eps: f32) -> Tensor {
408        self.lazy_op(
409            OpKind::RmsNorm { eps },
410            SmallVec::from_slice(&[self.node_id]),
411            self.shape.clone(),
412            self.dtype,
413        )
414    }
415
416    /// Apply Rotary Positional Embeddings.
417    pub fn rope(&self, rotary_dim: usize, pos_offset: usize, theta: f32) -> Tensor {
418        self.lazy_op(
419            OpKind::Rope {
420                rotary_dim,
421                pos_offset,
422                theta,
423            },
424            SmallVec::from_slice(&[self.node_id]),
425            self.shape.clone(),
426            self.dtype,
427        )
428    }
429
430    // ── Backward (VJP) helpers ─────────────────────────────────────────
431
432    /// LayerNorm VJP: compute grad_input given grad_output and original input.
433    pub fn layer_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor> {
434        if self.shape != input.shape {
435            return Err(MlxError::ShapeMismatch {
436                expected: input.shape.0.clone(),
437                got: self.shape.0.clone(),
438            });
439        }
440        if self.dtype != input.dtype {
441            return Err(MlxError::InvalidArgument(
442                "layer_norm_vjp requires matching dtypes".into(),
443            ));
444        }
445        if self.device != input.device {
446            return Err(MlxError::InvalidArgument(
447                "layer_norm_vjp requires matching devices".into(),
448            ));
449        }
450        Ok(self.lazy_op(
451            OpKind::LayerNormVjp { eps },
452            SmallVec::from_slice(&[self.node_id, input.node_id]),
453            input.shape.clone(),
454            input.dtype,
455        ))
456    }
457
458    /// RmsNorm VJP: compute grad_input given grad_output and original input.
459    pub fn rms_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor> {
460        if self.shape != input.shape {
461            return Err(MlxError::ShapeMismatch {
462                expected: input.shape.0.clone(),
463                got: self.shape.0.clone(),
464            });
465        }
466        if self.dtype != input.dtype {
467            return Err(MlxError::InvalidArgument(
468                "rms_norm_vjp requires matching dtypes".into(),
469            ));
470        }
471        if self.device != input.device {
472            return Err(MlxError::InvalidArgument(
473                "rms_norm_vjp requires matching devices".into(),
474            ));
475        }
476        Ok(self.lazy_op(
477            OpKind::RmsNormVjp { eps },
478            SmallVec::from_slice(&[self.node_id, input.node_id]),
479            input.shape.clone(),
480            input.dtype,
481        ))
482    }
483
484    /// Softmax VJP: compute grad_input given grad_output (self) and softmax output.
485    pub fn softmax_vjp(&self, softmax_output: &Tensor, axis: i32) -> Result<Tensor> {
486        if self.shape != softmax_output.shape {
487            return Err(MlxError::ShapeMismatch {
488                expected: softmax_output.shape.0.clone(),
489                got: self.shape.0.clone(),
490            });
491        }
492        if self.dtype != softmax_output.dtype {
493            return Err(MlxError::InvalidArgument(
494                "softmax_vjp requires matching dtypes".into(),
495            ));
496        }
497        if self.device != softmax_output.device {
498            return Err(MlxError::InvalidArgument(
499                "softmax_vjp requires matching devices".into(),
500            ));
501        }
502        Ok(self.lazy_op(
503            OpKind::SoftmaxVjp { axis },
504            SmallVec::from_slice(&[self.node_id, softmax_output.node_id]),
505            softmax_output.shape.clone(),
506            softmax_output.dtype,
507        ))
508    }
509
510    /// SiLU VJP: compute grad_input given grad_output (self) and original input.
511    pub fn silu_vjp(&self, input: &Tensor) -> Result<Tensor> {
512        if self.shape != input.shape {
513            return Err(MlxError::ShapeMismatch {
514                expected: input.shape.0.clone(),
515                got: self.shape.0.clone(),
516            });
517        }
518        if self.dtype != input.dtype {
519            return Err(MlxError::InvalidArgument(
520                "silu_vjp requires matching dtypes".into(),
521            ));
522        }
523        if self.device != input.device {
524            return Err(MlxError::InvalidArgument(
525                "silu_vjp requires matching devices".into(),
526            ));
527        }
528        Ok(self.lazy_op(
529            OpKind::SiluVjp,
530            SmallVec::from_slice(&[self.node_id, input.node_id]),
531            input.shape.clone(),
532            input.dtype,
533        ))
534    }
535
536    /// GELU VJP: compute grad_input given grad_output (self) and original input.
537    pub fn gelu_vjp(&self, input: &Tensor) -> Result<Tensor> {
538        if self.shape != input.shape {
539            return Err(MlxError::ShapeMismatch {
540                expected: input.shape.0.clone(),
541                got: self.shape.0.clone(),
542            });
543        }
544        if self.dtype != input.dtype {
545            return Err(MlxError::InvalidArgument(
546                "gelu_vjp requires matching dtypes".into(),
547            ));
548        }
549        if self.device != input.device {
550            return Err(MlxError::InvalidArgument(
551                "gelu_vjp requires matching devices".into(),
552            ));
553        }
554        Ok(self.lazy_op(
555            OpKind::GeluVjp,
556            SmallVec::from_slice(&[self.node_id, input.node_id]),
557            input.shape.clone(),
558            input.dtype,
559        ))
560    }
561
562    // ── Indexing / gathering ──────────────────────────────────────────
563
564    /// Embedding lookup: gather rows from this weight matrix [vocab, dim]
565    /// using `indices` [seq_len]. Returns [seq_len, dim].
566    pub fn embedding_lookup(&self, indices: &Tensor) -> Result<Tensor> {
567        if self.shape.ndim() != 2 {
568            return Err(MlxError::InvalidArgument(
569                "embedding_lookup: weight must be 2D [vocab_size, embed_dim]".into(),
570            ));
571        }
572        if indices.shape.ndim() != 1 {
573            return Err(MlxError::InvalidArgument(
574                "embedding_lookup: indices must be 1D [seq_len]".into(),
575            ));
576        }
577        let seq_len = indices.shape.0[0];
578        let embed_dim = self.shape.0[1];
579        Ok(self.lazy_op(
580            OpKind::Embedding,
581            SmallVec::from_slice(&[self.node_id, indices.node_id]),
582            Shape::new(vec![seq_len, embed_dim]),
583            self.dtype,
584        ))
585    }
586
587    /// Narrow (slice) along an axis: extract `length` elements starting at `start`.
588    pub fn narrow(&self, axis: i32, start: i64, length: i64) -> Result<Tensor> {
589        let ndim = self.shape.ndim() as i32;
590        let ax = if axis < 0 { ndim + axis } else { axis };
591        if ax < 0 || ax >= ndim {
592            return Err(MlxError::InvalidArgument(format!(
593                "narrow: axis {axis} out of range for ndim {ndim}"
594            )));
595        }
596        let ax_usize = ax as usize;
597        let dim_size = self.shape.0[ax_usize];
598        if start < 0 || start + length > dim_size {
599            return Err(MlxError::InvalidArgument(format!(
600                "narrow: start {start} + length {length} exceeds dim size {dim_size}"
601            )));
602        }
603        let mut new_dims = self.shape.0.clone();
604        new_dims[ax_usize] = length;
605        Ok(self.lazy_op(
606            OpKind::Narrow {
607                axis: ax,
608                start,
609                length,
610            },
611            SmallVec::from_slice(&[self.node_id]),
612            Shape::new(new_dims),
613            self.dtype,
614        ))
615    }
616
617    /// Concatenate tensors along an axis.
618    pub fn cat(tensors: &[&Tensor], axis: i32) -> Result<Tensor> {
619        if tensors.is_empty() {
620            return Err(MlxError::InvalidArgument(
621                "cat requires at least one tensor".into(),
622            ));
623        }
624        let first = tensors[0];
625        let ndim = first.shape.ndim() as i32;
626        let ax = if axis < 0 { ndim + axis } else { axis };
627        if ax < 0 || ax >= ndim {
628            return Err(MlxError::InvalidArgument(format!(
629                "cat: axis {axis} out of range for ndim {ndim}"
630            )));
631        }
632        let ax_usize = ax as usize;
633
634        // Validate shapes and compute output dim
635        let mut total_dim: i64 = 0;
636        for t in tensors {
637            if t.shape.ndim() != first.shape.ndim() {
638                return Err(MlxError::InvalidArgument(
639                    "cat: all tensors must have same ndim".into(),
640                ));
641            }
642            for (d, (&a, &b)) in first.shape.0.iter().zip(t.shape.0.iter()).enumerate() {
643                if d != ax_usize && a != b {
644                    return Err(MlxError::ShapeMismatch {
645                        expected: first.shape.0.clone(),
646                        got: t.shape.0.clone(),
647                    });
648                }
649            }
650            total_dim += t.shape.0[ax_usize];
651        }
652
653        let mut new_dims = first.shape.0.clone();
654        new_dims[ax_usize] = total_dim;
655
656        let inputs: SmallVec<[NodeId; 2]> = tensors.iter().map(|t| t.node_id).collect();
657
658        Ok(first.lazy_op(
659            OpKind::Concatenate { axis: ax },
660            inputs,
661            Shape::new(new_dims),
662            first.dtype,
663        ))
664    }
665
666    /// Single-head attention: Q @ K^T * scale → causal mask → softmax → @ V.
667    /// Q: [Tq, Dh], K: [Tk, Dh], V: [Tk, Dh] → Output: [Tq, Dh]
668    pub fn attention(&self, k: &Tensor, v: &Tensor, scale: f32, causal: bool) -> Result<Tensor> {
669        if self.shape.ndim() != 2 || k.shape.ndim() != 2 || v.shape.ndim() != 2 {
670            return Err(MlxError::InvalidArgument(
671                "attention requires 2D tensors [seq, head_dim]".into(),
672            ));
673        }
674        let tq = self.shape.0[0];
675        let dh = self.shape.0[1];
676        if k.shape.0[1] != dh {
677            return Err(MlxError::ShapeMismatch {
678                expected: self.shape.0.clone(),
679                got: k.shape.0.clone(),
680            });
681        }
682        if v.shape.0[1] != dh || k.shape.0[0] != v.shape.0[0] {
683            return Err(MlxError::ShapeMismatch {
684                expected: k.shape.0.clone(),
685                got: v.shape.0.clone(),
686            });
687        }
688        Ok(self.lazy_op(
689            OpKind::Attention { scale, causal },
690            SmallVec::from_slice(&[self.node_id, k.node_id, v.node_id]),
691            Shape::new(vec![tq, dh]),
692            self.dtype,
693        ))
694    }
695
696    /// Element-wise square root.
697    pub fn sqrt(&self) -> Tensor {
698        self.lazy_op(
699            OpKind::Sqrt,
700            SmallVec::from_slice(&[self.node_id]),
701            self.shape.clone(),
702            self.dtype,
703        )
704    }
705
706    // ── Materialization ─────────────────────────────────────────────────
707
708    /// Materialize the tensor — triggers evaluation of the computation graph.
709    pub fn eval(&self) -> Result<()> {
710        self.stream.eval(self.node_id)
711    }
712
713    /// Copy data out as Vec<f32>. Triggers evaluation if needed.
714    pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
715        self.eval()?;
716        self.stream
717            .get_buffer(self.node_id)
718            .ok_or_else(|| MlxError::InvalidArgument("buffer not found after eval".into()))
719    }
720
721    // ── Accessors ───────────────────────────────────────────────────────
722
723    /// Get the tensor shape.
724    pub fn shape(&self) -> &Shape {
725        &self.shape
726    }
727
728    /// Get the tensor dtype.
729    pub fn dtype(&self) -> DType {
730        self.dtype
731    }
732
733    /// Get the tensor device.
734    pub fn device(&self) -> &Device {
735        &self.device
736    }
737
738    /// Number of elements.
739    pub fn numel(&self) -> i64 {
740        self.shape.numel()
741    }
742
743    /// Get the graph node ID.
744    pub fn node_id(&self) -> NodeId {
745        self.node_id
746    }
747
748    /// Get the stream this tensor belongs to.
749    pub fn stream(&self) -> Arc<Stream> {
750        Arc::clone(&self.stream)
751    }
752
753    /// Reconstruct a tensor handle from a node ID and metadata.
754    ///
755    /// Used by autograd to create handles for graph introspection.
756    pub fn from_node_id(
757        node_id: NodeId,
758        shape: Shape,
759        dtype: DType,
760        device: Device,
761        stream: Arc<Stream>,
762    ) -> Self {
763        Self {
764            node_id,
765            shape,
766            dtype,
767            device,
768            stream,
769        }
770    }
771
772    /// Broadcast this tensor to the target shape (numpy-style rules).
773    pub fn broadcast_to(&self, target: &Shape) -> Result<Tensor> {
774        if &self.shape == target {
775            return Ok(self.clone());
776        }
777        // Validate broadcast compatibility: dimensions are compared from the right.
778        let in_ndim = self.shape.ndim();
779        let out_ndim = target.ndim();
780        if in_ndim > out_ndim {
781            return Err(MlxError::InvalidArgument(format!(
782                "cannot broadcast shape {} to {}",
783                self.shape, target
784            )));
785        }
786        let pad = out_ndim - in_ndim;
787        for i in 0..in_ndim {
788            let in_dim = self.shape.0[i];
789            let out_dim = target.0[pad + i];
790            if in_dim != 1 && in_dim != out_dim {
791                return Err(MlxError::InvalidArgument(format!(
792                    "cannot broadcast shape {} to {}",
793                    self.shape, target
794                )));
795            }
796        }
797        Ok(self.lazy_op(
798            OpKind::Broadcast {
799                target_shape: target.clone(),
800            },
801            SmallVec::from_slice(&[self.node_id]),
802            target.clone(),
803            self.dtype,
804        ))
805    }
806}
807
808impl std::ops::Add for &Tensor {
809    type Output = Result<Tensor>;
810    fn add(self, rhs: &Tensor) -> Self::Output {
811        self.add(rhs)
812    }
813}
814
815impl std::ops::Sub for &Tensor {
816    type Output = Result<Tensor>;
817    fn sub(self, rhs: &Tensor) -> Self::Output {
818        Tensor::sub(self, rhs)
819    }
820}
821
822impl std::ops::Mul for &Tensor {
823    type Output = Result<Tensor>;
824    fn mul(self, rhs: &Tensor) -> Self::Output {
825        Tensor::mul(self, rhs)
826    }
827}
828
829impl std::ops::Neg for &Tensor {
830    type Output = Tensor;
831    fn neg(self) -> Self::Output {
832        Tensor::neg(self)
833    }
834}
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    fn cpu() -> Device {
841        Device::Cpu
842    }
843
844    #[test]
845    fn test_zeros() {
846        let t = Tensor::zeros(&Shape::new(vec![2, 3]), DType::F32, &cpu()).unwrap();
847        assert_eq!(t.to_vec_f32().unwrap(), vec![0.0; 6]);
848        assert_eq!(t.shape(), &Shape::new(vec![2, 3]));
849    }
850
851    #[test]
852    fn test_ones() {
853        let t = Tensor::ones(&Shape::new(vec![3]), DType::F32, &cpu()).unwrap();
854        assert_eq!(t.to_vec_f32().unwrap(), vec![1.0; 3]);
855    }
856
857    #[test]
858    fn test_from_f32() {
859        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
860        assert_eq!(t.to_vec_f32().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
861    }
862
863    #[test]
864    fn test_from_f32_shape_mismatch() {
865        let r = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![3]), &cpu());
866        assert!(r.is_err());
867    }
868
869    #[test]
870    fn test_add() {
871        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
872        let b = Tensor::from_f32(&[4.0, 5.0, 6.0], &Shape::new(vec![3]), &cpu()).unwrap();
873        let c = a.add(&b).unwrap();
874        assert_eq!(c.to_vec_f32().unwrap(), vec![5.0, 7.0, 9.0]);
875    }
876
877    #[test]
878    fn test_sub() {
879        let a = Tensor::from_f32(&[5.0, 7.0, 9.0], &Shape::new(vec![3]), &cpu()).unwrap();
880        let b = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
881        let c = a.sub(&b).unwrap();
882        assert_eq!(c.to_vec_f32().unwrap(), vec![4.0, 5.0, 6.0]);
883    }
884
885    #[test]
886    fn test_mul() {
887        let a = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
888        let b = Tensor::from_f32(&[4.0, 5.0], &Shape::new(vec![2]), &cpu()).unwrap();
889        let c = a.mul(&b).unwrap();
890        assert_eq!(c.to_vec_f32().unwrap(), vec![8.0, 15.0]);
891    }
892
893    #[test]
894    fn test_div() {
895        let a = Tensor::from_f32(&[10.0, 9.0], &Shape::new(vec![2]), &cpu()).unwrap();
896        let b = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
897        let c = a.div(&b).unwrap();
898        assert_eq!(c.to_vec_f32().unwrap(), vec![5.0, 3.0]);
899    }
900
901    #[test]
902    fn test_neg() {
903        let a = Tensor::from_f32(&[1.0, -2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
904        let b = a.neg();
905        assert_eq!(b.to_vec_f32().unwrap(), vec![-1.0, 2.0, -3.0]);
906    }
907
908    #[test]
909    fn test_matmul() {
910        let a = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
911        let b = Tensor::from_f32(&[5.0, 6.0, 7.0, 8.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
912        let c = a.matmul(&b).unwrap();
913        assert_eq!(c.to_vec_f32().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
914    }
915
916    #[test]
917    fn test_sum_axis() {
918        let a = Tensor::from_f32(
919            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
920            &Shape::new(vec![2, 3]),
921            &cpu(),
922        )
923        .unwrap();
924        let s0 = a.sum_axis(0).unwrap();
925        assert_eq!(s0.to_vec_f32().unwrap(), vec![5.0, 7.0, 9.0]);
926        let s1 = a.sum_axis(1).unwrap();
927        assert_eq!(s1.to_vec_f32().unwrap(), vec![6.0, 15.0]);
928    }
929
930    #[test]
931    fn test_sum_all() {
932        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
933        let s = a.sum_all().unwrap();
934        assert_eq!(s.to_vec_f32().unwrap(), vec![6.0]);
935    }
936
937    #[test]
938    fn test_softmax() {
939        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
940        let s = a.softmax(0).unwrap();
941        let vals = s.to_vec_f32().unwrap();
942        let sum: f32 = vals.iter().sum();
943        assert!((sum - 1.0).abs() < 1e-6);
944        assert!(vals[0] < vals[1]);
945        assert!(vals[1] < vals[2]);
946    }
947
948    #[test]
949    fn test_reshape() {
950        let a = Tensor::from_f32(
951            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
952            &Shape::new(vec![2, 3]),
953            &cpu(),
954        )
955        .unwrap();
956        let b = a.reshape(&Shape::new(vec![3, 2])).unwrap();
957        assert_eq!(b.shape(), &Shape::new(vec![3, 2]));
958        assert_eq!(b.to_vec_f32().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
959    }
960
961    #[test]
962    fn test_transpose() {
963        let a = Tensor::from_f32(
964            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
965            &Shape::new(vec![2, 3]),
966            &cpu(),
967        )
968        .unwrap();
969        let b = a.transpose(None).unwrap();
970        assert_eq!(b.shape(), &Shape::new(vec![3, 2]));
971        // [[1,2,3],[4,5,6]] transposed = [[1,4],[2,5],[3,6]]
972        assert_eq!(b.to_vec_f32().unwrap(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
973    }
974
975    #[test]
976    fn test_operator_add() {
977        let a = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![2]), &cpu()).unwrap();
978        let b = Tensor::from_f32(&[3.0, 4.0], &Shape::new(vec![2]), &cpu()).unwrap();
979        let c = (&a + &b).unwrap();
980        assert_eq!(c.to_vec_f32().unwrap(), vec![4.0, 6.0]);
981    }
982
983    #[test]
984    fn test_operator_neg() {
985        let a = Tensor::from_f32(&[1.0, -2.0], &Shape::new(vec![2]), &cpu()).unwrap();
986        let b = -&a;
987        assert_eq!(b.to_vec_f32().unwrap(), vec![-1.0, 2.0]);
988    }
989
990    #[test]
991    fn test_lazy_chain() {
992        // Build a chain: (a + b) * c — nothing evaluated until to_vec_f32
993        let a = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![2]), &cpu()).unwrap();
994        let b = Tensor::from_f32(&[3.0, 4.0], &Shape::new(vec![2]), &cpu()).unwrap();
995        let c = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
996        let d = a.add(&b).unwrap().mul(&c).unwrap();
997        // Only now does evaluation happen:
998        assert_eq!(d.to_vec_f32().unwrap(), vec![8.0, 18.0]);
999    }
1000
1001    #[test]
1002    fn test_silu() {
1003        let a = Tensor::from_f32(&[0.0, 1.0], &Shape::new(vec![2]), &cpu()).unwrap();
1004        let b = a.silu();
1005        let vals = b.to_vec_f32().unwrap();
1006        assert!((vals[0]).abs() < 1e-6);
1007        assert!((vals[1] - 0.7311).abs() < 1e-3);
1008    }
1009
1010    #[test]
1011    fn test_layer_norm() {
1012        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
1013        let b = a.layer_norm(1e-5);
1014        let vals = b.to_vec_f32().unwrap();
1015        let mean: f32 = vals.iter().sum::<f32>() / 3.0;
1016        assert!(mean.abs() < 1e-5);
1017    }
1018
1019    #[test]
1020    fn test_reduce_zero_dim_bug() {
1021        let x = Tensor::from_f32(&[], &Shape::new(vec![2, 3, 0]), &cpu()).unwrap();
1022        let s = x.sum_axis(1).unwrap(); // Should return shape [2, 0]
1023        assert_eq!(s.shape(), &Shape::new(vec![2, 0]));
1024        let vals = s.to_vec_f32().unwrap();
1025        assert_eq!(vals.len(), 0);
1026    }
1027
1028    #[test]
1029    fn test_softmax_zero_trailing_dim() {
1030        let x = Tensor::from_f32(&[], &Shape::new(vec![2, 3, 0]), &cpu()).unwrap();
1031        let s = x.softmax(1).unwrap();
1032        assert_eq!(s.shape(), &Shape::new(vec![2, 3, 0]));
1033        let vals = s.to_vec_f32().unwrap();
1034        assert_eq!(vals.len(), 0);
1035    }
1036}