1use std::sync::Arc;
8
9use smallvec::SmallVec;
10
11use crate::backend::{Stream, default_stream};
12use crate::graph::{OpKind, TensorMeta};
13use crate::{DType, MlxError, NodeId, Result, Shape};
14
15#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum Device {
18 Cpu,
19 Gpu,
20}
21
22impl Device {
23 pub fn default_device() -> Self {
25 #[cfg(target_os = "macos")]
26 {
27 Device::Cpu
29 }
30 #[cfg(not(target_os = "macos"))]
31 {
32 Device::Cpu
33 }
34 }
35}
36
37#[derive(Clone)]
43pub struct Tensor {
44 node_id: NodeId,
45 shape: Shape,
46 dtype: DType,
47 device: Device,
48 stream: Arc<Stream>,
49}
50
51impl Tensor {
52 pub fn zeros(shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
56 let n = shape.numel() as usize;
57 Self::from_data(vec![0.0; n], shape, dtype, device)
58 }
59
60 pub fn ones(shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
62 let n = shape.numel() as usize;
63 Self::from_data(vec![1.0; n], shape, dtype, device)
64 }
65
66 pub fn from_f32(data: &[f32], shape: &Shape, device: &Device) -> Result<Self> {
68 let expected = shape.numel() as usize;
69 if data.len() != expected {
70 return Err(MlxError::InvalidArgument(format!(
71 "data length {} does not match shape {} (expected {})",
72 data.len(),
73 shape,
74 expected,
75 )));
76 }
77 Self::from_data(data.to_vec(), shape, DType::F32, device)
78 }
79
80 pub fn from_f32_on_stream(data: &[f32], shape: &Shape, stream: &Arc<Stream>) -> Result<Self> {
82 let expected = shape.numel() as usize;
83 if data.len() != expected {
84 return Err(MlxError::InvalidArgument(format!(
85 "data length {} does not match shape {} (expected {})",
86 data.len(),
87 shape,
88 expected,
89 )));
90 }
91 let meta = TensorMeta {
92 shape: shape.clone(),
93 dtype: DType::F32,
94 };
95 let node_id = stream.add_constant(data.to_vec(), meta);
96 Ok(Self {
97 node_id,
98 shape: shape.clone(),
99 dtype: DType::F32,
100 device: Device::Gpu, stream: Arc::clone(stream),
102 })
103 }
104
105 pub fn from_data_with_dtype(
110 data: Vec<f32>,
111 shape: &Shape,
112 dtype: DType,
113 device: &Device,
114 ) -> Result<Self> {
115 let expected = shape.numel() as usize;
116 if data.len() != expected {
117 return Err(MlxError::InvalidArgument(format!(
118 "data length {} does not match shape {} (expected {})",
119 data.len(),
120 shape,
121 expected,
122 )));
123 }
124 Self::from_data(data, shape, dtype, device)
125 }
126
127 fn from_data(data: Vec<f32>, shape: &Shape, dtype: DType, device: &Device) -> Result<Self> {
128 let stream = default_stream();
129 let meta = TensorMeta {
130 shape: shape.clone(),
131 dtype,
132 };
133 let node_id = stream.add_constant(data, meta);
134 Ok(Self {
135 node_id,
136 shape: shape.clone(),
137 dtype,
138 device: device.clone(),
139 stream,
140 })
141 }
142
143 fn lazy_op(
144 &self,
145 op: OpKind,
146 inputs: SmallVec<[NodeId; 2]>,
147 shape: Shape,
148 dtype: DType,
149 ) -> Self {
150 let meta = TensorMeta {
151 shape: shape.clone(),
152 dtype,
153 };
154 let node_id = self.stream.add_op(op, inputs, meta);
155 Tensor {
156 node_id,
157 shape,
158 dtype,
159 device: self.device.clone(),
160 stream: Arc::clone(&self.stream),
161 }
162 }
163
164 pub fn add(&self, rhs: &Tensor) -> Result<Tensor> {
168 if self.shape != rhs.shape {
169 return Err(MlxError::ShapeMismatch {
170 expected: self.shape.0.clone(),
171 got: rhs.shape.0.clone(),
172 });
173 }
174 Ok(self.lazy_op(
175 OpKind::Add,
176 SmallVec::from_slice(&[self.node_id, rhs.node_id]),
177 self.shape.clone(),
178 self.dtype,
179 ))
180 }
181
182 pub fn sub(&self, rhs: &Tensor) -> Result<Tensor> {
184 if self.shape != rhs.shape {
185 return Err(MlxError::ShapeMismatch {
186 expected: self.shape.0.clone(),
187 got: rhs.shape.0.clone(),
188 });
189 }
190 Ok(self.lazy_op(
191 OpKind::Sub,
192 SmallVec::from_slice(&[self.node_id, rhs.node_id]),
193 self.shape.clone(),
194 self.dtype,
195 ))
196 }
197
198 pub fn mul(&self, rhs: &Tensor) -> Result<Tensor> {
200 if self.shape != rhs.shape {
201 return Err(MlxError::ShapeMismatch {
202 expected: self.shape.0.clone(),
203 got: rhs.shape.0.clone(),
204 });
205 }
206 Ok(self.lazy_op(
207 OpKind::Mul,
208 SmallVec::from_slice(&[self.node_id, rhs.node_id]),
209 self.shape.clone(),
210 self.dtype,
211 ))
212 }
213
214 pub fn div(&self, rhs: &Tensor) -> Result<Tensor> {
216 if self.shape != rhs.shape {
217 return Err(MlxError::ShapeMismatch {
218 expected: self.shape.0.clone(),
219 got: rhs.shape.0.clone(),
220 });
221 }
222 Ok(self.lazy_op(
223 OpKind::Div,
224 SmallVec::from_slice(&[self.node_id, rhs.node_id]),
225 self.shape.clone(),
226 self.dtype,
227 ))
228 }
229
230 pub fn neg(&self) -> Tensor {
232 self.lazy_op(
233 OpKind::Neg,
234 SmallVec::from_slice(&[self.node_id]),
235 self.shape.clone(),
236 self.dtype,
237 )
238 }
239
240 pub fn sum_axis(&self, axis: i32) -> Result<Tensor> {
244 let ndim = self.shape.ndim() as i32;
245 let ax = if axis < 0 { ndim + axis } else { axis };
246 if ax < 0 || ax >= ndim {
247 return Err(MlxError::InvalidArgument(format!(
248 "axis {axis} out of range for ndim {ndim}"
249 )));
250 }
251 let mut new_dims: Vec<i64> = self.shape.0.clone();
252 new_dims.remove(ax as usize);
253 Ok(self.lazy_op(
254 OpKind::Sum { axis: Some(ax) },
255 SmallVec::from_slice(&[self.node_id]),
256 Shape::new(new_dims),
257 self.dtype,
258 ))
259 }
260
261 pub fn sum_all(&self) -> Result<Tensor> {
263 Ok(self.lazy_op(
264 OpKind::Sum { axis: None },
265 SmallVec::from_slice(&[self.node_id]),
266 Shape::scalar(),
267 self.dtype,
268 ))
269 }
270
271 pub fn matmul(&self, rhs: &Tensor) -> Result<Tensor> {
275 if self.shape.ndim() != 2 || rhs.shape.ndim() != 2 {
276 return Err(MlxError::InvalidArgument(
277 "matmul requires 2D tensors".to_string(),
278 ));
279 }
280 let m = self.shape.0[0];
281 let k = self.shape.0[1];
282 let k2 = rhs.shape.0[0];
283 let n = rhs.shape.0[1];
284 if k != k2 {
285 return Err(MlxError::ShapeMismatch {
286 expected: self.shape.0.clone(),
287 got: rhs.shape.0.clone(),
288 });
289 }
290 Ok(self.lazy_op(
291 OpKind::MatMul,
292 SmallVec::from_slice(&[self.node_id, rhs.node_id]),
293 Shape::new(vec![m, n]),
294 self.dtype,
295 ))
296 }
297
298 pub fn reshape(&self, new_shape: &Shape) -> Result<Tensor> {
302 if self.shape.numel() != new_shape.numel() {
303 return Err(MlxError::ShapeMismatch {
304 expected: self.shape.0.clone(),
305 got: new_shape.0.clone(),
306 });
307 }
308 Ok(self.lazy_op(
309 OpKind::Reshape {
310 new_shape: new_shape.clone(),
311 },
312 SmallVec::from_slice(&[self.node_id]),
313 new_shape.clone(),
314 self.dtype,
315 ))
316 }
317
318 pub fn transpose(&self, axes: Option<&[usize]>) -> Result<Tensor> {
320 let ndim = self.shape.ndim();
321 let perm: Vec<usize> = match axes {
322 Some(ax) => {
323 if ax.len() != ndim {
324 return Err(MlxError::InvalidArgument(
325 "transpose axes length must match ndim".into(),
326 ));
327 }
328 let mut seen = vec![false; ndim];
329 for &axis in ax {
330 if axis >= ndim {
331 return Err(MlxError::InvalidArgument(format!(
332 "transpose axis {axis} out of range for ndim {ndim}"
333 )));
334 }
335 if seen[axis] {
336 return Err(MlxError::InvalidArgument(format!(
337 "duplicate transpose axis {axis} in axes; expected a permutation of 0..{ndim}"
338 )));
339 }
340 seen[axis] = true;
341 }
342 ax.to_vec()
343 }
344 None => (0..ndim).rev().collect(),
345 };
346 let new_dims: Vec<i64> = perm.iter().map(|&ax| self.shape.0[ax]).collect();
347 Ok(self.lazy_op(
348 OpKind::Transpose { axes: Some(perm) },
349 SmallVec::from_slice(&[self.node_id]),
350 Shape::new(new_dims),
351 self.dtype,
352 ))
353 }
354
355 pub fn softmax(&self, axis: i32) -> Result<Tensor> {
359 let ndim = self.shape.ndim() as i32;
360 let ax = if axis < 0 { ndim + axis } else { axis };
361 if ax < 0 || ax >= ndim {
362 return Err(MlxError::InvalidArgument(format!(
363 "axis {axis} out of range for ndim {ndim}"
364 )));
365 }
366 Ok(self.lazy_op(
367 OpKind::Softmax { axis },
368 SmallVec::from_slice(&[self.node_id]),
369 self.shape.clone(),
370 self.dtype,
371 ))
372 }
373
374 pub fn silu(&self) -> Tensor {
376 self.lazy_op(
377 OpKind::Silu,
378 SmallVec::from_slice(&[self.node_id]),
379 self.shape.clone(),
380 self.dtype,
381 )
382 }
383
384 pub fn gelu(&self) -> Tensor {
386 self.lazy_op(
387 OpKind::Gelu,
388 SmallVec::from_slice(&[self.node_id]),
389 self.shape.clone(),
390 self.dtype,
391 )
392 }
393
394 pub fn layer_norm(&self, eps: f32) -> Tensor {
398 self.lazy_op(
399 OpKind::LayerNorm { eps },
400 SmallVec::from_slice(&[self.node_id]),
401 self.shape.clone(),
402 self.dtype,
403 )
404 }
405
406 pub fn rms_norm(&self, eps: f32) -> Tensor {
408 self.lazy_op(
409 OpKind::RmsNorm { eps },
410 SmallVec::from_slice(&[self.node_id]),
411 self.shape.clone(),
412 self.dtype,
413 )
414 }
415
416 pub fn rope(&self, rotary_dim: usize, pos_offset: usize, theta: f32) -> Tensor {
418 self.lazy_op(
419 OpKind::Rope {
420 rotary_dim,
421 pos_offset,
422 theta,
423 },
424 SmallVec::from_slice(&[self.node_id]),
425 self.shape.clone(),
426 self.dtype,
427 )
428 }
429
430 pub fn layer_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor> {
434 if self.shape != input.shape {
435 return Err(MlxError::ShapeMismatch {
436 expected: input.shape.0.clone(),
437 got: self.shape.0.clone(),
438 });
439 }
440 if self.dtype != input.dtype {
441 return Err(MlxError::InvalidArgument(
442 "layer_norm_vjp requires matching dtypes".into(),
443 ));
444 }
445 if self.device != input.device {
446 return Err(MlxError::InvalidArgument(
447 "layer_norm_vjp requires matching devices".into(),
448 ));
449 }
450 Ok(self.lazy_op(
451 OpKind::LayerNormVjp { eps },
452 SmallVec::from_slice(&[self.node_id, input.node_id]),
453 input.shape.clone(),
454 input.dtype,
455 ))
456 }
457
458 pub fn rms_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor> {
460 if self.shape != input.shape {
461 return Err(MlxError::ShapeMismatch {
462 expected: input.shape.0.clone(),
463 got: self.shape.0.clone(),
464 });
465 }
466 if self.dtype != input.dtype {
467 return Err(MlxError::InvalidArgument(
468 "rms_norm_vjp requires matching dtypes".into(),
469 ));
470 }
471 if self.device != input.device {
472 return Err(MlxError::InvalidArgument(
473 "rms_norm_vjp requires matching devices".into(),
474 ));
475 }
476 Ok(self.lazy_op(
477 OpKind::RmsNormVjp { eps },
478 SmallVec::from_slice(&[self.node_id, input.node_id]),
479 input.shape.clone(),
480 input.dtype,
481 ))
482 }
483
484 pub fn softmax_vjp(&self, softmax_output: &Tensor, axis: i32) -> Result<Tensor> {
486 if self.shape != softmax_output.shape {
487 return Err(MlxError::ShapeMismatch {
488 expected: softmax_output.shape.0.clone(),
489 got: self.shape.0.clone(),
490 });
491 }
492 if self.dtype != softmax_output.dtype {
493 return Err(MlxError::InvalidArgument(
494 "softmax_vjp requires matching dtypes".into(),
495 ));
496 }
497 if self.device != softmax_output.device {
498 return Err(MlxError::InvalidArgument(
499 "softmax_vjp requires matching devices".into(),
500 ));
501 }
502 Ok(self.lazy_op(
503 OpKind::SoftmaxVjp { axis },
504 SmallVec::from_slice(&[self.node_id, softmax_output.node_id]),
505 softmax_output.shape.clone(),
506 softmax_output.dtype,
507 ))
508 }
509
510 pub fn silu_vjp(&self, input: &Tensor) -> Result<Tensor> {
512 if self.shape != input.shape {
513 return Err(MlxError::ShapeMismatch {
514 expected: input.shape.0.clone(),
515 got: self.shape.0.clone(),
516 });
517 }
518 if self.dtype != input.dtype {
519 return Err(MlxError::InvalidArgument(
520 "silu_vjp requires matching dtypes".into(),
521 ));
522 }
523 if self.device != input.device {
524 return Err(MlxError::InvalidArgument(
525 "silu_vjp requires matching devices".into(),
526 ));
527 }
528 Ok(self.lazy_op(
529 OpKind::SiluVjp,
530 SmallVec::from_slice(&[self.node_id, input.node_id]),
531 input.shape.clone(),
532 input.dtype,
533 ))
534 }
535
536 pub fn gelu_vjp(&self, input: &Tensor) -> Result<Tensor> {
538 if self.shape != input.shape {
539 return Err(MlxError::ShapeMismatch {
540 expected: input.shape.0.clone(),
541 got: self.shape.0.clone(),
542 });
543 }
544 if self.dtype != input.dtype {
545 return Err(MlxError::InvalidArgument(
546 "gelu_vjp requires matching dtypes".into(),
547 ));
548 }
549 if self.device != input.device {
550 return Err(MlxError::InvalidArgument(
551 "gelu_vjp requires matching devices".into(),
552 ));
553 }
554 Ok(self.lazy_op(
555 OpKind::GeluVjp,
556 SmallVec::from_slice(&[self.node_id, input.node_id]),
557 input.shape.clone(),
558 input.dtype,
559 ))
560 }
561
562 pub fn embedding_lookup(&self, indices: &Tensor) -> Result<Tensor> {
567 if self.shape.ndim() != 2 {
568 return Err(MlxError::InvalidArgument(
569 "embedding_lookup: weight must be 2D [vocab_size, embed_dim]".into(),
570 ));
571 }
572 if indices.shape.ndim() != 1 {
573 return Err(MlxError::InvalidArgument(
574 "embedding_lookup: indices must be 1D [seq_len]".into(),
575 ));
576 }
577 let seq_len = indices.shape.0[0];
578 let embed_dim = self.shape.0[1];
579 Ok(self.lazy_op(
580 OpKind::Embedding,
581 SmallVec::from_slice(&[self.node_id, indices.node_id]),
582 Shape::new(vec![seq_len, embed_dim]),
583 self.dtype,
584 ))
585 }
586
587 pub fn narrow(&self, axis: i32, start: i64, length: i64) -> Result<Tensor> {
589 let ndim = self.shape.ndim() as i32;
590 let ax = if axis < 0 { ndim + axis } else { axis };
591 if ax < 0 || ax >= ndim {
592 return Err(MlxError::InvalidArgument(format!(
593 "narrow: axis {axis} out of range for ndim {ndim}"
594 )));
595 }
596 let ax_usize = ax as usize;
597 let dim_size = self.shape.0[ax_usize];
598 if start < 0 || start + length > dim_size {
599 return Err(MlxError::InvalidArgument(format!(
600 "narrow: start {start} + length {length} exceeds dim size {dim_size}"
601 )));
602 }
603 let mut new_dims = self.shape.0.clone();
604 new_dims[ax_usize] = length;
605 Ok(self.lazy_op(
606 OpKind::Narrow {
607 axis: ax,
608 start,
609 length,
610 },
611 SmallVec::from_slice(&[self.node_id]),
612 Shape::new(new_dims),
613 self.dtype,
614 ))
615 }
616
617 pub fn cat(tensors: &[&Tensor], axis: i32) -> Result<Tensor> {
619 if tensors.is_empty() {
620 return Err(MlxError::InvalidArgument(
621 "cat requires at least one tensor".into(),
622 ));
623 }
624 let first = tensors[0];
625 let ndim = first.shape.ndim() as i32;
626 let ax = if axis < 0 { ndim + axis } else { axis };
627 if ax < 0 || ax >= ndim {
628 return Err(MlxError::InvalidArgument(format!(
629 "cat: axis {axis} out of range for ndim {ndim}"
630 )));
631 }
632 let ax_usize = ax as usize;
633
634 let mut total_dim: i64 = 0;
636 for t in tensors {
637 if t.shape.ndim() != first.shape.ndim() {
638 return Err(MlxError::InvalidArgument(
639 "cat: all tensors must have same ndim".into(),
640 ));
641 }
642 for (d, (&a, &b)) in first.shape.0.iter().zip(t.shape.0.iter()).enumerate() {
643 if d != ax_usize && a != b {
644 return Err(MlxError::ShapeMismatch {
645 expected: first.shape.0.clone(),
646 got: t.shape.0.clone(),
647 });
648 }
649 }
650 total_dim += t.shape.0[ax_usize];
651 }
652
653 let mut new_dims = first.shape.0.clone();
654 new_dims[ax_usize] = total_dim;
655
656 let inputs: SmallVec<[NodeId; 2]> = tensors.iter().map(|t| t.node_id).collect();
657
658 Ok(first.lazy_op(
659 OpKind::Concatenate { axis: ax },
660 inputs,
661 Shape::new(new_dims),
662 first.dtype,
663 ))
664 }
665
666 pub fn attention(&self, k: &Tensor, v: &Tensor, scale: f32, causal: bool) -> Result<Tensor> {
669 if self.shape.ndim() != 2 || k.shape.ndim() != 2 || v.shape.ndim() != 2 {
670 return Err(MlxError::InvalidArgument(
671 "attention requires 2D tensors [seq, head_dim]".into(),
672 ));
673 }
674 let tq = self.shape.0[0];
675 let dh = self.shape.0[1];
676 if k.shape.0[1] != dh {
677 return Err(MlxError::ShapeMismatch {
678 expected: self.shape.0.clone(),
679 got: k.shape.0.clone(),
680 });
681 }
682 if v.shape.0[1] != dh || k.shape.0[0] != v.shape.0[0] {
683 return Err(MlxError::ShapeMismatch {
684 expected: k.shape.0.clone(),
685 got: v.shape.0.clone(),
686 });
687 }
688 Ok(self.lazy_op(
689 OpKind::Attention { scale, causal },
690 SmallVec::from_slice(&[self.node_id, k.node_id, v.node_id]),
691 Shape::new(vec![tq, dh]),
692 self.dtype,
693 ))
694 }
695
696 pub fn sqrt(&self) -> Tensor {
698 self.lazy_op(
699 OpKind::Sqrt,
700 SmallVec::from_slice(&[self.node_id]),
701 self.shape.clone(),
702 self.dtype,
703 )
704 }
705
706 pub fn eval(&self) -> Result<()> {
710 self.stream.eval(self.node_id)
711 }
712
713 pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
715 self.eval()?;
716 self.stream
717 .get_buffer(self.node_id)
718 .ok_or_else(|| MlxError::InvalidArgument("buffer not found after eval".into()))
719 }
720
721 pub fn shape(&self) -> &Shape {
725 &self.shape
726 }
727
728 pub fn dtype(&self) -> DType {
730 self.dtype
731 }
732
733 pub fn device(&self) -> &Device {
735 &self.device
736 }
737
738 pub fn numel(&self) -> i64 {
740 self.shape.numel()
741 }
742
743 pub fn node_id(&self) -> NodeId {
745 self.node_id
746 }
747
748 pub fn stream(&self) -> Arc<Stream> {
750 Arc::clone(&self.stream)
751 }
752
753 pub fn from_node_id(
757 node_id: NodeId,
758 shape: Shape,
759 dtype: DType,
760 device: Device,
761 stream: Arc<Stream>,
762 ) -> Self {
763 Self {
764 node_id,
765 shape,
766 dtype,
767 device,
768 stream,
769 }
770 }
771
772 pub fn broadcast_to(&self, target: &Shape) -> Result<Tensor> {
774 if &self.shape == target {
775 return Ok(self.clone());
776 }
777 let in_ndim = self.shape.ndim();
779 let out_ndim = target.ndim();
780 if in_ndim > out_ndim {
781 return Err(MlxError::InvalidArgument(format!(
782 "cannot broadcast shape {} to {}",
783 self.shape, target
784 )));
785 }
786 let pad = out_ndim - in_ndim;
787 for i in 0..in_ndim {
788 let in_dim = self.shape.0[i];
789 let out_dim = target.0[pad + i];
790 if in_dim != 1 && in_dim != out_dim {
791 return Err(MlxError::InvalidArgument(format!(
792 "cannot broadcast shape {} to {}",
793 self.shape, target
794 )));
795 }
796 }
797 Ok(self.lazy_op(
798 OpKind::Broadcast {
799 target_shape: target.clone(),
800 },
801 SmallVec::from_slice(&[self.node_id]),
802 target.clone(),
803 self.dtype,
804 ))
805 }
806}
807
808impl std::ops::Add for &Tensor {
809 type Output = Result<Tensor>;
810 fn add(self, rhs: &Tensor) -> Self::Output {
811 self.add(rhs)
812 }
813}
814
815impl std::ops::Sub for &Tensor {
816 type Output = Result<Tensor>;
817 fn sub(self, rhs: &Tensor) -> Self::Output {
818 Tensor::sub(self, rhs)
819 }
820}
821
822impl std::ops::Mul for &Tensor {
823 type Output = Result<Tensor>;
824 fn mul(self, rhs: &Tensor) -> Self::Output {
825 Tensor::mul(self, rhs)
826 }
827}
828
829impl std::ops::Neg for &Tensor {
830 type Output = Tensor;
831 fn neg(self) -> Self::Output {
832 Tensor::neg(self)
833 }
834}
835
836#[cfg(test)]
837mod tests {
838 use super::*;
839
840 fn cpu() -> Device {
841 Device::Cpu
842 }
843
844 #[test]
845 fn test_zeros() {
846 let t = Tensor::zeros(&Shape::new(vec![2, 3]), DType::F32, &cpu()).unwrap();
847 assert_eq!(t.to_vec_f32().unwrap(), vec![0.0; 6]);
848 assert_eq!(t.shape(), &Shape::new(vec![2, 3]));
849 }
850
851 #[test]
852 fn test_ones() {
853 let t = Tensor::ones(&Shape::new(vec![3]), DType::F32, &cpu()).unwrap();
854 assert_eq!(t.to_vec_f32().unwrap(), vec![1.0; 3]);
855 }
856
857 #[test]
858 fn test_from_f32() {
859 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
860 assert_eq!(t.to_vec_f32().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
861 }
862
863 #[test]
864 fn test_from_f32_shape_mismatch() {
865 let r = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![3]), &cpu());
866 assert!(r.is_err());
867 }
868
869 #[test]
870 fn test_add() {
871 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
872 let b = Tensor::from_f32(&[4.0, 5.0, 6.0], &Shape::new(vec![3]), &cpu()).unwrap();
873 let c = a.add(&b).unwrap();
874 assert_eq!(c.to_vec_f32().unwrap(), vec![5.0, 7.0, 9.0]);
875 }
876
877 #[test]
878 fn test_sub() {
879 let a = Tensor::from_f32(&[5.0, 7.0, 9.0], &Shape::new(vec![3]), &cpu()).unwrap();
880 let b = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
881 let c = a.sub(&b).unwrap();
882 assert_eq!(c.to_vec_f32().unwrap(), vec![4.0, 5.0, 6.0]);
883 }
884
885 #[test]
886 fn test_mul() {
887 let a = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
888 let b = Tensor::from_f32(&[4.0, 5.0], &Shape::new(vec![2]), &cpu()).unwrap();
889 let c = a.mul(&b).unwrap();
890 assert_eq!(c.to_vec_f32().unwrap(), vec![8.0, 15.0]);
891 }
892
893 #[test]
894 fn test_div() {
895 let a = Tensor::from_f32(&[10.0, 9.0], &Shape::new(vec![2]), &cpu()).unwrap();
896 let b = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
897 let c = a.div(&b).unwrap();
898 assert_eq!(c.to_vec_f32().unwrap(), vec![5.0, 3.0]);
899 }
900
901 #[test]
902 fn test_neg() {
903 let a = Tensor::from_f32(&[1.0, -2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
904 let b = a.neg();
905 assert_eq!(b.to_vec_f32().unwrap(), vec![-1.0, 2.0, -3.0]);
906 }
907
908 #[test]
909 fn test_matmul() {
910 let a = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
911 let b = Tensor::from_f32(&[5.0, 6.0, 7.0, 8.0], &Shape::new(vec![2, 2]), &cpu()).unwrap();
912 let c = a.matmul(&b).unwrap();
913 assert_eq!(c.to_vec_f32().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
914 }
915
916 #[test]
917 fn test_sum_axis() {
918 let a = Tensor::from_f32(
919 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
920 &Shape::new(vec![2, 3]),
921 &cpu(),
922 )
923 .unwrap();
924 let s0 = a.sum_axis(0).unwrap();
925 assert_eq!(s0.to_vec_f32().unwrap(), vec![5.0, 7.0, 9.0]);
926 let s1 = a.sum_axis(1).unwrap();
927 assert_eq!(s1.to_vec_f32().unwrap(), vec![6.0, 15.0]);
928 }
929
930 #[test]
931 fn test_sum_all() {
932 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
933 let s = a.sum_all().unwrap();
934 assert_eq!(s.to_vec_f32().unwrap(), vec![6.0]);
935 }
936
937 #[test]
938 fn test_softmax() {
939 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
940 let s = a.softmax(0).unwrap();
941 let vals = s.to_vec_f32().unwrap();
942 let sum: f32 = vals.iter().sum();
943 assert!((sum - 1.0).abs() < 1e-6);
944 assert!(vals[0] < vals[1]);
945 assert!(vals[1] < vals[2]);
946 }
947
948 #[test]
949 fn test_reshape() {
950 let a = Tensor::from_f32(
951 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
952 &Shape::new(vec![2, 3]),
953 &cpu(),
954 )
955 .unwrap();
956 let b = a.reshape(&Shape::new(vec![3, 2])).unwrap();
957 assert_eq!(b.shape(), &Shape::new(vec![3, 2]));
958 assert_eq!(b.to_vec_f32().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
959 }
960
961 #[test]
962 fn test_transpose() {
963 let a = Tensor::from_f32(
964 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
965 &Shape::new(vec![2, 3]),
966 &cpu(),
967 )
968 .unwrap();
969 let b = a.transpose(None).unwrap();
970 assert_eq!(b.shape(), &Shape::new(vec![3, 2]));
971 assert_eq!(b.to_vec_f32().unwrap(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
973 }
974
975 #[test]
976 fn test_operator_add() {
977 let a = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![2]), &cpu()).unwrap();
978 let b = Tensor::from_f32(&[3.0, 4.0], &Shape::new(vec![2]), &cpu()).unwrap();
979 let c = (&a + &b).unwrap();
980 assert_eq!(c.to_vec_f32().unwrap(), vec![4.0, 6.0]);
981 }
982
983 #[test]
984 fn test_operator_neg() {
985 let a = Tensor::from_f32(&[1.0, -2.0], &Shape::new(vec![2]), &cpu()).unwrap();
986 let b = -&a;
987 assert_eq!(b.to_vec_f32().unwrap(), vec![-1.0, 2.0]);
988 }
989
990 #[test]
991 fn test_lazy_chain() {
992 let a = Tensor::from_f32(&[1.0, 2.0], &Shape::new(vec![2]), &cpu()).unwrap();
994 let b = Tensor::from_f32(&[3.0, 4.0], &Shape::new(vec![2]), &cpu()).unwrap();
995 let c = Tensor::from_f32(&[2.0, 3.0], &Shape::new(vec![2]), &cpu()).unwrap();
996 let d = a.add(&b).unwrap().mul(&c).unwrap();
997 assert_eq!(d.to_vec_f32().unwrap(), vec![8.0, 18.0]);
999 }
1000
1001 #[test]
1002 fn test_silu() {
1003 let a = Tensor::from_f32(&[0.0, 1.0], &Shape::new(vec![2]), &cpu()).unwrap();
1004 let b = a.silu();
1005 let vals = b.to_vec_f32().unwrap();
1006 assert!((vals[0]).abs() < 1e-6);
1007 assert!((vals[1] - 0.7311).abs() < 1e-3);
1008 }
1009
1010 #[test]
1011 fn test_layer_norm() {
1012 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &Shape::new(vec![3]), &cpu()).unwrap();
1013 let b = a.layer_norm(1e-5);
1014 let vals = b.to_vec_f32().unwrap();
1015 let mean: f32 = vals.iter().sum::<f32>() / 3.0;
1016 assert!(mean.abs() < 1e-5);
1017 }
1018
1019 #[test]
1020 fn test_reduce_zero_dim_bug() {
1021 let x = Tensor::from_f32(&[], &Shape::new(vec![2, 3, 0]), &cpu()).unwrap();
1022 let s = x.sum_axis(1).unwrap(); assert_eq!(s.shape(), &Shape::new(vec![2, 0]));
1024 let vals = s.to_vec_f32().unwrap();
1025 assert_eq!(vals.len(), 0);
1026 }
1027
1028 #[test]
1029 fn test_softmax_zero_trailing_dim() {
1030 let x = Tensor::from_f32(&[], &Shape::new(vec![2, 3, 0]), &cpu()).unwrap();
1031 let s = x.softmax(1).unwrap();
1032 assert_eq!(s.shape(), &Shape::new(vec![2, 3, 0]));
1033 let vals = s.to_vec_f32().unwrap();
1034 assert_eq!(vals.len(), 0);
1035 }
1036}