Skip to main content

axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! # File
4//! `crates/axonml-tensor/src/tensor.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29// =============================================================================
30// CUDA Acceleration
31// =============================================================================
32
33#[cfg(feature = "cuda")]
34mod cuda_accel {
35    use super::*;
36    use axonml_core::backends::cuda::get_cuda_backend;
37
38    /// Get the global CUDA backend (delegates to core singleton).
39    pub fn get_cuda() -> Option<&'static CudaBackend> {
40        get_cuda_backend()
41    }
42
43    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
44    /// Returns None if GPU is unavailable or an error occurs.
45    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46        let cuda = get_cuda()?;
47
48        let a_gpu = cuda.htod_copy(a).ok()?;
49        let b_gpu = cuda.htod_copy(b).ok()?;
50        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
53        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
54        cuda.gemm_f32(
55            false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56        )
57        .ok()?;
58
59        cuda.dtoh_copy(&c_gpu).ok()
60    }
61}
62
63use crate::shape::{
64    Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65    linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66    unsqueeze,
67};
68
69// =============================================================================
70// GPU Dispatch Helpers
71// =============================================================================
72//
73// These enable calling Tensor<f32> GPU methods from generic Tensor<T> code
74// when T is verified to be f32 via TypeId check at runtime.
75
76#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78    assert!(
79        is_f32::<T>(),
80        "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
81        T::DTYPE
82    );
83    // SAFETY: T is f32 (asserted above), Tensor<f32> and Tensor<T> have identical layout
84    unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
85}
86
87#[cfg(feature = "cuda")]
88unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
89    assert!(
90        is_f32::<T>(),
91        "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
92        T::DTYPE
93    );
94    // SAFETY: T is f32 (asserted above), ownership transfer via ptr::read + forget
95    unsafe {
96        let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
97        std::mem::forget(t);
98        out
99    }
100}
101
102#[cfg(feature = "cuda")]
103fn is_f32<T: 'static>() -> bool {
104    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
105}
106
107// =============================================================================
108// Tensor Struct
109// =============================================================================
110
111/// An N-dimensional array of numeric values.
112///
113/// Tensors are the core data structure for all computations in Axonml.
114/// They support arbitrary dimensions, automatic broadcasting, and efficient
115/// memory sharing between views.
116#[derive(Clone)]
117pub struct Tensor<T: Scalar> {
118    /// Underlying data storage (reference-counted).
119    pub(crate) storage: Storage<T>,
120    /// Shape of the tensor (dimensions).
121    pub(crate) shape: Shape,
122    /// Strides for each dimension.
123    pub(crate) strides: Strides,
124    /// Offset into storage (for views).
125    pub(crate) offset: usize,
126}
127
128impl<T: Scalar> Tensor<T> {
129    // =========================================================================
130    // Constructors
131    // =========================================================================
132
133    /// Creates a new tensor from storage with the given shape.
134    ///
135    /// # Arguments
136    /// * `storage` - The underlying data storage
137    /// * `shape` - Shape of the tensor
138    ///
139    /// # Returns
140    /// New tensor, or error if shape doesn't match storage size.
141    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
142        let total = numel(shape);
143        if total != storage.len() {
144            return Err(Error::shape_mismatch(&[storage.len()], shape));
145        }
146
147        let shape = Shape::from_slice(shape);
148        let strides = contiguous_strides(&shape);
149
150        Ok(Self {
151            storage,
152            shape,
153            strides,
154            offset: 0,
155        })
156    }
157
158    /// Creates a new tensor from a vector with the given shape.
159    ///
160    /// # Arguments
161    /// * `data` - Vector of data
162    /// * `shape` - Shape of the tensor
163    ///
164    /// # Returns
165    /// New tensor, or error if shape doesn't match data length.
166    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
167        let storage = Storage::from_vec(data, Device::Cpu);
168        Self::from_storage(storage, shape)
169    }
170
171    /// Creates a new tensor from a slice with the given shape.
172    ///
173    /// # Arguments
174    /// * `data` - Slice of data to copy
175    /// * `shape` - Shape of the tensor
176    ///
177    /// # Returns
178    /// New tensor, or error if shape doesn't match data length.
179    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
180        let storage = Storage::from_slice(data, Device::Cpu);
181        Self::from_storage(storage, shape)
182    }
183
184    /// Creates a scalar tensor (0-dimensional).
185    ///
186    /// # Arguments
187    /// * `value` - The scalar value
188    ///
189    /// # Returns
190    /// New 0-dimensional tensor.
191    pub fn scalar(value: T) -> Self {
192        Self {
193            storage: Storage::from_vec(vec![value], Device::Cpu),
194            shape: Shape::new(),
195            strides: Strides::new(),
196            offset: 0,
197        }
198    }
199
200    /// Creates a tensor filled with zeros.
201    #[must_use]
202    pub fn zeros(shape: &[usize]) -> Self {
203        crate::creation::zeros(shape)
204    }
205
206    /// Creates a tensor filled with ones.
207    #[must_use]
208    pub fn ones(shape: &[usize]) -> Self
209    where
210        T: Numeric,
211    {
212        crate::creation::ones(shape)
213    }
214
215    /// Creates a tensor filled with a constant value.
216    #[must_use]
217    pub fn full(shape: &[usize], value: T) -> Self {
218        crate::creation::full(shape, value)
219    }
220
221    /// Creates a tensor with random values from standard normal distribution.
222    #[must_use]
223    pub fn randn(shape: &[usize]) -> Self
224    where
225        T: Float,
226        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
227    {
228        crate::creation::randn(shape)
229    }
230
231    /// Creates a tensor with random values from uniform distribution [0, 1).
232    #[must_use]
233    pub fn rand(shape: &[usize]) -> Self
234    where
235        T: Float,
236        rand::distributions::Standard: rand::distributions::Distribution<T>,
237    {
238        crate::creation::rand(shape)
239    }
240
241    // =========================================================================
242    // Properties
243    // =========================================================================
244
245    /// Returns the shape of the tensor.
246    #[must_use]
247    pub fn shape(&self) -> &[usize] {
248        &self.shape
249    }
250
251    /// Returns the strides of the tensor.
252    #[must_use]
253    pub fn strides(&self) -> &[isize] {
254        &self.strides
255    }
256
257    /// Returns the number of dimensions.
258    #[must_use]
259    pub fn ndim(&self) -> usize {
260        self.shape.len()
261    }
262
263    /// Returns the total number of elements.
264    #[must_use]
265    pub fn numel(&self) -> usize {
266        numel(&self.shape)
267    }
268
269    /// Returns true if the tensor is empty (has zero elements).
270    #[must_use]
271    pub fn is_empty(&self) -> bool {
272        self.numel() == 0
273    }
274
275    /// Returns the size of a specific dimension.
276    ///
277    /// # Arguments
278    /// * `dim` - Dimension index (supports negative indexing)
279    pub fn size(&self, dim: i64) -> Result<usize> {
280        let idx = normalize_dim(dim, self.ndim())?;
281        Ok(self.shape[idx])
282    }
283
284    /// Returns the device this tensor is on.
285    #[must_use]
286    pub fn device(&self) -> Device {
287        self.storage.device()
288    }
289
290    /// Returns true if the tensor is contiguous in memory.
291    #[must_use]
292    pub fn is_contiguous(&self) -> bool {
293        is_contiguous(&self.shape, &self.strides)
294    }
295
296    /// Returns true if this tensor is a scalar (0-dimensional).
297    #[must_use]
298    pub fn is_scalar(&self) -> bool {
299        self.shape.is_empty()
300    }
301
302    // =========================================================================
303    // Data Access
304    // =========================================================================
305
306    /// Returns the element at the given indices.
307    ///
308    /// # Arguments
309    /// * `indices` - Multi-dimensional indices
310    pub fn get(&self, indices: &[usize]) -> Result<T> {
311        if indices.len() != self.ndim() {
312            return Err(Error::invalid_operation(format!(
313                "Expected {} indices, got {}",
314                self.ndim(),
315                indices.len()
316            )));
317        }
318
319        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
320            if idx >= dim {
321                return Err(Error::IndexOutOfBounds {
322                    index: idx,
323                    size: dim,
324                });
325            }
326        }
327
328        let offset = self.offset + linear_index(indices, &self.strides);
329        Ok(self.storage.as_slice()[offset])
330    }
331
332    /// Sets the element at the given indices.
333    ///
334    /// # Arguments
335    /// * `indices` - Multi-dimensional indices
336    /// * `value` - Value to set
337    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
338        if indices.len() != self.ndim() {
339            return Err(Error::invalid_operation(format!(
340                "Expected {} indices, got {}",
341                self.ndim(),
342                indices.len()
343            )));
344        }
345
346        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
347            if idx >= dim {
348                return Err(Error::IndexOutOfBounds {
349                    index: idx,
350                    size: dim,
351                });
352            }
353        }
354
355        let offset = self.offset + linear_index(indices, &self.strides);
356        self.storage.as_slice_mut()[offset] = value;
357        Ok(())
358    }
359
360    /// Returns the scalar value for a 0-dimensional tensor.
361    pub fn item(&self) -> Result<T> {
362        if self.numel() != 1 {
363            return Err(Error::invalid_operation(
364                "item() only works on single-element tensors",
365            ));
366        }
367
368        // Use to_vec() which handles both CPU and GPU tensors safely
369        let data = self.to_vec();
370        if data.is_empty() {
371            Err(Error::invalid_operation("item() on empty tensor"))
372        } else {
373            Ok(data[0])
374        }
375    }
376
377    /// Returns the data as a contiguous vector.
378    ///
379    /// If the tensor is already contiguous, this returns a reference.
380    /// Otherwise, it copies the data into a new contiguous vector.
381    /// For GPU tensors (f32 only), performs a D2H copy.
382    #[must_use]
383    pub fn to_vec(&self) -> Vec<T> {
384        // GPU path: GPU storage is always f32
385        #[cfg(feature = "cuda")]
386        if self.storage.is_gpu() {
387            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
388            let self_f32 = unsafe { gpu_ref(self) };
389            let f32_vec = self_f32.to_vec_gpu();
390            unsafe {
391                let mut v = std::mem::ManuallyDrop::new(f32_vec);
392                return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
393            }
394        }
395
396        if self.is_contiguous() {
397            let storage = self.storage.as_slice();
398            storage[self.offset..self.offset + self.numel()].to_vec()
399        } else {
400            let mut result = Vec::with_capacity(self.numel());
401            self.copy_data_to(&mut result);
402            result
403        }
404    }
405
406    /// Copies data to a slice, handling non-contiguous layouts.
407    fn copy_data_to(&self, dst: &mut Vec<T>) {
408        dst.clear();
409        let storage = self.storage.as_slice();
410
411        // Iterate through all indices
412        let total = self.numel();
413        for i in 0..total {
414            let indices = crate::shape::unravel_index(i, &self.shape);
415            let offset = self.offset + linear_index(&indices, &self.strides);
416            dst.push(storage[offset]);
417        }
418    }
419
420    // =========================================================================
421    // Shape Operations
422    // =========================================================================
423
424    /// Returns a new tensor with the specified shape.
425    ///
426    /// The total number of elements must remain the same.
427    /// Supports -1 in one dimension to infer the size.
428    ///
429    /// # Arguments
430    /// * `new_shape` - Target shape
431    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
432        let shape = reshape(&self.shape, new_shape)?;
433
434        if self.is_contiguous() {
435            // Can just change shape without copying
436            Ok(Self {
437                storage: self.storage.clone(),
438                strides: contiguous_strides(&shape),
439                shape,
440                offset: self.offset,
441            })
442        } else {
443            // Need to make contiguous first
444            let contig = self.contiguous();
445            Ok(Self {
446                storage: contig.storage,
447                strides: contiguous_strides(&shape),
448                shape,
449                offset: 0,
450            })
451        }
452    }
453
454    /// Returns a new tensor with a flattened shape.
455    #[must_use]
456    pub fn flatten(&self) -> Self {
457        self.reshape(&[-1]).expect("Flatten should never fail")
458    }
459
460    /// Returns a new tensor with dimensions of size 1 removed.
461    ///
462    /// # Arguments
463    /// * `dim` - Optional specific dimension to squeeze
464    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
465        let dim = match dim {
466            Some(d) => Some(normalize_dim(d, self.ndim())?),
467            None => None,
468        };
469
470        let new_shape = squeeze(&self.shape, dim);
471        let new_strides: Strides = match dim {
472            Some(d) => {
473                let mut s = self.strides.clone();
474                if d < self.shape.len() && self.shape[d] == 1 {
475                    s.remove(d);
476                }
477                s
478            }
479            None => self
480                .shape
481                .iter()
482                .zip(self.strides.iter())
483                .filter(|(dim, _)| **dim != 1)
484                .map(|(_, stride)| *stride)
485                .collect(),
486        };
487
488        Ok(Self {
489            storage: self.storage.clone(),
490            shape: new_shape,
491            strides: new_strides,
492            offset: self.offset,
493        })
494    }
495
496    /// Returns a new tensor with a dimension of size 1 inserted.
497    ///
498    /// # Arguments
499    /// * `dim` - Position to insert the new dimension
500    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
501        let normalized = if dim < 0 {
502            (dim + self.ndim() as i64 + 1) as usize
503        } else {
504            dim as usize
505        };
506
507        let new_shape = unsqueeze(&self.shape, normalized)?;
508        let mut new_strides = Strides::with_capacity(new_shape.len());
509
510        for (i, _) in new_shape.iter().enumerate() {
511            if i < normalized {
512                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
513            } else if i == normalized {
514                // Stride for new dimension (doesn't matter since size is 1)
515                new_strides.push(1);
516            } else {
517                new_strides.push(self.strides[i - 1]);
518            }
519        }
520
521        Ok(Self {
522            storage: self.storage.clone(),
523            shape: new_shape,
524            strides: new_strides,
525            offset: self.offset,
526        })
527    }
528
529    /// Transposes two dimensions.
530    ///
531    /// # Arguments
532    /// * `dim0` - First dimension
533    /// * `dim1` - Second dimension
534    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
535        let d0 = normalize_dim(dim0, self.ndim())?;
536        let d1 = normalize_dim(dim1, self.ndim())?;
537
538        let new_shape = transpose_shape(&self.shape, d0, d1)?;
539        let new_strides = transpose_strides(&self.strides, d0, d1);
540
541        Ok(Self {
542            storage: self.storage.clone(),
543            shape: new_shape,
544            strides: new_strides,
545            offset: self.offset,
546        })
547    }
548
549    /// Returns the transpose of a 2D tensor.
550    pub fn t(&self) -> Result<Self> {
551        if self.ndim() != 2 {
552            return Err(Error::invalid_operation("t() only works on 2D tensors"));
553        }
554        self.transpose(0, 1)
555    }
556
557    /// Returns a permuted tensor with dimensions reordered.
558    ///
559    /// # Arguments
560    /// * `dims` - New order of dimensions
561    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
562        if dims.len() != self.ndim() {
563            return Err(Error::invalid_operation(format!(
564                "Expected {} dimensions, got {}",
565                self.ndim(),
566                dims.len()
567            )));
568        }
569
570        // Check that dims is a permutation
571        let mut seen = vec![false; self.ndim()];
572        for &d in dims {
573            if d >= self.ndim() {
574                return Err(Error::InvalidDimension {
575                    index: d as i64,
576                    ndim: self.ndim(),
577                });
578            }
579            if seen[d] {
580                return Err(Error::invalid_operation("Duplicate dimension in permute"));
581            }
582            seen[d] = true;
583        }
584
585        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
586        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
587
588        Ok(Self {
589            storage: self.storage.clone(),
590            shape: new_shape,
591            strides: new_strides,
592            offset: self.offset,
593        })
594    }
595
596    /// Returns a contiguous copy of the tensor.
597    #[must_use]
598    pub fn contiguous(&self) -> Self {
599        if self.is_contiguous() && self.offset == 0 {
600            return self.clone();
601        }
602
603        #[cfg(feature = "cuda")]
604        if self.storage.is_gpu() {
605            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
606            let self_f32 = unsafe { gpu_ref(self) };
607            let result = self_f32.contiguous_gpu();
608            return unsafe { gpu_into(result) };
609        }
610
611        let data = self.to_vec();
612        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
613    }
614
615    // =========================================================================
616    // Functional Map Operations (zero-copy for CPU tensors)
617    // =========================================================================
618
619    /// Apply a function element-wise, producing a new tensor with the same shape.
620    ///
621    /// Avoids the to_vec() → map → from_vec() pattern by operating directly
622    /// on contiguous storage.
623    #[must_use]
624    pub fn map<F: Fn(T) -> T>(&self, f: F) -> Self {
625        let data = self.to_vec(); // contiguous read
626        let result: Vec<T> = data.into_iter().map(f).collect();
627        Self::from_vec(result, &self.shape).unwrap()
628    }
629
630    /// Apply a binary function element-wise with another tensor of the same shape.
631    ///
632    /// This is the primary zero-allocation pattern for backward functions:
633    /// instead of `a.to_vec()` + `b.to_vec()` + zip + `from_vec()`,
634    /// use `a.zip_map(&b, |x, y| ...)` which does a single allocation.
635    #[must_use]
636    pub fn zip_map<F: Fn(T, T) -> T>(&self, other: &Self, f: F) -> Self {
637        let a = self.to_vec();
638        let b = other.to_vec();
639        debug_assert_eq!(
640            a.len(),
641            b.len(),
642            "zip_map requires same number of elements: {} vs {}",
643            a.len(),
644            b.len()
645        );
646        let result: Vec<T> = a.into_iter().zip(b).map(|(x, y)| f(x, y)).collect();
647        Self::from_vec(result, &self.shape).unwrap()
648    }
649
650    /// Apply a ternary function element-wise with two other tensors.
651    #[must_use]
652    pub fn zip_map3<F: Fn(T, T, T) -> T>(&self, b: &Self, c: &Self, f: F) -> Self {
653        let a_data = self.to_vec();
654        let b_data = b.to_vec();
655        let c_data = c.to_vec();
656        debug_assert_eq!(a_data.len(), b_data.len());
657        debug_assert_eq!(a_data.len(), c_data.len());
658        let result: Vec<T> = a_data
659            .into_iter()
660            .zip(b_data)
661            .zip(c_data)
662            .map(|((a, b), c)| f(a, b, c))
663            .collect();
664        Self::from_vec(result, &self.shape).unwrap()
665    }
666
667    // =========================================================================
668    // Device Operations
669    // =========================================================================
670
671    /// Transfers the tensor to a different device.
672    ///
673    /// # Arguments
674    /// * `device` - Target device
675    pub fn to_device(&self, device: Device) -> Result<Self> {
676        if self.device() == device {
677            return Ok(self.clone());
678        }
679
680        #[cfg(feature = "cuda")]
681        if self.storage.is_gpu() || device.is_gpu() {
682            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
683            let self_f32 = unsafe { gpu_ref(self) };
684            let result = self_f32.to_device_f32(device)?;
685            return Ok(unsafe { gpu_into(result) });
686        }
687
688        let contig = self.contiguous();
689        let new_storage = contig.storage.to_device(device)?;
690
691        Ok(Self {
692            storage: new_storage,
693            shape: self.shape.clone(),
694            strides: self.strides.clone(),
695            offset: 0,
696        })
697    }
698
699    /// Transfers to CPU.
700    pub fn cpu(&self) -> Result<Self> {
701        self.to_device(Device::Cpu)
702    }
703
704    // =========================================================================
705    // Deep Copy
706    // =========================================================================
707
708    /// Creates a deep copy of this tensor with its own storage.
709    #[must_use]
710    pub fn clone_deep(&self) -> Self {
711        let data = self.to_vec();
712        let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
713        #[cfg(feature = "cuda")]
714        if self.device().is_gpu() {
715            return cpu.to_device(self.device()).unwrap();
716        }
717        cpu
718    }
719}
720
721// =============================================================================
722// Numeric Operations
723// =============================================================================
724
725impl<T: Numeric> Tensor<T> {
726    /// Fills the tensor with a value.
727    ///
728    /// # Panics
729    /// Panics on GPU tensors. Use `Tensor::from_vec(vec![value; n], shape)`
730    /// followed by `.to_device()` instead.
731    pub fn fill_(&self, value: T) {
732        assert!(
733            self.storage.is_cpu(),
734            "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
735        );
736        let mut data = self.storage.as_slice_mut();
737        CpuBackend::fill(&mut data, value);
738    }
739
740    /// Fills the tensor with zeros.
741    pub fn zero_(&self) {
742        self.fill_(T::zero());
743    }
744
745    // =========================================================================
746    // Reduction Operations
747    // =========================================================================
748
749    /// Returns the sum of all elements as a scalar tensor.
750    ///
751    /// On GPU, uses native CUDA reduction kernels (no CPU round-trip).
752    #[must_use]
753    pub fn sum(&self) -> Self {
754        #[cfg(feature = "cuda")]
755        if self.device().is_gpu() {
756            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
757            let self_f32 = unsafe { gpu_ref(self) };
758            let mut t = self_f32.clone();
759            while t.ndim() > 1 {
760                t = t.sum_dim_cuda(0);
761            }
762            if t.numel() > 1 {
763                t = t.sum_dim_cuda(0);
764            }
765            return unsafe { gpu_into(t) };
766        }
767
768        let data = self.to_vec();
769        let result = CpuBackend::sum(&data);
770        Self::scalar(result)
771    }
772
773    /// Returns the product of all elements.
774    ///
775    /// GPU: D2H round-trip (no CUDA prod reduction kernel yet).
776    #[must_use]
777    pub fn prod(&self) -> Self {
778        let data = self.to_vec();
779        let result = CpuBackend::prod(&data);
780        let s = Self::scalar(result);
781        #[cfg(feature = "cuda")]
782        if self.device().is_gpu() {
783            return s
784                .to_device(self.device())
785                .expect("prod: device transfer failed");
786        }
787        s
788    }
789
790    /// Returns the maximum element.
791    ///
792    /// GPU: D2H round-trip (no CUDA max reduction kernel yet).
793    pub fn max(&self) -> Result<Self> {
794        if self.is_empty() {
795            return Err(Error::EmptyTensor);
796        }
797        let data = self.to_vec();
798        let result = CpuBackend::max(&data).expect("max on non-empty tensor");
799        let s = Self::scalar(result);
800        #[cfg(feature = "cuda")]
801        if self.device().is_gpu() {
802            return Ok(s
803                .to_device(self.device())
804                .expect("max: device transfer failed"));
805        }
806        Ok(s)
807    }
808
809    /// Returns the minimum element.
810    ///
811    /// GPU: D2H round-trip (no CUDA min reduction kernel yet).
812    pub fn min(&self) -> Result<Self> {
813        if self.is_empty() {
814            return Err(Error::EmptyTensor);
815        }
816        let data = self.to_vec();
817        let result = CpuBackend::min(&data).expect("min on non-empty tensor");
818        let s = Self::scalar(result);
819        #[cfg(feature = "cuda")]
820        if self.device().is_gpu() {
821            return Ok(s
822                .to_device(self.device())
823                .expect("min: device transfer failed"));
824        }
825        Ok(s)
826    }
827
828    /// Returns the index of the maximum element.
829    pub fn argmax(&self) -> Result<usize> {
830        if self.is_empty() {
831            return Err(Error::EmptyTensor);
832        }
833        let data = self.to_vec();
834        Ok(CpuBackend::argmax(&data).unwrap())
835    }
836
837    /// Returns the index of the minimum element.
838    pub fn argmin(&self) -> Result<usize> {
839        if self.is_empty() {
840            return Err(Error::EmptyTensor);
841        }
842        let data = self.to_vec();
843        Ok(CpuBackend::argmin(&data).unwrap())
844    }
845
846    /// Concatenates tensors along a dimension.
847    ///
848    /// All tensors must have the same shape except along the cat dimension.
849    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
850        if tensors.is_empty() {
851            return Err(Error::invalid_operation("cat requires at least one tensor"));
852        }
853        let ndim = tensors[0].ndim();
854        if dim >= ndim {
855            return Err(Error::invalid_operation("cat dimension out of range"));
856        }
857
858        for t in &tensors[1..] {
859            if t.ndim() != ndim {
860                return Err(Error::invalid_operation(
861                    "cat: all tensors must have same ndim",
862                ));
863            }
864            for d in 0..ndim {
865                if d != dim && t.shape[d] != tensors[0].shape[d] {
866                    return Err(Error::invalid_operation(
867                        "cat: shapes must match on non-cat dims",
868                    ));
869                }
870            }
871        }
872
873        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
874        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
875        out_shape[dim] = total_dim_size;
876
877        let outer_size: usize = out_shape[..dim].iter().product();
878        let inner_size: usize = out_shape[dim + 1..].iter().product();
879        let total_numel: usize = out_shape.iter().product();
880        let mut result = vec![T::zero(); total_numel];
881
882        let mut dim_offset = 0;
883        for t in tensors {
884            let t_data = t.contiguous().to_vec();
885            let t_dim_size = t.shape[dim];
886            for outer in 0..outer_size {
887                for d in 0..t_dim_size {
888                    let src_base = outer * t_dim_size * inner_size + d * inner_size;
889                    let dst_base =
890                        outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
891                    result[dst_base..dst_base + inner_size]
892                        .copy_from_slice(&t_data[src_base..src_base + inner_size]);
893                }
894            }
895            dim_offset += t_dim_size;
896        }
897
898        let out = Self::from_vec(result, &out_shape)?;
899        #[cfg(feature = "cuda")]
900        if tensors[0].device().is_gpu() {
901            return Ok(out.to_device(tensors[0].device()).unwrap());
902        }
903        Ok(out)
904    }
905}
906
907// =============================================================================
908// Float Operations
909// =============================================================================
910
911impl<T: Float> Tensor<T> {
912    /// Returns the mean of all elements.
913    /// Returns the mean of all elements.
914    ///
915    /// On GPU, uses native CUDA sum reduction then divides by numel.
916    pub fn mean(&self) -> Result<Self> {
917        if self.is_empty() {
918            return Err(Error::EmptyTensor);
919        }
920        #[cfg(feature = "cuda")]
921        if self.device().is_gpu() {
922            let s = self.sum(); // uses CUDA sum_dim chain
923            let n = self.numel() as f32;
924            // mul_scalar stays on GPU
925            return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
926        }
927
928        let data = self.to_vec();
929        let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
930        Ok(Self::scalar(result))
931    }
932
933    // =========================================================================
934    // Activation Functions
935    // =========================================================================
936
937    /// Applies `ReLU` activation: max(0, x).
938    #[must_use]
939    pub fn relu(&self) -> Self {
940        #[cfg(feature = "cuda")]
941        if self.device().is_gpu() {
942            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
943            return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
944        }
945        let data = self.to_vec();
946        let mut result = vec![T::zero(); data.len()];
947        CpuBackend::relu(&mut result, &data);
948        Self::from_vec(result, &self.shape).unwrap()
949    }
950
951    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
952    #[must_use]
953    pub fn sigmoid(&self) -> Self {
954        #[cfg(feature = "cuda")]
955        if self.device().is_gpu() {
956            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
957            return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
958        }
959        let data = self.to_vec();
960        let mut result = vec![T::zero(); data.len()];
961        CpuBackend::sigmoid(&mut result, &data);
962        Self::from_vec(result, &self.shape).unwrap()
963    }
964
965    /// Applies tanh activation.
966    #[must_use]
967    pub fn tanh(&self) -> Self {
968        #[cfg(feature = "cuda")]
969        if self.device().is_gpu() {
970            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
971            return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
972        }
973        let data = self.to_vec();
974        let mut result = vec![T::zero(); data.len()];
975        CpuBackend::tanh(&mut result, &data);
976        Self::from_vec(result, &self.shape).unwrap()
977    }
978
979    /// Applies exponential function.
980    #[must_use]
981    pub fn exp(&self) -> Self {
982        #[cfg(feature = "cuda")]
983        if self.device().is_gpu() {
984            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
985            return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
986        }
987        let data = self.to_vec();
988        let mut result = vec![T::zero(); data.len()];
989        CpuBackend::exp(&mut result, &data);
990        Self::from_vec(result, &self.shape).unwrap()
991    }
992
993    /// Applies natural logarithm.
994    #[must_use]
995    pub fn ln(&self) -> Self {
996        #[cfg(feature = "cuda")]
997        if self.device().is_gpu() {
998            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
999            return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
1000        }
1001        let data = self.to_vec();
1002        let mut result = vec![T::zero(); data.len()];
1003        CpuBackend::ln(&mut result, &data);
1004        Self::from_vec(result, &self.shape).unwrap()
1005    }
1006
1007    /// Applies square root.
1008    #[must_use]
1009    pub fn sqrt(&self) -> Self {
1010        #[cfg(feature = "cuda")]
1011        if self.device().is_gpu() {
1012            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1013            return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
1014        }
1015        let data = self.to_vec();
1016        let mut result = vec![T::zero(); data.len()];
1017        CpuBackend::sqrt(&mut result, &data);
1018        Self::from_vec(result, &self.shape).unwrap()
1019    }
1020
1021    /// Computes element-wise power.
1022    #[must_use]
1023    pub fn pow(&self, exp: T) -> Self {
1024        #[cfg(feature = "cuda")]
1025        if self.device().is_gpu() {
1026            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1027            let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
1028            return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
1029        }
1030        let data = self.to_vec();
1031        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
1032        Self::from_vec(result, &self.shape).unwrap()
1033    }
1034
1035    /// GELU activation function (Gaussian Error Linear Unit).
1036    #[must_use]
1037    pub fn gelu(&self) -> Self {
1038        #[cfg(feature = "cuda")]
1039        if self.device().is_gpu() {
1040            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1041            return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
1042        }
1043        crate::ops::gelu(self)
1044    }
1045
1046    /// SiLU/Swish activation function.
1047    #[must_use]
1048    pub fn silu(&self) -> Self {
1049        #[cfg(feature = "cuda")]
1050        if self.device().is_gpu() {
1051            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1052            return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
1053        }
1054        crate::ops::silu(self)
1055    }
1056
1057    /// Softmax along specified dimension.
1058    #[must_use]
1059    pub fn softmax(&self, dim: i32) -> Self {
1060        #[cfg(feature = "cuda")]
1061        if self.device().is_gpu() {
1062            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1063            let self_f32 = unsafe { gpu_ref(self) };
1064            return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1065        }
1066        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1067    }
1068
1069    /// Log softmax along specified dimension.
1070    #[must_use]
1071    pub fn log_softmax(&self, dim: i32) -> Self {
1072        let softmax_result = self.softmax(dim);
1073        softmax_result.ln()
1074    }
1075
1076    /// Mean along a dimension.
1077    #[must_use]
1078    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1079        let ndim = self.ndim();
1080        let dim = if dim < 0 {
1081            (ndim as i32 + dim) as usize
1082        } else {
1083            dim as usize
1084        };
1085
1086        if dim >= ndim {
1087            return self.clone();
1088        }
1089
1090        // GPU fast path: sum_dim then divide by dim_size (all on GPU)
1091        #[cfg(feature = "cuda")]
1092        if self.device().is_gpu() {
1093            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1094            let self_f32 = unsafe { gpu_ref(self) };
1095            let summed = if keepdim {
1096                self_f32.sum_dim_keepdim_cuda(dim)
1097            } else {
1098                self_f32.sum_dim_cuda(dim)
1099            };
1100            let dim_size = self.shape[dim];
1101            let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1102            return unsafe { gpu_into(result) };
1103        }
1104
1105        let dim_size = self.shape[dim];
1106        let data = self.to_vec();
1107        let mut new_shape = self.shape.clone();
1108
1109        if keepdim {
1110            new_shape[dim] = 1;
1111        } else {
1112            new_shape.remove(dim);
1113        }
1114
1115        if new_shape.is_empty() {
1116            new_shape = smallvec::smallvec![1];
1117        }
1118
1119        let new_numel: usize = new_shape.iter().product();
1120        let mut result = vec![T::zero(); new_numel];
1121
1122        let outer_size: usize = self.shape[..dim].iter().product();
1123        let inner_size: usize = self.shape[dim + 1..].iter().product();
1124
1125        for outer in 0..outer_size {
1126            for inner in 0..inner_size {
1127                let mut sum = T::zero();
1128                for d in 0..dim_size {
1129                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1130                    sum = sum + data[idx];
1131                }
1132                let mean = sum / NumCast::from(dim_size).unwrap();
1133                let result_idx = outer * inner_size + inner;
1134                result[result_idx] = mean;
1135            }
1136        }
1137
1138        Self::from_vec(result, &new_shape).unwrap()
1139    }
1140
1141    /// Sum along a dimension.
1142    #[must_use]
1143    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1144        let ndim = self.ndim();
1145        let dim = if dim < 0 {
1146            (ndim as i32 + dim) as usize
1147        } else {
1148            dim as usize
1149        };
1150
1151        if dim >= ndim {
1152            return self.clone();
1153        }
1154
1155        // GPU fast path: use CUDA sum_dim kernel (no CPU copies)
1156        #[cfg(feature = "cuda")]
1157        if self.device().is_gpu() {
1158            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1159            let self_f32 = unsafe { gpu_ref(self) };
1160            let result = if keepdim {
1161                self_f32.sum_dim_keepdim_cuda(dim)
1162            } else {
1163                self_f32.sum_dim_cuda(dim)
1164            };
1165            return unsafe { gpu_into(result) };
1166        }
1167
1168        let dim_size = self.shape[dim];
1169        let data = self.to_vec();
1170        let mut new_shape = self.shape.clone();
1171
1172        if keepdim {
1173            new_shape[dim] = 1;
1174        } else {
1175            new_shape.remove(dim);
1176        }
1177
1178        if new_shape.is_empty() {
1179            new_shape = smallvec::smallvec![1];
1180        }
1181
1182        let new_numel: usize = new_shape.iter().product();
1183        let mut result = vec![T::zero(); new_numel];
1184
1185        let outer_size: usize = self.shape[..dim].iter().product();
1186        let inner_size: usize = self.shape[dim + 1..].iter().product();
1187
1188        for outer in 0..outer_size {
1189            for inner in 0..inner_size {
1190                let mut sum = T::zero();
1191                for d in 0..dim_size {
1192                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1193                    sum = sum + data[idx];
1194                }
1195                let result_idx = outer * inner_size + inner;
1196                result[result_idx] = sum;
1197            }
1198        }
1199
1200        Self::from_vec(result, &new_shape).unwrap()
1201    }
1202
1203    /// Variance along a dimension.
1204    #[must_use]
1205    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1206        // variance = E[x²] - E[x]²  (saves one full-size intermediate allocation)
1207        let mean = self.mean_dim(dim, true);
1208        let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1209        let mean_sq = sq.mean_dim(dim, keepdim);
1210        let mean_keepdim = if keepdim {
1211            mean.clone()
1212        } else {
1213            self.mean_dim(dim, keepdim)
1214        };
1215        let mean_squared = mean_keepdim
1216            .mul(&mean_keepdim)
1217            .unwrap_or_else(|_| mean_keepdim.clone());
1218        mean_sq
1219            .sub(&mean_squared)
1220            .unwrap_or_else(|_| mean_sq.clone())
1221    }
1222
1223    /// Broadcasts tensor to a new shape.
1224    #[must_use]
1225    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1226        if self.shape.as_slice() == shape {
1227            return self.clone();
1228        }
1229
1230        #[cfg(feature = "cuda")]
1231        if self.device().is_gpu() {
1232            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1233            let self_f32 = unsafe { gpu_ref(self) };
1234            return unsafe {
1235                gpu_into(
1236                    self_f32
1237                        .broadcast_to_cuda(shape)
1238                        .expect("CUDA broadcast_to failed"),
1239                )
1240            };
1241        }
1242
1243        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1244        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1245
1246        let total = numel(&result_shape);
1247        let mut result_data = vec![T::zero(); total];
1248        let self_data = self.storage.as_slice();
1249
1250        for i in 0..total {
1251            let indices = crate::shape::unravel_index(i, &result_shape);
1252            let self_idx = self.offset + linear_index(&indices, &self_strides);
1253            result_data[i] = self_data[self_idx];
1254        }
1255
1256        Self::from_vec(result_data, &result_shape).unwrap()
1257    }
1258
1259    /// Slices the tensor using ranges for each dimension.
1260    #[must_use]
1261    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1262        let mut new_shape = Vec::with_capacity(self.ndim());
1263        for (i, range) in ranges.iter().enumerate() {
1264            if i < self.ndim() {
1265                new_shape.push(range.end - range.start);
1266            }
1267        }
1268        // Keep remaining dimensions unchanged
1269        for i in ranges.len()..self.ndim() {
1270            new_shape.push(self.shape[i]);
1271        }
1272
1273        let new_numel: usize = new_shape.iter().product();
1274        let mut result_data = vec![T::zero(); new_numel];
1275        let self_data = self.to_vec();
1276
1277        // Copy data with proper indexing
1278        let mut result_idx = 0;
1279        Self::slice_recursive(
1280            &self_data,
1281            &self.shape,
1282            ranges,
1283            0,
1284            0,
1285            &mut result_data,
1286            &mut result_idx,
1287        );
1288
1289        let out = Self::from_vec(result_data, &new_shape).unwrap();
1290        #[cfg(feature = "cuda")]
1291        if self.device().is_gpu() {
1292            return out.to_device(self.device()).unwrap();
1293        }
1294        out
1295    }
1296
1297    fn slice_recursive(
1298        data: &[T],
1299        shape: &[usize],
1300        ranges: &[std::ops::Range<usize>],
1301        dim: usize,
1302        offset: usize,
1303        result: &mut [T],
1304        result_idx: &mut usize,
1305    ) {
1306        if dim == shape.len() {
1307            result[*result_idx] = data[offset];
1308            *result_idx += 1;
1309            return;
1310        }
1311
1312        let stride: usize = shape[dim + 1..].iter().product();
1313        let (start, end) = if dim < ranges.len() {
1314            (ranges[dim].start, ranges[dim].end)
1315        } else {
1316            (0, shape[dim])
1317        };
1318
1319        for i in start..end {
1320            Self::slice_recursive(
1321                data,
1322                shape,
1323                ranges,
1324                dim + 1,
1325                offset + i * stride,
1326                result,
1327                result_idx,
1328            );
1329        }
1330    }
1331}
1332
1333// =============================================================================
1334// Arithmetic Operator Implementations
1335// =============================================================================
1336
1337impl<T: Numeric> Tensor<T> {
1338    /// Element-wise addition with broadcasting.
1339    pub fn add(&self, other: &Self) -> Result<Self> {
1340        #[cfg(feature = "cuda")]
1341        {
1342            let self_gpu = self.device().is_gpu();
1343            let other_gpu = other.device().is_gpu();
1344            if self_gpu || other_gpu {
1345                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1346                if self_gpu && other_gpu {
1347                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1348                    if self.shape == other.shape {
1349                        return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1350                    } else {
1351                        return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1352                    }
1353                }
1354                // Mixed device — move to GPU, then operate
1355                let target_device = if self_gpu {
1356                    self.device()
1357                } else {
1358                    other.device()
1359                };
1360                let a_gpu = if self_gpu {
1361                    self.clone()
1362                } else {
1363                    self.to_device(target_device)?
1364                };
1365                let b_gpu = if other_gpu {
1366                    other.clone()
1367                } else {
1368                    other.to_device(target_device)?
1369                };
1370                return a_gpu.add(&b_gpu);
1371            }
1372        }
1373        // Fast path: same shape, both contiguous — no index arithmetic needed
1374        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1375            let a = self.storage.as_slice();
1376            let b = other.storage.as_slice();
1377            let ao = self.offset;
1378            let bo = other.offset;
1379            let n = numel(&self.shape);
1380            let mut result_data = vec![T::zero(); n];
1381            for i in 0..n {
1382                result_data[i] = a[ao + i] + b[bo + i];
1383            }
1384            return Self::from_vec(result_data, &self.shape);
1385        }
1386
1387        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1388        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1389        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1390
1391        let total = numel(&result_shape);
1392        let mut result_data = vec![T::zero(); total];
1393
1394        let self_data = self.storage.as_slice();
1395        let other_data = other.storage.as_slice();
1396
1397        for i in 0..total {
1398            let indices = crate::shape::unravel_index(i, &result_shape);
1399            let self_idx = self.offset + linear_index(&indices, &self_strides);
1400            let other_idx = other.offset + linear_index(&indices, &other_strides);
1401            result_data[i] = self_data[self_idx] + other_data[other_idx];
1402        }
1403
1404        Self::from_vec(result_data, &result_shape)
1405    }
1406
1407    /// Element-wise subtraction with broadcasting.
1408    pub fn sub(&self, other: &Self) -> Result<Self> {
1409        #[cfg(feature = "cuda")]
1410        {
1411            let self_gpu = self.device().is_gpu();
1412            let other_gpu = other.device().is_gpu();
1413            if self_gpu || other_gpu {
1414                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1415                if self_gpu && other_gpu {
1416                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1417                    if self.shape == other.shape {
1418                        return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1419                    } else {
1420                        return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1421                    }
1422                }
1423                let target = if self_gpu {
1424                    self.device()
1425                } else {
1426                    other.device()
1427                };
1428                let a_gpu = if self_gpu {
1429                    self.clone()
1430                } else {
1431                    self.to_device(target)?
1432                };
1433                let b_gpu = if other_gpu {
1434                    other.clone()
1435                } else {
1436                    other.to_device(target)?
1437                };
1438                return a_gpu.sub(&b_gpu);
1439            }
1440        }
1441        // Fast path: same shape, contiguous
1442        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1443            let a = self.storage.as_slice();
1444            let b = other.storage.as_slice();
1445            let (ao, bo) = (self.offset, other.offset);
1446            let n = numel(&self.shape);
1447            let mut r = vec![T::zero(); n];
1448            for i in 0..n {
1449                r[i] = a[ao + i] - b[bo + i];
1450            }
1451            return Self::from_vec(r, &self.shape);
1452        }
1453
1454        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1455        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1456        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1457
1458        let total = numel(&result_shape);
1459        let mut result_data = vec![T::zero(); total];
1460
1461        let self_data = self.storage.as_slice();
1462        let other_data = other.storage.as_slice();
1463
1464        for i in 0..total {
1465            let indices = crate::shape::unravel_index(i, &result_shape);
1466            let self_idx = self.offset + linear_index(&indices, &self_strides);
1467            let other_idx = other.offset + linear_index(&indices, &other_strides);
1468            result_data[i] = self_data[self_idx] - other_data[other_idx];
1469        }
1470
1471        Self::from_vec(result_data, &result_shape)
1472    }
1473
1474    /// Element-wise multiplication with broadcasting.
1475    pub fn mul(&self, other: &Self) -> Result<Self> {
1476        #[cfg(feature = "cuda")]
1477        {
1478            let self_gpu = self.device().is_gpu();
1479            let other_gpu = other.device().is_gpu();
1480            if self_gpu || other_gpu {
1481                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1482                if self_gpu && other_gpu {
1483                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1484                    if self.shape == other.shape {
1485                        return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1486                    } else {
1487                        return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1488                    }
1489                }
1490                let target = if self_gpu {
1491                    self.device()
1492                } else {
1493                    other.device()
1494                };
1495                let a_gpu = if self_gpu {
1496                    self.clone()
1497                } else {
1498                    self.to_device(target)?
1499                };
1500                let b_gpu = if other_gpu {
1501                    other.clone()
1502                } else {
1503                    other.to_device(target)?
1504                };
1505                return a_gpu.mul(&b_gpu);
1506            }
1507        }
1508        // Fast path: same shape, contiguous
1509        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1510            let a = self.storage.as_slice();
1511            let b = other.storage.as_slice();
1512            let (ao, bo) = (self.offset, other.offset);
1513            let n = numel(&self.shape);
1514            let mut r = vec![T::zero(); n];
1515            for i in 0..n {
1516                r[i] = a[ao + i] * b[bo + i];
1517            }
1518            return Self::from_vec(r, &self.shape);
1519        }
1520
1521        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1522        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1523        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1524
1525        let total = numel(&result_shape);
1526        let mut result_data = vec![T::zero(); total];
1527
1528        let self_data = self.storage.as_slice();
1529        let other_data = other.storage.as_slice();
1530
1531        for i in 0..total {
1532            let indices = crate::shape::unravel_index(i, &result_shape);
1533            let self_idx = self.offset + linear_index(&indices, &self_strides);
1534            let other_idx = other.offset + linear_index(&indices, &other_strides);
1535            result_data[i] = self_data[self_idx] * other_data[other_idx];
1536        }
1537
1538        Self::from_vec(result_data, &result_shape)
1539    }
1540
1541    /// Element-wise division with broadcasting.
1542    pub fn div(&self, other: &Self) -> Result<Self> {
1543        #[cfg(feature = "cuda")]
1544        {
1545            let self_gpu = self.device().is_gpu();
1546            let other_gpu = other.device().is_gpu();
1547            if self_gpu || other_gpu {
1548                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1549                if self_gpu && other_gpu {
1550                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1551                    if self.shape == other.shape {
1552                        return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1553                    } else {
1554                        return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1555                    }
1556                }
1557                let target = if self_gpu {
1558                    self.device()
1559                } else {
1560                    other.device()
1561                };
1562                let a_gpu = if self_gpu {
1563                    self.clone()
1564                } else {
1565                    self.to_device(target)?
1566                };
1567                let b_gpu = if other_gpu {
1568                    other.clone()
1569                } else {
1570                    other.to_device(target)?
1571                };
1572                return a_gpu.div(&b_gpu);
1573            }
1574        }
1575        // Fast path: same shape, contiguous
1576        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1577            let a = self.storage.as_slice();
1578            let b = other.storage.as_slice();
1579            let (ao, bo) = (self.offset, other.offset);
1580            let n = numel(&self.shape);
1581            let mut r = vec![T::zero(); n];
1582            for i in 0..n {
1583                r[i] = a[ao + i] / b[bo + i];
1584            }
1585            return Self::from_vec(r, &self.shape);
1586        }
1587
1588        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1589        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1590        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1591
1592        let total = numel(&result_shape);
1593        let mut result_data = vec![T::zero(); total];
1594
1595        let self_data = self.storage.as_slice();
1596        let other_data = other.storage.as_slice();
1597
1598        for i in 0..total {
1599            let indices = crate::shape::unravel_index(i, &result_shape);
1600            let self_idx = self.offset + linear_index(&indices, &self_strides);
1601            let other_idx = other.offset + linear_index(&indices, &other_strides);
1602            result_data[i] = self_data[self_idx] / other_data[other_idx];
1603        }
1604
1605        Self::from_vec(result_data, &result_shape)
1606    }
1607
1608    /// Scalar addition.
1609    #[must_use]
1610    pub fn add_scalar(&self, scalar: T) -> Self {
1611        #[cfg(feature = "cuda")]
1612        if self.device().is_gpu() {
1613            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1614            let self_f32 = unsafe { gpu_ref(self) };
1615            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1616            return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1617        }
1618        let data = self.to_vec();
1619        let mut result = vec![T::zero(); data.len()];
1620        CpuBackend::add_scalar(&mut result, &data, scalar);
1621        Self::from_vec(result, &self.shape).unwrap()
1622    }
1623
1624    /// Scalar multiplication.
1625    #[must_use]
1626    pub fn mul_scalar(&self, scalar: T) -> Self {
1627        #[cfg(feature = "cuda")]
1628        if self.device().is_gpu() {
1629            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1630            let self_f32 = unsafe { gpu_ref(self) };
1631            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1632            return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1633        }
1634        let data = self.to_vec();
1635        let mut result = vec![T::zero(); data.len()];
1636        CpuBackend::mul_scalar(&mut result, &data, scalar);
1637        Self::from_vec(result, &self.shape).unwrap()
1638    }
1639
1640    /// Element-wise negation.
1641    #[must_use]
1642    pub fn neg(&self) -> Self {
1643        #[cfg(feature = "cuda")]
1644        if self.device().is_gpu() {
1645            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1646            let self_f32 = unsafe { gpu_ref(self) };
1647            return unsafe { gpu_into(self_f32.neg_cuda()) };
1648        }
1649        let data = self.to_vec();
1650        let mut result = vec![T::zero(); data.len()];
1651        CpuBackend::neg(&mut result, &data);
1652        Self::from_vec(result, &self.shape).unwrap()
1653    }
1654
1655    /// Matrix multiplication with batching support.
1656    ///
1657    /// Supports:
1658    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1659    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1660    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1661    pub fn matmul(&self, other: &Self) -> Result<Self> {
1662        #[cfg(feature = "cuda")]
1663        if self.device().is_gpu() {
1664            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1665            let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1666            return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1667        }
1668        if self.ndim() < 2 || other.ndim() < 2 {
1669            return Err(Error::invalid_operation(
1670                "matmul requires at least 2D tensors",
1671            ));
1672        }
1673
1674        let m = self.shape[self.ndim() - 2];
1675        let k1 = self.shape[self.ndim() - 1];
1676        let k2 = other.shape[other.ndim() - 2];
1677        let n = other.shape[other.ndim() - 1];
1678
1679        if k1 != k2 {
1680            return Err(Error::invalid_operation(format!(
1681                "matmul inner dimensions must match: {k1} vs {k2}"
1682            )));
1683        }
1684
1685        // For 2D matrices, do simple matmul
1686        if self.ndim() == 2 && other.ndim() == 2 {
1687            let a_data = self.contiguous().to_vec();
1688            let b_data = other.contiguous().to_vec();
1689
1690            // GPU-accelerated matmul for CPU tensors: only for very large matrices
1691            // where transfer overhead is negligible relative to compute.
1692            // For GPU-resident tensors, the dispatch at the top of matmul() handles it.
1693            #[cfg(feature = "cuda")]
1694            {
1695                let flops = m * n * k1;
1696                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1697                    && flops >= 4_000_000
1698                {
1699                    debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1700                    // SAFETY: T is f32 (checked by TypeId above), same size and layout
1701                    let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1702                    let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1703                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1704                        // SAFETY: T is f32, Vec<f32> → Vec<T> is a no-op transmute
1705                        let c_t: Vec<T> = unsafe {
1706                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1707                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1708                        };
1709                        return Self::from_vec(c_t, &[m, n]);
1710                    }
1711                }
1712            }
1713
1714            let mut c_data = vec![T::zero(); m * n];
1715            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1716            return Self::from_vec(c_data, &[m, n]);
1717        }
1718
1719        // For batched matmul, compute batch size
1720        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1721        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1722
1723        // Broadcast batch dimensions (PyTorch parity)
1724        let broadcast_batch = if batch_dims_self == batch_dims_other {
1725            None
1726        } else {
1727            // Pad to same length
1728            let max_len = batch_dims_self.len().max(batch_dims_other.len());
1729            let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1730            let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1731            let a_dims: Vec<usize> = pad_a
1732                .iter()
1733                .chain(batch_dims_self.iter())
1734                .copied()
1735                .collect();
1736            let b_dims: Vec<usize> = pad_b
1737                .iter()
1738                .chain(batch_dims_other.iter())
1739                .copied()
1740                .collect();
1741
1742            let mut out_dims = Vec::with_capacity(max_len);
1743            for i in 0..max_len {
1744                if a_dims[i] == b_dims[i] {
1745                    out_dims.push(a_dims[i]);
1746                } else if a_dims[i] == 1 {
1747                    out_dims.push(b_dims[i]);
1748                } else if b_dims[i] == 1 {
1749                    out_dims.push(a_dims[i]);
1750                } else {
1751                    return Err(Error::invalid_operation(format!(
1752                        "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1753                        batch_dims_self, batch_dims_other
1754                    )));
1755                }
1756            }
1757            Some((a_dims, b_dims, out_dims))
1758        };
1759
1760        let (batch_size, a_batch_idx, b_batch_idx) =
1761            if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1762                let bs: usize = out_dims.iter().product();
1763                // Build index mapping: for each output batch, which a and b batch to use
1764                let mut a_idx = Vec::with_capacity(bs);
1765                let mut b_idx = Vec::with_capacity(bs);
1766                for flat in 0..bs {
1767                    let mut remaining = flat;
1768                    let mut ai = 0usize;
1769                    let mut bi = 0usize;
1770                    let mut a_stride_acc = 1usize;
1771                    let mut b_stride_acc = 1usize;
1772                    for d in (0..out_dims.len()).rev() {
1773                        let out_d = out_dims[d];
1774                        let idx = remaining % out_d;
1775                        remaining /= out_d;
1776                        let a_d = a_dims[d];
1777                        let b_d = b_dims[d];
1778                        ai += (idx % a_d) * a_stride_acc;
1779                        bi += (idx % b_d) * b_stride_acc;
1780                        a_stride_acc *= a_d;
1781                        b_stride_acc *= b_d;
1782                    }
1783                    a_idx.push(ai);
1784                    b_idx.push(bi);
1785                }
1786                (bs, a_idx, b_idx)
1787            } else {
1788                let bs: usize = batch_dims_self.iter().product();
1789                let idx: Vec<usize> = (0..bs).collect();
1790                (bs, idx.clone(), idx)
1791            };
1792
1793        let a_stride = m * k1;
1794        let b_stride = k1 * n;
1795        let c_stride = m * n;
1796
1797        let a_data = self.contiguous().to_vec();
1798        let b_data = other.contiguous().to_vec();
1799        let mut c_data = vec![T::zero(); batch_size * m * n];
1800
1801        // Try GPU acceleration for f32 batched matmul (only for large enough matrices)
1802        #[cfg(feature = "cuda")]
1803        {
1804            let flops = m * n * k1;
1805            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1806                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1807                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1808                let mut gpu_ok = true;
1809                for batch in 0..batch_size {
1810                    let ai = a_batch_idx[batch];
1811                    let bi = b_batch_idx[batch];
1812                    let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1813                    let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1814                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1815                        c_data[batch * c_stride..(batch + 1) * c_stride]
1816                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1817                    } else {
1818                        gpu_ok = false;
1819                        break;
1820                    }
1821                }
1822                if gpu_ok {
1823                    let mut output_shape = batch_dims_self;
1824                    output_shape.push(m);
1825                    output_shape.push(n);
1826                    return Self::from_vec(c_data, &output_shape);
1827                }
1828                // Fall through to CPU if GPU failed
1829                c_data = vec![T::zero(); batch_size * m * n];
1830            }
1831        }
1832
1833        // CPU fallback: loop over batches
1834        for batch in 0..batch_size {
1835            let ai = a_batch_idx[batch];
1836            let bi = b_batch_idx[batch];
1837            let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1838            let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1839            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1840            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1841        }
1842
1843        // Build output shape: broadcast batch dims + [m, n]
1844        let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1845            out_dims.clone()
1846        } else {
1847            batch_dims_self
1848        };
1849        output_shape.push(m);
1850        output_shape.push(n);
1851
1852        Self::from_vec(c_data, &output_shape)
1853    }
1854
1855    /// Dot product for 1D tensors.
1856    pub fn dot(&self, other: &Self) -> Result<Self> {
1857        if self.ndim() != 1 || other.ndim() != 1 {
1858            return Err(Error::invalid_operation("dot requires 1D tensors"));
1859        }
1860
1861        if self.shape[0] != other.shape[0] {
1862            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1863        }
1864
1865        let a_data = self.to_vec();
1866        let b_data = other.to_vec();
1867        let result = CpuBackend::dot(&a_data, &b_data);
1868
1869        Ok(Self::scalar(result))
1870    }
1871}
1872
1873// =============================================================================
1874// Operator Trait Implementations
1875// =============================================================================
1876
1877impl<T: Numeric> Add for &Tensor<T> {
1878    type Output = Tensor<T>;
1879
1880    fn add(self, other: Self) -> Self::Output {
1881        self.add(other).expect("Addition failed")
1882    }
1883}
1884
1885impl<T: Numeric> Sub for &Tensor<T> {
1886    type Output = Tensor<T>;
1887
1888    fn sub(self, other: Self) -> Self::Output {
1889        self.sub(other).expect("Subtraction failed")
1890    }
1891}
1892
1893impl<T: Numeric> Mul for &Tensor<T> {
1894    type Output = Tensor<T>;
1895
1896    fn mul(self, other: Self) -> Self::Output {
1897        self.mul(other).expect("Multiplication failed")
1898    }
1899}
1900
1901impl<T: Numeric> Div for &Tensor<T> {
1902    type Output = Tensor<T>;
1903
1904    fn div(self, other: Self) -> Self::Output {
1905        self.div(other).expect("Division failed")
1906    }
1907}
1908
1909impl<T: Numeric> Neg for &Tensor<T> {
1910    type Output = Tensor<T>;
1911
1912    fn neg(self) -> Self::Output {
1913        self.neg()
1914    }
1915}
1916
1917// Scalar operations
1918impl<T: Numeric> Add<T> for &Tensor<T> {
1919    type Output = Tensor<T>;
1920
1921    fn add(self, scalar: T) -> Self::Output {
1922        self.add_scalar(scalar)
1923    }
1924}
1925
1926impl<T: Numeric> Mul<T> for &Tensor<T> {
1927    type Output = Tensor<T>;
1928
1929    fn mul(self, scalar: T) -> Self::Output {
1930        self.mul_scalar(scalar)
1931    }
1932}
1933
1934// =============================================================================
1935// Display Implementation
1936// =============================================================================
1937
1938impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1939    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1940        write!(
1941            f,
1942            "Tensor(shape={:?}, device={}",
1943            self.shape(),
1944            self.device()
1945        )?;
1946        if self.numel() <= 10 {
1947            write!(f, ", data={:?}", self.to_vec())?;
1948        }
1949        write!(f, ")")
1950    }
1951}
1952
1953impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1954    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1955        if self.is_scalar() {
1956            write!(f, "{}", self.item().unwrap())
1957        } else if self.ndim() == 1 {
1958            write!(f, "[")?;
1959            let data = self.to_vec();
1960            for (i, val) in data.iter().enumerate() {
1961                if i > 0 {
1962                    write!(f, ", ")?;
1963                }
1964                write!(f, "{val}")?;
1965            }
1966            write!(f, "]")
1967        } else {
1968            write!(f, "Tensor(shape={:?})", self.shape())
1969        }
1970    }
1971}
1972
1973// =============================================================================
1974// f32 ↔ f16 Casting for AMP (Automatic Mixed Precision)
1975// =============================================================================
1976
1977impl Tensor<f32> {
1978    /// Cast this f32 tensor to f16 values stored as f32.
1979    ///
1980    /// Each value is rounded to f16 precision. This simulates half-precision
1981    /// computation while keeping the tensor type as f32, which is how AMP
1982    /// works — the autograd graph stays f32 but computation uses f16 precision.
1983    ///
1984    /// On GPU, this uses a CUDA kernel for fast conversion.
1985    /// On CPU, this uses the `half` crate.
1986    #[must_use]
1987    pub fn to_f16_precision(&self) -> Self {
1988        let data = self.to_vec();
1989        let f16_data: Vec<f32> = data
1990            .iter()
1991            .map(|&v| {
1992                let h = half::f16::from_f32(v);
1993                h.to_f32()
1994            })
1995            .collect();
1996        Self::from_vec(f16_data, self.shape()).unwrap()
1997    }
1998
1999    /// Cast f16-precision values back to full f32 precision.
2000    ///
2001    /// This is a no-op since the data is already stored as f32.
2002    /// Included for API symmetry with `to_f16_precision()`.
2003    #[must_use]
2004    pub fn to_f32_precision(&self) -> Self {
2005        self.clone()
2006    }
2007
2008    /// Returns true if applying f16 precision would change any values.
2009    /// Useful for debugging AMP-related numerical issues.
2010    #[must_use]
2011    pub fn has_f16_rounding_error(&self) -> bool {
2012        let data = self.to_vec();
2013        data.iter().any(|&v| {
2014            let h = half::f16::from_f32(v);
2015            (h.to_f32() - v).abs() > f32::EPSILON
2016        })
2017    }
2018}
2019
2020// =============================================================================
2021// Tests
2022// =============================================================================
2023
2024#[cfg(test)]
2025mod tests {
2026    use super::*;
2027
2028    #[test]
2029    fn test_from_vec() {
2030        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2031        assert_eq!(t.shape(), &[2, 3]);
2032        assert_eq!(t.numel(), 6);
2033    }
2034
2035    #[test]
2036    fn test_get_set() {
2037        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2038        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
2039        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
2040        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
2041        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
2042
2043        t.set(&[0, 0], 99.0).unwrap();
2044        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
2045    }
2046
2047    #[test]
2048    fn test_reshape() {
2049        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2050        let r = t.reshape(&[3, 2]).expect("reshape failed");
2051        assert_eq!(r.shape(), &[3, 2]);
2052
2053        let r = t.reshape(&[-1]).expect("reshape failed");
2054        assert_eq!(r.shape(), &[6]);
2055    }
2056
2057    #[test]
2058    fn test_transpose() {
2059        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2060        let r = t.t().unwrap();
2061        assert_eq!(r.shape(), &[3, 2]);
2062        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
2063        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
2064        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
2065    }
2066
2067    #[test]
2068    fn test_arithmetic() {
2069        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2070        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
2071
2072        let c = &a + &b;
2073        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
2074
2075        let d = &a * &b;
2076        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
2077    }
2078
2079    #[test]
2080    fn test_broadcasting() {
2081        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2082        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
2083
2084        let c = &a + &b;
2085        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
2086    }
2087
2088    #[test]
2089    fn test_sum() {
2090        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
2091        let s = t.sum();
2092        assert_eq!(s.item().unwrap(), 10.0);
2093    }
2094
2095    #[test]
2096    fn test_matmul() {
2097        // 2x2 @ 2x2
2098        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2099        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
2100        let c = a.matmul(&b).unwrap();
2101
2102        assert_eq!(c.shape(), &[2, 2]);
2103        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
2104    }
2105
2106    #[test]
2107    fn test_relu() {
2108        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
2109        let r = t.relu();
2110        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
2111    }
2112
2113    #[test]
2114    fn test_scalar() {
2115        let s = Tensor::<f32>::scalar(42.0);
2116        assert!(s.is_scalar());
2117        assert_eq!(s.numel(), 1);
2118        assert_eq!(s.item().unwrap(), 42.0);
2119    }
2120}