Skip to main content

axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! # File
4//! `crates/axonml-tensor/src/tensor.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29// =============================================================================
30// CUDA Acceleration
31// =============================================================================
32
33#[cfg(feature = "cuda")]
34mod cuda_accel {
35    use super::*;
36    use axonml_core::backends::cuda::get_cuda_backend;
37
38    /// Get the global CUDA backend (delegates to core singleton).
39    pub fn get_cuda() -> Option<&'static CudaBackend> {
40        get_cuda_backend()
41    }
42
43    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
44    /// Returns None if GPU is unavailable or an error occurs.
45    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46        let cuda = get_cuda()?;
47
48        let a_gpu = cuda.htod_copy(a).ok()?;
49        let b_gpu = cuda.htod_copy(b).ok()?;
50        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
53        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
54        cuda.gemm_f32(
55            false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56        )
57        .ok()?;
58
59        cuda.dtoh_copy(&c_gpu).ok()
60    }
61}
62
63use crate::shape::{
64    Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65    linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66    unsqueeze,
67};
68
69// =============================================================================
70// GPU Dispatch Helpers
71// =============================================================================
72//
73// These enable calling Tensor<f32> GPU methods from generic Tensor<T> code
74// when T is verified to be f32 via TypeId check at runtime.
75
76#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78    assert!(
79        is_f32::<T>(),
80        "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
81        T::DTYPE
82    );
83    // SAFETY: T is f32 (asserted above), Tensor<f32> and Tensor<T> have identical layout
84    unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
85}
86
87#[cfg(feature = "cuda")]
88unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
89    assert!(
90        is_f32::<T>(),
91        "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
92        T::DTYPE
93    );
94    // SAFETY: T is f32 (asserted above), ownership transfer via ptr::read + forget
95    unsafe {
96        let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
97        std::mem::forget(t);
98        out
99    }
100}
101
102#[cfg(feature = "cuda")]
103fn is_f32<T: 'static>() -> bool {
104    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
105}
106
107// =============================================================================
108// Tensor Struct
109// =============================================================================
110
111/// An N-dimensional array of numeric values.
112///
113/// Tensors are the core data structure for all computations in Axonml.
114/// They support arbitrary dimensions, automatic broadcasting, and efficient
115/// memory sharing between views.
116#[derive(Clone)]
117pub struct Tensor<T: Scalar> {
118    /// Underlying data storage (reference-counted).
119    pub(crate) storage: Storage<T>,
120    /// Shape of the tensor (dimensions).
121    pub(crate) shape: Shape,
122    /// Strides for each dimension.
123    pub(crate) strides: Strides,
124    /// Offset into storage (for views).
125    pub(crate) offset: usize,
126}
127
128impl<T: Scalar> Tensor<T> {
129    // =========================================================================
130    // Constructors
131    // =========================================================================
132
133    /// Creates a new tensor from storage with the given shape.
134    ///
135    /// # Arguments
136    /// * `storage` - The underlying data storage
137    /// * `shape` - Shape of the tensor
138    ///
139    /// # Returns
140    /// New tensor, or error if shape doesn't match storage size.
141    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
142        let total = numel(shape);
143        if total != storage.len() {
144            return Err(Error::shape_mismatch(&[storage.len()], shape));
145        }
146
147        let shape = Shape::from_slice(shape);
148        let strides = contiguous_strides(&shape);
149
150        Ok(Self {
151            storage,
152            shape,
153            strides,
154            offset: 0,
155        })
156    }
157
158    /// Creates a new tensor from a vector with the given shape.
159    ///
160    /// # Arguments
161    /// * `data` - Vector of data
162    /// * `shape` - Shape of the tensor
163    ///
164    /// # Returns
165    /// New tensor, or error if shape doesn't match data length.
166    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
167        let storage = Storage::from_vec(data, Device::Cpu);
168        Self::from_storage(storage, shape)
169    }
170
171    /// Creates a new tensor from a slice with the given shape.
172    ///
173    /// # Arguments
174    /// * `data` - Slice of data to copy
175    /// * `shape` - Shape of the tensor
176    ///
177    /// # Returns
178    /// New tensor, or error if shape doesn't match data length.
179    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
180        let storage = Storage::from_slice(data, Device::Cpu);
181        Self::from_storage(storage, shape)
182    }
183
184    /// Creates a scalar tensor (0-dimensional).
185    ///
186    /// # Arguments
187    /// * `value` - The scalar value
188    ///
189    /// # Returns
190    /// New 0-dimensional tensor.
191    pub fn scalar(value: T) -> Self {
192        Self {
193            storage: Storage::from_vec(vec![value], Device::Cpu),
194            shape: Shape::new(),
195            strides: Strides::new(),
196            offset: 0,
197        }
198    }
199
200    /// Creates a tensor filled with zeros.
201    #[must_use]
202    pub fn zeros(shape: &[usize]) -> Self {
203        crate::creation::zeros(shape)
204    }
205
206    /// Creates a tensor filled with ones.
207    #[must_use]
208    pub fn ones(shape: &[usize]) -> Self
209    where
210        T: Numeric,
211    {
212        crate::creation::ones(shape)
213    }
214
215    /// Creates a tensor filled with a constant value.
216    #[must_use]
217    pub fn full(shape: &[usize], value: T) -> Self {
218        crate::creation::full(shape, value)
219    }
220
221    /// Creates a tensor with random values from standard normal distribution.
222    #[must_use]
223    pub fn randn(shape: &[usize]) -> Self
224    where
225        T: Float,
226        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
227    {
228        crate::creation::randn(shape)
229    }
230
231    /// Creates a tensor with random values from uniform distribution [0, 1).
232    #[must_use]
233    pub fn rand(shape: &[usize]) -> Self
234    where
235        T: Float,
236        rand::distributions::Standard: rand::distributions::Distribution<T>,
237    {
238        crate::creation::rand(shape)
239    }
240
241    // =========================================================================
242    // Properties
243    // =========================================================================
244
245    /// Returns the shape of the tensor.
246    #[must_use]
247    pub fn shape(&self) -> &[usize] {
248        &self.shape
249    }
250
251    /// Returns the strides of the tensor.
252    #[must_use]
253    pub fn strides(&self) -> &[isize] {
254        &self.strides
255    }
256
257    /// Returns the number of dimensions.
258    #[must_use]
259    pub fn ndim(&self) -> usize {
260        self.shape.len()
261    }
262
263    /// Returns the total number of elements.
264    #[must_use]
265    pub fn numel(&self) -> usize {
266        numel(&self.shape)
267    }
268
269    /// Returns true if the tensor is empty (has zero elements).
270    #[must_use]
271    pub fn is_empty(&self) -> bool {
272        self.numel() == 0
273    }
274
275    /// Returns the size of a specific dimension.
276    ///
277    /// # Arguments
278    /// * `dim` - Dimension index (supports negative indexing)
279    pub fn size(&self, dim: i64) -> Result<usize> {
280        let idx = normalize_dim(dim, self.ndim())?;
281        Ok(self.shape[idx])
282    }
283
284    /// Returns the device this tensor is on.
285    #[must_use]
286    pub fn device(&self) -> Device {
287        self.storage.device()
288    }
289
290    /// Returns true if the tensor is contiguous in memory.
291    #[must_use]
292    pub fn is_contiguous(&self) -> bool {
293        is_contiguous(&self.shape, &self.strides)
294    }
295
296    /// Returns true if this tensor is a scalar (0-dimensional).
297    #[must_use]
298    pub fn is_scalar(&self) -> bool {
299        self.shape.is_empty()
300    }
301
302    // =========================================================================
303    // Data Access
304    // =========================================================================
305
306    /// Returns the element at the given indices.
307    ///
308    /// # Arguments
309    /// * `indices` - Multi-dimensional indices
310    pub fn get(&self, indices: &[usize]) -> Result<T> {
311        if indices.len() != self.ndim() {
312            return Err(Error::invalid_operation(format!(
313                "Expected {} indices, got {}",
314                self.ndim(),
315                indices.len()
316            )));
317        }
318
319        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
320            if idx >= dim {
321                return Err(Error::IndexOutOfBounds {
322                    index: idx,
323                    size: dim,
324                });
325            }
326        }
327
328        let offset = self.offset + linear_index(indices, &self.strides);
329        Ok(self.storage.as_slice()[offset])
330    }
331
332    /// Sets the element at the given indices.
333    ///
334    /// # Arguments
335    /// * `indices` - Multi-dimensional indices
336    /// * `value` - Value to set
337    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
338        if indices.len() != self.ndim() {
339            return Err(Error::invalid_operation(format!(
340                "Expected {} indices, got {}",
341                self.ndim(),
342                indices.len()
343            )));
344        }
345
346        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
347            if idx >= dim {
348                return Err(Error::IndexOutOfBounds {
349                    index: idx,
350                    size: dim,
351                });
352            }
353        }
354
355        let offset = self.offset + linear_index(indices, &self.strides);
356        self.storage.as_slice_mut()[offset] = value;
357        Ok(())
358    }
359
360    /// Returns the scalar value for a 0-dimensional tensor.
361    pub fn item(&self) -> Result<T> {
362        if self.numel() != 1 {
363            return Err(Error::invalid_operation(
364                "item() only works on single-element tensors",
365            ));
366        }
367
368        if self.is_scalar() {
369            Ok(self.storage.as_slice()[self.offset])
370        } else {
371            // Single element but not scalar shape
372            let indices = vec![0; self.ndim()];
373            self.get(&indices)
374        }
375    }
376
377    /// Returns the data as a contiguous vector.
378    ///
379    /// If the tensor is already contiguous, this returns a reference.
380    /// Otherwise, it copies the data into a new contiguous vector.
381    /// For GPU tensors (f32 only), performs a D2H copy.
382    #[must_use]
383    pub fn to_vec(&self) -> Vec<T> {
384        // GPU path: GPU storage is always f32
385        #[cfg(feature = "cuda")]
386        if self.storage.is_gpu() {
387            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
388            let self_f32 = unsafe { gpu_ref(self) };
389            let f32_vec = self_f32.to_vec_gpu();
390            unsafe {
391                let mut v = std::mem::ManuallyDrop::new(f32_vec);
392                return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
393            }
394        }
395
396        if self.is_contiguous() {
397            let storage = self.storage.as_slice();
398            storage[self.offset..self.offset + self.numel()].to_vec()
399        } else {
400            let mut result = Vec::with_capacity(self.numel());
401            self.copy_data_to(&mut result);
402            result
403        }
404    }
405
406    /// Copies data to a slice, handling non-contiguous layouts.
407    fn copy_data_to(&self, dst: &mut Vec<T>) {
408        dst.clear();
409        let storage = self.storage.as_slice();
410
411        // Iterate through all indices
412        let total = self.numel();
413        for i in 0..total {
414            let indices = crate::shape::unravel_index(i, &self.shape);
415            let offset = self.offset + linear_index(&indices, &self.strides);
416            dst.push(storage[offset]);
417        }
418    }
419
420    // =========================================================================
421    // Shape Operations
422    // =========================================================================
423
424    /// Returns a new tensor with the specified shape.
425    ///
426    /// The total number of elements must remain the same.
427    /// Supports -1 in one dimension to infer the size.
428    ///
429    /// # Arguments
430    /// * `new_shape` - Target shape
431    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
432        let shape = reshape(&self.shape, new_shape)?;
433
434        if self.is_contiguous() {
435            // Can just change shape without copying
436            Ok(Self {
437                storage: self.storage.clone(),
438                strides: contiguous_strides(&shape),
439                shape,
440                offset: self.offset,
441            })
442        } else {
443            // Need to make contiguous first
444            let contig = self.contiguous();
445            Ok(Self {
446                storage: contig.storage,
447                strides: contiguous_strides(&shape),
448                shape,
449                offset: 0,
450            })
451        }
452    }
453
454    /// Returns a new tensor with a flattened shape.
455    #[must_use]
456    pub fn flatten(&self) -> Self {
457        self.reshape(&[-1]).expect("Flatten should never fail")
458    }
459
460    /// Returns a new tensor with dimensions of size 1 removed.
461    ///
462    /// # Arguments
463    /// * `dim` - Optional specific dimension to squeeze
464    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
465        let dim = match dim {
466            Some(d) => Some(normalize_dim(d, self.ndim())?),
467            None => None,
468        };
469
470        let new_shape = squeeze(&self.shape, dim);
471        let new_strides: Strides = match dim {
472            Some(d) => {
473                let mut s = self.strides.clone();
474                if d < self.shape.len() && self.shape[d] == 1 {
475                    s.remove(d);
476                }
477                s
478            }
479            None => self
480                .shape
481                .iter()
482                .zip(self.strides.iter())
483                .filter(|(dim, _)| **dim != 1)
484                .map(|(_, stride)| *stride)
485                .collect(),
486        };
487
488        Ok(Self {
489            storage: self.storage.clone(),
490            shape: new_shape,
491            strides: new_strides,
492            offset: self.offset,
493        })
494    }
495
496    /// Returns a new tensor with a dimension of size 1 inserted.
497    ///
498    /// # Arguments
499    /// * `dim` - Position to insert the new dimension
500    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
501        let normalized = if dim < 0 {
502            (dim + self.ndim() as i64 + 1) as usize
503        } else {
504            dim as usize
505        };
506
507        let new_shape = unsqueeze(&self.shape, normalized)?;
508        let mut new_strides = Strides::with_capacity(new_shape.len());
509
510        for (i, _) in new_shape.iter().enumerate() {
511            if i < normalized {
512                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
513            } else if i == normalized {
514                // Stride for new dimension (doesn't matter since size is 1)
515                new_strides.push(1);
516            } else {
517                new_strides.push(self.strides[i - 1]);
518            }
519        }
520
521        Ok(Self {
522            storage: self.storage.clone(),
523            shape: new_shape,
524            strides: new_strides,
525            offset: self.offset,
526        })
527    }
528
529    /// Transposes two dimensions.
530    ///
531    /// # Arguments
532    /// * `dim0` - First dimension
533    /// * `dim1` - Second dimension
534    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
535        let d0 = normalize_dim(dim0, self.ndim())?;
536        let d1 = normalize_dim(dim1, self.ndim())?;
537
538        let new_shape = transpose_shape(&self.shape, d0, d1)?;
539        let new_strides = transpose_strides(&self.strides, d0, d1);
540
541        Ok(Self {
542            storage: self.storage.clone(),
543            shape: new_shape,
544            strides: new_strides,
545            offset: self.offset,
546        })
547    }
548
549    /// Returns the transpose of a 2D tensor.
550    pub fn t(&self) -> Result<Self> {
551        if self.ndim() != 2 {
552            return Err(Error::invalid_operation("t() only works on 2D tensors"));
553        }
554        self.transpose(0, 1)
555    }
556
557    /// Returns a permuted tensor with dimensions reordered.
558    ///
559    /// # Arguments
560    /// * `dims` - New order of dimensions
561    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
562        if dims.len() != self.ndim() {
563            return Err(Error::invalid_operation(format!(
564                "Expected {} dimensions, got {}",
565                self.ndim(),
566                dims.len()
567            )));
568        }
569
570        // Check that dims is a permutation
571        let mut seen = vec![false; self.ndim()];
572        for &d in dims {
573            if d >= self.ndim() {
574                return Err(Error::InvalidDimension {
575                    index: d as i64,
576                    ndim: self.ndim(),
577                });
578            }
579            if seen[d] {
580                return Err(Error::invalid_operation("Duplicate dimension in permute"));
581            }
582            seen[d] = true;
583        }
584
585        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
586        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
587
588        Ok(Self {
589            storage: self.storage.clone(),
590            shape: new_shape,
591            strides: new_strides,
592            offset: self.offset,
593        })
594    }
595
596    /// Returns a contiguous copy of the tensor.
597    #[must_use]
598    pub fn contiguous(&self) -> Self {
599        if self.is_contiguous() && self.offset == 0 {
600            return self.clone();
601        }
602
603        #[cfg(feature = "cuda")]
604        if self.storage.is_gpu() {
605            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
606            let self_f32 = unsafe { gpu_ref(self) };
607            let result = self_f32.contiguous_gpu();
608            return unsafe { gpu_into(result) };
609        }
610
611        let data = self.to_vec();
612        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
613    }
614
615    // =========================================================================
616    // Device Operations
617    // =========================================================================
618
619    /// Transfers the tensor to a different device.
620    ///
621    /// # Arguments
622    /// * `device` - Target device
623    pub fn to_device(&self, device: Device) -> Result<Self> {
624        if self.device() == device {
625            return Ok(self.clone());
626        }
627
628        #[cfg(feature = "cuda")]
629        if self.storage.is_gpu() || device.is_gpu() {
630            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
631            let self_f32 = unsafe { gpu_ref(self) };
632            let result = self_f32.to_device_f32(device)?;
633            return Ok(unsafe { gpu_into(result) });
634        }
635
636        let contig = self.contiguous();
637        let new_storage = contig.storage.to_device(device)?;
638
639        Ok(Self {
640            storage: new_storage,
641            shape: self.shape.clone(),
642            strides: self.strides.clone(),
643            offset: 0,
644        })
645    }
646
647    /// Transfers to CPU.
648    pub fn cpu(&self) -> Result<Self> {
649        self.to_device(Device::Cpu)
650    }
651
652    // =========================================================================
653    // Deep Copy
654    // =========================================================================
655
656    /// Creates a deep copy of this tensor with its own storage.
657    #[must_use]
658    pub fn clone_deep(&self) -> Self {
659        let data = self.to_vec();
660        let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
661        #[cfg(feature = "cuda")]
662        if self.device().is_gpu() {
663            return cpu.to_device(self.device()).unwrap();
664        }
665        cpu
666    }
667}
668
669// =============================================================================
670// Numeric Operations
671// =============================================================================
672
673impl<T: Numeric> Tensor<T> {
674    /// Fills the tensor with a value.
675    ///
676    /// # Panics
677    /// Panics on GPU tensors. Use `Tensor::from_vec(vec![value; n], shape)`
678    /// followed by `.to_device()` instead.
679    pub fn fill_(&self, value: T) {
680        assert!(
681            self.storage.is_cpu(),
682            "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
683        );
684        let mut data = self.storage.as_slice_mut();
685        CpuBackend::fill(&mut data, value);
686    }
687
688    /// Fills the tensor with zeros.
689    pub fn zero_(&self) {
690        self.fill_(T::zero());
691    }
692
693    // =========================================================================
694    // Reduction Operations
695    // =========================================================================
696
697    /// Returns the sum of all elements as a scalar tensor.
698    ///
699    /// On GPU, uses native CUDA reduction kernels (no CPU round-trip).
700    #[must_use]
701    pub fn sum(&self) -> Self {
702        #[cfg(feature = "cuda")]
703        if self.device().is_gpu() {
704            // Reduce each dimension on GPU using the existing sum_dim kernel
705            let mut t = self.clone();
706            while t.ndim() > 1 {
707                t = t.sum_dim_cuda(0);
708            }
709            // Final 1D reduction
710            if t.numel() > 1 {
711                t = t.sum_dim_cuda(0);
712            }
713            return t;
714        }
715
716        let data = self.to_vec();
717        let result = CpuBackend::sum(&data);
718        Self::scalar(result)
719    }
720
721    /// Returns the product of all elements.
722    ///
723    /// GPU: D2H round-trip (no CUDA prod reduction kernel yet).
724    #[must_use]
725    pub fn prod(&self) -> Self {
726        let data = self.to_vec();
727        let result = CpuBackend::prod(&data);
728        let s = Self::scalar(result);
729        #[cfg(feature = "cuda")]
730        if self.device().is_gpu() {
731            return s.to_device(self.device()).expect("prod: device transfer failed");
732        }
733        s
734    }
735
736    /// Returns the maximum element.
737    ///
738    /// GPU: D2H round-trip (no CUDA max reduction kernel yet).
739    pub fn max(&self) -> Result<Self> {
740        if self.is_empty() {
741            return Err(Error::EmptyTensor);
742        }
743        let data = self.to_vec();
744        let result = CpuBackend::max(&data).expect("max on non-empty tensor");
745        let s = Self::scalar(result);
746        #[cfg(feature = "cuda")]
747        if self.device().is_gpu() {
748            return Ok(s.to_device(self.device()).expect("max: device transfer failed"));
749        }
750        Ok(s)
751    }
752
753    /// Returns the minimum element.
754    ///
755    /// GPU: D2H round-trip (no CUDA min reduction kernel yet).
756    pub fn min(&self) -> Result<Self> {
757        if self.is_empty() {
758            return Err(Error::EmptyTensor);
759        }
760        let data = self.to_vec();
761        let result = CpuBackend::min(&data).expect("min on non-empty tensor");
762        let s = Self::scalar(result);
763        #[cfg(feature = "cuda")]
764        if self.device().is_gpu() {
765            return Ok(s.to_device(self.device()).expect("min: device transfer failed"));
766        }
767        Ok(s)
768    }
769
770    /// Returns the index of the maximum element.
771    pub fn argmax(&self) -> Result<usize> {
772        if self.is_empty() {
773            return Err(Error::EmptyTensor);
774        }
775        let data = self.to_vec();
776        Ok(CpuBackend::argmax(&data).unwrap())
777    }
778
779    /// Returns the index of the minimum element.
780    pub fn argmin(&self) -> Result<usize> {
781        if self.is_empty() {
782            return Err(Error::EmptyTensor);
783        }
784        let data = self.to_vec();
785        Ok(CpuBackend::argmin(&data).unwrap())
786    }
787
788    /// Concatenates tensors along a dimension.
789    ///
790    /// All tensors must have the same shape except along the cat dimension.
791    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
792        if tensors.is_empty() {
793            return Err(Error::invalid_operation("cat requires at least one tensor"));
794        }
795        let ndim = tensors[0].ndim();
796        if dim >= ndim {
797            return Err(Error::invalid_operation("cat dimension out of range"));
798        }
799
800        for t in &tensors[1..] {
801            if t.ndim() != ndim {
802                return Err(Error::invalid_operation(
803                    "cat: all tensors must have same ndim",
804                ));
805            }
806            for d in 0..ndim {
807                if d != dim && t.shape[d] != tensors[0].shape[d] {
808                    return Err(Error::invalid_operation(
809                        "cat: shapes must match on non-cat dims",
810                    ));
811                }
812            }
813        }
814
815        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
816        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
817        out_shape[dim] = total_dim_size;
818
819        let outer_size: usize = out_shape[..dim].iter().product();
820        let inner_size: usize = out_shape[dim + 1..].iter().product();
821        let total_numel: usize = out_shape.iter().product();
822        let mut result = vec![T::zero(); total_numel];
823
824        let mut dim_offset = 0;
825        for t in tensors {
826            let t_data = t.contiguous().to_vec();
827            let t_dim_size = t.shape[dim];
828            for outer in 0..outer_size {
829                for d in 0..t_dim_size {
830                    let src_base = outer * t_dim_size * inner_size + d * inner_size;
831                    let dst_base =
832                        outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
833                    result[dst_base..dst_base + inner_size]
834                        .copy_from_slice(&t_data[src_base..src_base + inner_size]);
835                }
836            }
837            dim_offset += t_dim_size;
838        }
839
840        let out = Self::from_vec(result, &out_shape)?;
841        #[cfg(feature = "cuda")]
842        if tensors[0].device().is_gpu() {
843            return Ok(out.to_device(tensors[0].device()).unwrap());
844        }
845        Ok(out)
846    }
847}
848
849// =============================================================================
850// Float Operations
851// =============================================================================
852
853impl<T: Float> Tensor<T> {
854    /// Returns the mean of all elements.
855    /// Returns the mean of all elements.
856    ///
857    /// On GPU, uses native CUDA sum reduction then divides by numel.
858    pub fn mean(&self) -> Result<Self> {
859        if self.is_empty() {
860            return Err(Error::EmptyTensor);
861        }
862        #[cfg(feature = "cuda")]
863        if self.device().is_gpu() {
864            let s = self.sum(); // uses CUDA sum_dim chain
865            let n = self.numel() as f32;
866            // mul_scalar stays on GPU
867            return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
868        }
869
870        let data = self.to_vec();
871        let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
872        Ok(Self::scalar(result))
873    }
874
875    // =========================================================================
876    // Activation Functions
877    // =========================================================================
878
879    /// Applies `ReLU` activation: max(0, x).
880    #[must_use]
881    pub fn relu(&self) -> Self {
882        #[cfg(feature = "cuda")]
883        if self.device().is_gpu() {
884            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
885            return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
886        }
887        let data = self.to_vec();
888        let mut result = vec![T::zero(); data.len()];
889        CpuBackend::relu(&mut result, &data);
890        Self::from_vec(result, &self.shape).unwrap()
891    }
892
893    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
894    #[must_use]
895    pub fn sigmoid(&self) -> Self {
896        #[cfg(feature = "cuda")]
897        if self.device().is_gpu() {
898            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
899            return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
900        }
901        let data = self.to_vec();
902        let mut result = vec![T::zero(); data.len()];
903        CpuBackend::sigmoid(&mut result, &data);
904        Self::from_vec(result, &self.shape).unwrap()
905    }
906
907    /// Applies tanh activation.
908    #[must_use]
909    pub fn tanh(&self) -> Self {
910        #[cfg(feature = "cuda")]
911        if self.device().is_gpu() {
912            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
913            return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
914        }
915        let data = self.to_vec();
916        let mut result = vec![T::zero(); data.len()];
917        CpuBackend::tanh(&mut result, &data);
918        Self::from_vec(result, &self.shape).unwrap()
919    }
920
921    /// Applies exponential function.
922    #[must_use]
923    pub fn exp(&self) -> Self {
924        #[cfg(feature = "cuda")]
925        if self.device().is_gpu() {
926            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
927            return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
928        }
929        let data = self.to_vec();
930        let mut result = vec![T::zero(); data.len()];
931        CpuBackend::exp(&mut result, &data);
932        Self::from_vec(result, &self.shape).unwrap()
933    }
934
935    /// Applies natural logarithm.
936    #[must_use]
937    pub fn ln(&self) -> Self {
938        #[cfg(feature = "cuda")]
939        if self.device().is_gpu() {
940            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
941            return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
942        }
943        let data = self.to_vec();
944        let mut result = vec![T::zero(); data.len()];
945        CpuBackend::ln(&mut result, &data);
946        Self::from_vec(result, &self.shape).unwrap()
947    }
948
949    /// Applies square root.
950    #[must_use]
951    pub fn sqrt(&self) -> Self {
952        #[cfg(feature = "cuda")]
953        if self.device().is_gpu() {
954            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
955            return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
956        }
957        let data = self.to_vec();
958        let mut result = vec![T::zero(); data.len()];
959        CpuBackend::sqrt(&mut result, &data);
960        Self::from_vec(result, &self.shape).unwrap()
961    }
962
963    /// Computes element-wise power.
964    #[must_use]
965    pub fn pow(&self, exp: T) -> Self {
966        #[cfg(feature = "cuda")]
967        if self.device().is_gpu() {
968            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
969            let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
970            return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
971        }
972        let data = self.to_vec();
973        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
974        Self::from_vec(result, &self.shape).unwrap()
975    }
976
977    /// GELU activation function (Gaussian Error Linear Unit).
978    #[must_use]
979    pub fn gelu(&self) -> Self {
980        #[cfg(feature = "cuda")]
981        if self.device().is_gpu() {
982            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
983            return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
984        }
985        crate::ops::gelu(self)
986    }
987
988    /// SiLU/Swish activation function.
989    #[must_use]
990    pub fn silu(&self) -> Self {
991        #[cfg(feature = "cuda")]
992        if self.device().is_gpu() {
993            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
994            return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
995        }
996        crate::ops::silu(self)
997    }
998
999    /// Softmax along specified dimension.
1000    #[must_use]
1001    pub fn softmax(&self, dim: i32) -> Self {
1002        #[cfg(feature = "cuda")]
1003        if self.device().is_gpu() {
1004            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1005            let self_f32 = unsafe { gpu_ref(self) };
1006            return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1007        }
1008        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1009    }
1010
1011    /// Log softmax along specified dimension.
1012    #[must_use]
1013    pub fn log_softmax(&self, dim: i32) -> Self {
1014        let softmax_result = self.softmax(dim);
1015        softmax_result.ln()
1016    }
1017
1018    /// Mean along a dimension.
1019    #[must_use]
1020    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1021        let ndim = self.ndim();
1022        let dim = if dim < 0 {
1023            (ndim as i32 + dim) as usize
1024        } else {
1025            dim as usize
1026        };
1027
1028        if dim >= ndim {
1029            return self.clone();
1030        }
1031
1032        // GPU fast path: sum_dim then divide by dim_size (all on GPU)
1033        #[cfg(feature = "cuda")]
1034        if self.device().is_gpu() {
1035            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1036            let self_f32 = unsafe { gpu_ref(self) };
1037            let summed = if keepdim {
1038                self_f32.sum_dim_keepdim_cuda(dim)
1039            } else {
1040                self_f32.sum_dim_cuda(dim)
1041            };
1042            let dim_size = self.shape[dim];
1043            let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1044            return unsafe { gpu_into(result) };
1045        }
1046
1047        let dim_size = self.shape[dim];
1048        let data = self.to_vec();
1049        let mut new_shape = self.shape.clone();
1050
1051        if keepdim {
1052            new_shape[dim] = 1;
1053        } else {
1054            new_shape.remove(dim);
1055        }
1056
1057        if new_shape.is_empty() {
1058            new_shape = smallvec::smallvec![1];
1059        }
1060
1061        let new_numel: usize = new_shape.iter().product();
1062        let mut result = vec![T::zero(); new_numel];
1063
1064        let outer_size: usize = self.shape[..dim].iter().product();
1065        let inner_size: usize = self.shape[dim + 1..].iter().product();
1066
1067        for outer in 0..outer_size {
1068            for inner in 0..inner_size {
1069                let mut sum = T::zero();
1070                for d in 0..dim_size {
1071                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1072                    sum = sum + data[idx];
1073                }
1074                let mean = sum / NumCast::from(dim_size).unwrap();
1075                let result_idx = outer * inner_size + inner;
1076                result[result_idx] = mean;
1077            }
1078        }
1079
1080        Self::from_vec(result, &new_shape).unwrap()
1081    }
1082
1083    /// Sum along a dimension.
1084    #[must_use]
1085    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1086        let ndim = self.ndim();
1087        let dim = if dim < 0 {
1088            (ndim as i32 + dim) as usize
1089        } else {
1090            dim as usize
1091        };
1092
1093        if dim >= ndim {
1094            return self.clone();
1095        }
1096
1097        // GPU fast path: use CUDA sum_dim kernel (no CPU copies)
1098        #[cfg(feature = "cuda")]
1099        if self.device().is_gpu() {
1100            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1101            let self_f32 = unsafe { gpu_ref(self) };
1102            let result = if keepdim {
1103                self_f32.sum_dim_keepdim_cuda(dim)
1104            } else {
1105                self_f32.sum_dim_cuda(dim)
1106            };
1107            return unsafe { gpu_into(result) };
1108        }
1109
1110        let dim_size = self.shape[dim];
1111        let data = self.to_vec();
1112        let mut new_shape = self.shape.clone();
1113
1114        if keepdim {
1115            new_shape[dim] = 1;
1116        } else {
1117            new_shape.remove(dim);
1118        }
1119
1120        if new_shape.is_empty() {
1121            new_shape = smallvec::smallvec![1];
1122        }
1123
1124        let new_numel: usize = new_shape.iter().product();
1125        let mut result = vec![T::zero(); new_numel];
1126
1127        let outer_size: usize = self.shape[..dim].iter().product();
1128        let inner_size: usize = self.shape[dim + 1..].iter().product();
1129
1130        for outer in 0..outer_size {
1131            for inner in 0..inner_size {
1132                let mut sum = T::zero();
1133                for d in 0..dim_size {
1134                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1135                    sum = sum + data[idx];
1136                }
1137                let result_idx = outer * inner_size + inner;
1138                result[result_idx] = sum;
1139            }
1140        }
1141
1142        Self::from_vec(result, &new_shape).unwrap()
1143    }
1144
1145    /// Variance along a dimension.
1146    #[must_use]
1147    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1148        // variance = E[x²] - E[x]²  (saves one full-size intermediate allocation)
1149        let mean = self.mean_dim(dim, true);
1150        let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1151        let mean_sq = sq.mean_dim(dim, keepdim);
1152        let mean_keepdim = if keepdim {
1153            mean.clone()
1154        } else {
1155            self.mean_dim(dim, keepdim)
1156        };
1157        let mean_squared = mean_keepdim
1158            .mul(&mean_keepdim)
1159            .unwrap_or_else(|_| mean_keepdim.clone());
1160        mean_sq
1161            .sub(&mean_squared)
1162            .unwrap_or_else(|_| mean_sq.clone())
1163    }
1164
1165    /// Broadcasts tensor to a new shape.
1166    #[must_use]
1167    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1168        if self.shape.as_slice() == shape {
1169            return self.clone();
1170        }
1171
1172        #[cfg(feature = "cuda")]
1173        if self.device().is_gpu() {
1174            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1175            let self_f32 = unsafe { gpu_ref(self) };
1176            return unsafe {
1177                gpu_into(
1178                    self_f32
1179                        .broadcast_to_cuda(shape)
1180                        .expect("CUDA broadcast_to failed"),
1181                )
1182            };
1183        }
1184
1185        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1186        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1187
1188        let total = numel(&result_shape);
1189        let mut result_data = vec![T::zero(); total];
1190        let self_data = self.storage.as_slice();
1191
1192        for i in 0..total {
1193            let indices = crate::shape::unravel_index(i, &result_shape);
1194            let self_idx = self.offset + linear_index(&indices, &self_strides);
1195            result_data[i] = self_data[self_idx];
1196        }
1197
1198        Self::from_vec(result_data, &result_shape).unwrap()
1199    }
1200
1201    /// Slices the tensor using ranges for each dimension.
1202    #[must_use]
1203    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1204        let mut new_shape = Vec::with_capacity(self.ndim());
1205        for (i, range) in ranges.iter().enumerate() {
1206            if i < self.ndim() {
1207                new_shape.push(range.end - range.start);
1208            }
1209        }
1210        // Keep remaining dimensions unchanged
1211        for i in ranges.len()..self.ndim() {
1212            new_shape.push(self.shape[i]);
1213        }
1214
1215        let new_numel: usize = new_shape.iter().product();
1216        let mut result_data = vec![T::zero(); new_numel];
1217        let self_data = self.to_vec();
1218
1219        // Copy data with proper indexing
1220        let mut result_idx = 0;
1221        Self::slice_recursive(
1222            &self_data,
1223            &self.shape,
1224            ranges,
1225            0,
1226            0,
1227            &mut result_data,
1228            &mut result_idx,
1229        );
1230
1231        let out = Self::from_vec(result_data, &new_shape).unwrap();
1232        #[cfg(feature = "cuda")]
1233        if self.device().is_gpu() {
1234            return out.to_device(self.device()).unwrap();
1235        }
1236        out
1237    }
1238
1239    fn slice_recursive(
1240        data: &[T],
1241        shape: &[usize],
1242        ranges: &[std::ops::Range<usize>],
1243        dim: usize,
1244        offset: usize,
1245        result: &mut [T],
1246        result_idx: &mut usize,
1247    ) {
1248        if dim == shape.len() {
1249            result[*result_idx] = data[offset];
1250            *result_idx += 1;
1251            return;
1252        }
1253
1254        let stride: usize = shape[dim + 1..].iter().product();
1255        let (start, end) = if dim < ranges.len() {
1256            (ranges[dim].start, ranges[dim].end)
1257        } else {
1258            (0, shape[dim])
1259        };
1260
1261        for i in start..end {
1262            Self::slice_recursive(
1263                data,
1264                shape,
1265                ranges,
1266                dim + 1,
1267                offset + i * stride,
1268                result,
1269                result_idx,
1270            );
1271        }
1272    }
1273}
1274
1275// =============================================================================
1276// Arithmetic Operator Implementations
1277// =============================================================================
1278
1279impl<T: Numeric> Tensor<T> {
1280    /// Element-wise addition with broadcasting.
1281    pub fn add(&self, other: &Self) -> Result<Self> {
1282        #[cfg(feature = "cuda")]
1283        {
1284            let self_gpu = self.device().is_gpu();
1285            let other_gpu = other.device().is_gpu();
1286            if self_gpu || other_gpu {
1287                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1288                if self_gpu && other_gpu {
1289                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1290                    if self.shape == other.shape {
1291                        return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1292                    } else {
1293                        return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1294                    }
1295                }
1296                // Mixed device — move to GPU, then operate
1297                let target_device = if self_gpu {
1298                    self.device()
1299                } else {
1300                    other.device()
1301                };
1302                let a_gpu = if self_gpu {
1303                    self.clone()
1304                } else {
1305                    self.to_device(target_device)?
1306                };
1307                let b_gpu = if other_gpu {
1308                    other.clone()
1309                } else {
1310                    other.to_device(target_device)?
1311                };
1312                return a_gpu.add(&b_gpu);
1313            }
1314        }
1315        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1316        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1317        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1318
1319        let total = numel(&result_shape);
1320        let mut result_data = vec![T::zero(); total];
1321
1322        let self_data = self.storage.as_slice();
1323        let other_data = other.storage.as_slice();
1324
1325        for i in 0..total {
1326            let indices = crate::shape::unravel_index(i, &result_shape);
1327            let self_idx = self.offset + linear_index(&indices, &self_strides);
1328            let other_idx = other.offset + linear_index(&indices, &other_strides);
1329            result_data[i] = self_data[self_idx] + other_data[other_idx];
1330        }
1331
1332        Self::from_vec(result_data, &result_shape)
1333    }
1334
1335    /// Element-wise subtraction with broadcasting.
1336    pub fn sub(&self, other: &Self) -> Result<Self> {
1337        #[cfg(feature = "cuda")]
1338        {
1339            let self_gpu = self.device().is_gpu();
1340            let other_gpu = other.device().is_gpu();
1341            if self_gpu || other_gpu {
1342                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1343                if self_gpu && other_gpu {
1344                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1345                    if self.shape == other.shape {
1346                        return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1347                    } else {
1348                        return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1349                    }
1350                }
1351                let target = if self_gpu {
1352                    self.device()
1353                } else {
1354                    other.device()
1355                };
1356                let a_gpu = if self_gpu {
1357                    self.clone()
1358                } else {
1359                    self.to_device(target)?
1360                };
1361                let b_gpu = if other_gpu {
1362                    other.clone()
1363                } else {
1364                    other.to_device(target)?
1365                };
1366                return a_gpu.sub(&b_gpu);
1367            }
1368        }
1369        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1370        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1371        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1372
1373        let total = numel(&result_shape);
1374        let mut result_data = vec![T::zero(); total];
1375
1376        let self_data = self.storage.as_slice();
1377        let other_data = other.storage.as_slice();
1378
1379        for i in 0..total {
1380            let indices = crate::shape::unravel_index(i, &result_shape);
1381            let self_idx = self.offset + linear_index(&indices, &self_strides);
1382            let other_idx = other.offset + linear_index(&indices, &other_strides);
1383            result_data[i] = self_data[self_idx] - other_data[other_idx];
1384        }
1385
1386        Self::from_vec(result_data, &result_shape)
1387    }
1388
1389    /// Element-wise multiplication with broadcasting.
1390    pub fn mul(&self, other: &Self) -> Result<Self> {
1391        #[cfg(feature = "cuda")]
1392        {
1393            let self_gpu = self.device().is_gpu();
1394            let other_gpu = other.device().is_gpu();
1395            if self_gpu || other_gpu {
1396                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1397                if self_gpu && other_gpu {
1398                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1399                    if self.shape == other.shape {
1400                        return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1401                    } else {
1402                        return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1403                    }
1404                }
1405                let target = if self_gpu {
1406                    self.device()
1407                } else {
1408                    other.device()
1409                };
1410                let a_gpu = if self_gpu {
1411                    self.clone()
1412                } else {
1413                    self.to_device(target)?
1414                };
1415                let b_gpu = if other_gpu {
1416                    other.clone()
1417                } else {
1418                    other.to_device(target)?
1419                };
1420                return a_gpu.mul(&b_gpu);
1421            }
1422        }
1423        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1424        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1425        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1426
1427        let total = numel(&result_shape);
1428        let mut result_data = vec![T::zero(); total];
1429
1430        let self_data = self.storage.as_slice();
1431        let other_data = other.storage.as_slice();
1432
1433        for i in 0..total {
1434            let indices = crate::shape::unravel_index(i, &result_shape);
1435            let self_idx = self.offset + linear_index(&indices, &self_strides);
1436            let other_idx = other.offset + linear_index(&indices, &other_strides);
1437            result_data[i] = self_data[self_idx] * other_data[other_idx];
1438        }
1439
1440        Self::from_vec(result_data, &result_shape)
1441    }
1442
1443    /// Element-wise division with broadcasting.
1444    pub fn div(&self, other: &Self) -> Result<Self> {
1445        #[cfg(feature = "cuda")]
1446        {
1447            let self_gpu = self.device().is_gpu();
1448            let other_gpu = other.device().is_gpu();
1449            if self_gpu || other_gpu {
1450                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1451                if self_gpu && other_gpu {
1452                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1453                    if self.shape == other.shape {
1454                        return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1455                    } else {
1456                        return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1457                    }
1458                }
1459                let target = if self_gpu {
1460                    self.device()
1461                } else {
1462                    other.device()
1463                };
1464                let a_gpu = if self_gpu {
1465                    self.clone()
1466                } else {
1467                    self.to_device(target)?
1468                };
1469                let b_gpu = if other_gpu {
1470                    other.clone()
1471                } else {
1472                    other.to_device(target)?
1473                };
1474                return a_gpu.div(&b_gpu);
1475            }
1476        }
1477        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1478        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1479        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1480
1481        let total = numel(&result_shape);
1482        let mut result_data = vec![T::zero(); total];
1483
1484        let self_data = self.storage.as_slice();
1485        let other_data = other.storage.as_slice();
1486
1487        for i in 0..total {
1488            let indices = crate::shape::unravel_index(i, &result_shape);
1489            let self_idx = self.offset + linear_index(&indices, &self_strides);
1490            let other_idx = other.offset + linear_index(&indices, &other_strides);
1491            result_data[i] = self_data[self_idx] / other_data[other_idx];
1492        }
1493
1494        Self::from_vec(result_data, &result_shape)
1495    }
1496
1497    /// Scalar addition.
1498    #[must_use]
1499    pub fn add_scalar(&self, scalar: T) -> Self {
1500        #[cfg(feature = "cuda")]
1501        if self.device().is_gpu() {
1502            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1503            let self_f32 = unsafe { gpu_ref(self) };
1504            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1505            return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1506        }
1507        let data = self.to_vec();
1508        let mut result = vec![T::zero(); data.len()];
1509        CpuBackend::add_scalar(&mut result, &data, scalar);
1510        Self::from_vec(result, &self.shape).unwrap()
1511    }
1512
1513    /// Scalar multiplication.
1514    #[must_use]
1515    pub fn mul_scalar(&self, scalar: T) -> Self {
1516        #[cfg(feature = "cuda")]
1517        if self.device().is_gpu() {
1518            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1519            let self_f32 = unsafe { gpu_ref(self) };
1520            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1521            return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1522        }
1523        let data = self.to_vec();
1524        let mut result = vec![T::zero(); data.len()];
1525        CpuBackend::mul_scalar(&mut result, &data, scalar);
1526        Self::from_vec(result, &self.shape).unwrap()
1527    }
1528
1529    /// Element-wise negation.
1530    #[must_use]
1531    pub fn neg(&self) -> Self {
1532        #[cfg(feature = "cuda")]
1533        if self.device().is_gpu() {
1534            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1535            let self_f32 = unsafe { gpu_ref(self) };
1536            return unsafe { gpu_into(self_f32.neg_cuda()) };
1537        }
1538        let data = self.to_vec();
1539        let mut result = vec![T::zero(); data.len()];
1540        CpuBackend::neg(&mut result, &data);
1541        Self::from_vec(result, &self.shape).unwrap()
1542    }
1543
1544    /// Matrix multiplication with batching support.
1545    ///
1546    /// Supports:
1547    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1548    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1549    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1550    pub fn matmul(&self, other: &Self) -> Result<Self> {
1551        #[cfg(feature = "cuda")]
1552        if self.device().is_gpu() {
1553            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1554            let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1555            return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1556        }
1557        if self.ndim() < 2 || other.ndim() < 2 {
1558            return Err(Error::invalid_operation(
1559                "matmul requires at least 2D tensors",
1560            ));
1561        }
1562
1563        let m = self.shape[self.ndim() - 2];
1564        let k1 = self.shape[self.ndim() - 1];
1565        let k2 = other.shape[other.ndim() - 2];
1566        let n = other.shape[other.ndim() - 1];
1567
1568        if k1 != k2 {
1569            return Err(Error::invalid_operation(format!(
1570                "matmul inner dimensions must match: {k1} vs {k2}"
1571            )));
1572        }
1573
1574        // For 2D matrices, do simple matmul
1575        if self.ndim() == 2 && other.ndim() == 2 {
1576            let a_data = self.contiguous().to_vec();
1577            let b_data = other.contiguous().to_vec();
1578
1579            // GPU-accelerated matmul for CPU tensors: only for very large matrices
1580            // where transfer overhead is negligible relative to compute.
1581            // For GPU-resident tensors, the dispatch at the top of matmul() handles it.
1582            #[cfg(feature = "cuda")]
1583            {
1584                let flops = m * n * k1;
1585                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1586                    && flops >= 4_000_000
1587                {
1588                    debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1589                    // SAFETY: T is f32 (checked by TypeId above), same size and layout
1590                    let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1591                    let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1592                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1593                        // SAFETY: T is f32, Vec<f32> → Vec<T> is a no-op transmute
1594                        let c_t: Vec<T> = unsafe {
1595                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1596                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1597                        };
1598                        return Self::from_vec(c_t, &[m, n]);
1599                    }
1600                }
1601            }
1602
1603            let mut c_data = vec![T::zero(); m * n];
1604            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1605            return Self::from_vec(c_data, &[m, n]);
1606        }
1607
1608        // For batched matmul, compute batch size
1609        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1610        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1611
1612        // Broadcast batch dimensions (PyTorch parity)
1613        let broadcast_batch = if batch_dims_self != batch_dims_other {
1614            // Pad to same length
1615            let max_len = batch_dims_self.len().max(batch_dims_other.len());
1616            let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1617            let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1618            let a_dims: Vec<usize> = pad_a.iter().chain(batch_dims_self.iter()).copied().collect();
1619            let b_dims: Vec<usize> = pad_b.iter().chain(batch_dims_other.iter()).copied().collect();
1620
1621            let mut out_dims = Vec::with_capacity(max_len);
1622            for i in 0..max_len {
1623                if a_dims[i] == b_dims[i] {
1624                    out_dims.push(a_dims[i]);
1625                } else if a_dims[i] == 1 {
1626                    out_dims.push(b_dims[i]);
1627                } else if b_dims[i] == 1 {
1628                    out_dims.push(a_dims[i]);
1629                } else {
1630                    return Err(Error::invalid_operation(format!(
1631                        "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1632                        batch_dims_self, batch_dims_other
1633                    )));
1634                }
1635            }
1636            Some((a_dims, b_dims, out_dims))
1637        } else {
1638            None
1639        };
1640
1641        let (batch_size, a_batch_idx, b_batch_idx) = if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1642            let bs: usize = out_dims.iter().product();
1643            // Build index mapping: for each output batch, which a and b batch to use
1644            let mut a_idx = Vec::with_capacity(bs);
1645            let mut b_idx = Vec::with_capacity(bs);
1646            for flat in 0..bs {
1647                let mut remaining = flat;
1648                let mut ai = 0usize;
1649                let mut bi = 0usize;
1650                let mut a_stride_acc = 1usize;
1651                let mut b_stride_acc = 1usize;
1652                for d in (0..out_dims.len()).rev() {
1653                    let out_d = out_dims[d];
1654                    let idx = remaining % out_d;
1655                    remaining /= out_d;
1656                    let a_d = a_dims[d];
1657                    let b_d = b_dims[d];
1658                    ai += (idx % a_d) * a_stride_acc;
1659                    bi += (idx % b_d) * b_stride_acc;
1660                    a_stride_acc *= a_d;
1661                    b_stride_acc *= b_d;
1662                }
1663                a_idx.push(ai);
1664                b_idx.push(bi);
1665            }
1666            (bs, a_idx, b_idx)
1667        } else {
1668            let bs: usize = batch_dims_self.iter().product();
1669            let idx: Vec<usize> = (0..bs).collect();
1670            (bs, idx.clone(), idx)
1671        };
1672
1673        let a_stride = m * k1;
1674        let b_stride = k1 * n;
1675        let c_stride = m * n;
1676
1677        let a_data = self.contiguous().to_vec();
1678        let b_data = other.contiguous().to_vec();
1679        let mut c_data = vec![T::zero(); batch_size * m * n];
1680
1681        // Try GPU acceleration for f32 batched matmul (only for large enough matrices)
1682        #[cfg(feature = "cuda")]
1683        {
1684            let flops = m * n * k1;
1685            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1686                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1687                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1688                let mut gpu_ok = true;
1689                for batch in 0..batch_size {
1690                    let ai = a_batch_idx[batch];
1691                    let bi = b_batch_idx[batch];
1692                    let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1693                    let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1694                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1695                        c_data[batch * c_stride..(batch + 1) * c_stride]
1696                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1697                    } else {
1698                        gpu_ok = false;
1699                        break;
1700                    }
1701                }
1702                if gpu_ok {
1703                    let mut output_shape = batch_dims_self;
1704                    output_shape.push(m);
1705                    output_shape.push(n);
1706                    return Self::from_vec(c_data, &output_shape);
1707                }
1708                // Fall through to CPU if GPU failed
1709                c_data = vec![T::zero(); batch_size * m * n];
1710            }
1711        }
1712
1713        // CPU fallback: loop over batches
1714        for batch in 0..batch_size {
1715            let ai = a_batch_idx[batch];
1716            let bi = b_batch_idx[batch];
1717            let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1718            let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1719            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1720            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1721        }
1722
1723        // Build output shape: broadcast batch dims + [m, n]
1724        let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1725            out_dims.clone()
1726        } else {
1727            batch_dims_self
1728        };
1729        output_shape.push(m);
1730        output_shape.push(n);
1731
1732        Self::from_vec(c_data, &output_shape)
1733    }
1734
1735    /// Dot product for 1D tensors.
1736    pub fn dot(&self, other: &Self) -> Result<Self> {
1737        if self.ndim() != 1 || other.ndim() != 1 {
1738            return Err(Error::invalid_operation("dot requires 1D tensors"));
1739        }
1740
1741        if self.shape[0] != other.shape[0] {
1742            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1743        }
1744
1745        let a_data = self.to_vec();
1746        let b_data = other.to_vec();
1747        let result = CpuBackend::dot(&a_data, &b_data);
1748
1749        Ok(Self::scalar(result))
1750    }
1751}
1752
1753// =============================================================================
1754// Operator Trait Implementations
1755// =============================================================================
1756
1757impl<T: Numeric> Add for &Tensor<T> {
1758    type Output = Tensor<T>;
1759
1760    fn add(self, other: Self) -> Self::Output {
1761        self.add(other).expect("Addition failed")
1762    }
1763}
1764
1765impl<T: Numeric> Sub for &Tensor<T> {
1766    type Output = Tensor<T>;
1767
1768    fn sub(self, other: Self) -> Self::Output {
1769        self.sub(other).expect("Subtraction failed")
1770    }
1771}
1772
1773impl<T: Numeric> Mul for &Tensor<T> {
1774    type Output = Tensor<T>;
1775
1776    fn mul(self, other: Self) -> Self::Output {
1777        self.mul(other).expect("Multiplication failed")
1778    }
1779}
1780
1781impl<T: Numeric> Div for &Tensor<T> {
1782    type Output = Tensor<T>;
1783
1784    fn div(self, other: Self) -> Self::Output {
1785        self.div(other).expect("Division failed")
1786    }
1787}
1788
1789impl<T: Numeric> Neg for &Tensor<T> {
1790    type Output = Tensor<T>;
1791
1792    fn neg(self) -> Self::Output {
1793        self.neg()
1794    }
1795}
1796
1797// Scalar operations
1798impl<T: Numeric> Add<T> for &Tensor<T> {
1799    type Output = Tensor<T>;
1800
1801    fn add(self, scalar: T) -> Self::Output {
1802        self.add_scalar(scalar)
1803    }
1804}
1805
1806impl<T: Numeric> Mul<T> for &Tensor<T> {
1807    type Output = Tensor<T>;
1808
1809    fn mul(self, scalar: T) -> Self::Output {
1810        self.mul_scalar(scalar)
1811    }
1812}
1813
1814// =============================================================================
1815// Display Implementation
1816// =============================================================================
1817
1818impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1819    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1820        write!(
1821            f,
1822            "Tensor(shape={:?}, device={}",
1823            self.shape(),
1824            self.device()
1825        )?;
1826        if self.numel() <= 10 {
1827            write!(f, ", data={:?}", self.to_vec())?;
1828        }
1829        write!(f, ")")
1830    }
1831}
1832
1833impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1834    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1835        if self.is_scalar() {
1836            write!(f, "{}", self.item().unwrap())
1837        } else if self.ndim() == 1 {
1838            write!(f, "[")?;
1839            let data = self.to_vec();
1840            for (i, val) in data.iter().enumerate() {
1841                if i > 0 {
1842                    write!(f, ", ")?;
1843                }
1844                write!(f, "{val}")?;
1845            }
1846            write!(f, "]")
1847        } else {
1848            write!(f, "Tensor(shape={:?})", self.shape())
1849        }
1850    }
1851}
1852
1853// =============================================================================
1854// f32 ↔ f16 Casting for AMP (Automatic Mixed Precision)
1855// =============================================================================
1856
1857impl Tensor<f32> {
1858    /// Cast this f32 tensor to f16 values stored as f32.
1859    ///
1860    /// Each value is rounded to f16 precision. This simulates half-precision
1861    /// computation while keeping the tensor type as f32, which is how AMP
1862    /// works — the autograd graph stays f32 but computation uses f16 precision.
1863    ///
1864    /// On GPU, this uses a CUDA kernel for fast conversion.
1865    /// On CPU, this uses the `half` crate.
1866    #[must_use]
1867    pub fn to_f16_precision(&self) -> Self {
1868        let data = self.to_vec();
1869        let f16_data: Vec<f32> = data
1870            .iter()
1871            .map(|&v| {
1872                let h = half::f16::from_f32(v);
1873                h.to_f32()
1874            })
1875            .collect();
1876        Self::from_vec(f16_data, self.shape()).unwrap()
1877    }
1878
1879    /// Cast f16-precision values back to full f32 precision.
1880    ///
1881    /// This is a no-op since the data is already stored as f32.
1882    /// Included for API symmetry with `to_f16_precision()`.
1883    #[must_use]
1884    pub fn to_f32_precision(&self) -> Self {
1885        self.clone()
1886    }
1887
1888    /// Returns true if applying f16 precision would change any values.
1889    /// Useful for debugging AMP-related numerical issues.
1890    #[must_use]
1891    pub fn has_f16_rounding_error(&self) -> bool {
1892        let data = self.to_vec();
1893        data.iter().any(|&v| {
1894            let h = half::f16::from_f32(v);
1895            (h.to_f32() - v).abs() > f32::EPSILON
1896        })
1897    }
1898}
1899
1900// =============================================================================
1901// Tests
1902// =============================================================================
1903
1904#[cfg(test)]
1905mod tests {
1906    use super::*;
1907
1908    #[test]
1909    fn test_from_vec() {
1910        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1911        assert_eq!(t.shape(), &[2, 3]);
1912        assert_eq!(t.numel(), 6);
1913    }
1914
1915    #[test]
1916    fn test_get_set() {
1917        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1918        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1919        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1920        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1921        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1922
1923        t.set(&[0, 0], 99.0).unwrap();
1924        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1925    }
1926
1927    #[test]
1928    fn test_reshape() {
1929        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1930        let r = t.reshape(&[3, 2]).expect("reshape failed");
1931        assert_eq!(r.shape(), &[3, 2]);
1932
1933        let r = t.reshape(&[-1]).expect("reshape failed");
1934        assert_eq!(r.shape(), &[6]);
1935    }
1936
1937    #[test]
1938    fn test_transpose() {
1939        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1940        let r = t.t().unwrap();
1941        assert_eq!(r.shape(), &[3, 2]);
1942        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1943        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1944        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1945    }
1946
1947    #[test]
1948    fn test_arithmetic() {
1949        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1950        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1951
1952        let c = &a + &b;
1953        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1954
1955        let d = &a * &b;
1956        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1957    }
1958
1959    #[test]
1960    fn test_broadcasting() {
1961        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1962        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1963
1964        let c = &a + &b;
1965        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1966    }
1967
1968    #[test]
1969    fn test_sum() {
1970        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1971        let s = t.sum();
1972        assert_eq!(s.item().unwrap(), 10.0);
1973    }
1974
1975    #[test]
1976    fn test_matmul() {
1977        // 2x2 @ 2x2
1978        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1979        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1980        let c = a.matmul(&b).unwrap();
1981
1982        assert_eq!(c.shape(), &[2, 2]);
1983        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1984    }
1985
1986    #[test]
1987    fn test_relu() {
1988        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1989        let r = t.relu();
1990        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1991    }
1992
1993    #[test]
1994    fn test_scalar() {
1995        let s = Tensor::<f32>::scalar(42.0);
1996        assert!(s.is_scalar());
1997        assert_eq!(s.numel(), 1);
1998        assert_eq!(s.item().unwrap(), 42.0);
1999    }
2000}