Skip to main content

axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! # File
4//! `crates/axonml-tensor/src/tensor.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29// =============================================================================
30// CUDA Acceleration
31// =============================================================================
32
33#[cfg(feature = "cuda")]
34mod cuda_accel {
35    use super::*;
36    use axonml_core::backends::cuda::get_cuda_backend;
37
38    /// Get the global CUDA backend (delegates to core singleton).
39    pub fn get_cuda() -> Option<&'static CudaBackend> {
40        get_cuda_backend()
41    }
42
43    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
44    /// Returns None if GPU is unavailable or an error occurs.
45    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46        let cuda = get_cuda()?;
47
48        let a_gpu = cuda.htod_copy(a).ok()?;
49        let b_gpu = cuda.htod_copy(b).ok()?;
50        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
53        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
54        cuda.gemm_f32(
55            false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56        )
57        .ok()?;
58
59        cuda.dtoh_copy(&c_gpu).ok()
60    }
61}
62
63use crate::shape::{
64    Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65    linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66    unsqueeze,
67};
68
69// =============================================================================
70// GPU Dispatch Helpers
71// =============================================================================
72//
73// These enable calling Tensor<f32> GPU methods from generic Tensor<T> code
74// when T is verified to be f32 via TypeId check at runtime.
75
76#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78    &*(t as *const Tensor<T> as *const Tensor<f32>)
79}
80
81#[cfg(feature = "cuda")]
82unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
83    let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
84    std::mem::forget(t);
85    out
86}
87
88#[cfg(feature = "cuda")]
89fn is_f32<T: 'static>() -> bool {
90    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
91}
92
93// =============================================================================
94// Tensor Struct
95// =============================================================================
96
97/// An N-dimensional array of numeric values.
98///
99/// Tensors are the core data structure for all computations in Axonml.
100/// They support arbitrary dimensions, automatic broadcasting, and efficient
101/// memory sharing between views.
102#[derive(Clone)]
103pub struct Tensor<T: Scalar> {
104    /// Underlying data storage (reference-counted).
105    pub(crate) storage: Storage<T>,
106    /// Shape of the tensor (dimensions).
107    pub(crate) shape: Shape,
108    /// Strides for each dimension.
109    pub(crate) strides: Strides,
110    /// Offset into storage (for views).
111    pub(crate) offset: usize,
112}
113
114impl<T: Scalar> Tensor<T> {
115    // =========================================================================
116    // Constructors
117    // =========================================================================
118
119    /// Creates a new tensor from storage with the given shape.
120    ///
121    /// # Arguments
122    /// * `storage` - The underlying data storage
123    /// * `shape` - Shape of the tensor
124    ///
125    /// # Returns
126    /// New tensor, or error if shape doesn't match storage size.
127    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
128        let total = numel(shape);
129        if total != storage.len() {
130            return Err(Error::shape_mismatch(&[storage.len()], shape));
131        }
132
133        let shape = Shape::from_slice(shape);
134        let strides = contiguous_strides(&shape);
135
136        Ok(Self {
137            storage,
138            shape,
139            strides,
140            offset: 0,
141        })
142    }
143
144    /// Creates a new tensor from a vector with the given shape.
145    ///
146    /// # Arguments
147    /// * `data` - Vector of data
148    /// * `shape` - Shape of the tensor
149    ///
150    /// # Returns
151    /// New tensor, or error if shape doesn't match data length.
152    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
153        let storage = Storage::from_vec(data, Device::Cpu);
154        Self::from_storage(storage, shape)
155    }
156
157    /// Creates a new tensor from a slice with the given shape.
158    ///
159    /// # Arguments
160    /// * `data` - Slice of data to copy
161    /// * `shape` - Shape of the tensor
162    ///
163    /// # Returns
164    /// New tensor, or error if shape doesn't match data length.
165    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
166        let storage = Storage::from_slice(data, Device::Cpu);
167        Self::from_storage(storage, shape)
168    }
169
170    /// Creates a scalar tensor (0-dimensional).
171    ///
172    /// # Arguments
173    /// * `value` - The scalar value
174    ///
175    /// # Returns
176    /// New 0-dimensional tensor.
177    pub fn scalar(value: T) -> Self {
178        Self {
179            storage: Storage::from_vec(vec![value], Device::Cpu),
180            shape: Shape::new(),
181            strides: Strides::new(),
182            offset: 0,
183        }
184    }
185
186    /// Creates a tensor filled with zeros.
187    #[must_use]
188    pub fn zeros(shape: &[usize]) -> Self {
189        crate::creation::zeros(shape)
190    }
191
192    /// Creates a tensor filled with ones.
193    #[must_use]
194    pub fn ones(shape: &[usize]) -> Self
195    where
196        T: Numeric,
197    {
198        crate::creation::ones(shape)
199    }
200
201    /// Creates a tensor filled with a constant value.
202    #[must_use]
203    pub fn full(shape: &[usize], value: T) -> Self {
204        crate::creation::full(shape, value)
205    }
206
207    /// Creates a tensor with random values from standard normal distribution.
208    #[must_use]
209    pub fn randn(shape: &[usize]) -> Self
210    where
211        T: Float,
212        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
213    {
214        crate::creation::randn(shape)
215    }
216
217    /// Creates a tensor with random values from uniform distribution [0, 1).
218    #[must_use]
219    pub fn rand(shape: &[usize]) -> Self
220    where
221        T: Float,
222        rand::distributions::Standard: rand::distributions::Distribution<T>,
223    {
224        crate::creation::rand(shape)
225    }
226
227    // =========================================================================
228    // Properties
229    // =========================================================================
230
231    /// Returns the shape of the tensor.
232    #[must_use]
233    pub fn shape(&self) -> &[usize] {
234        &self.shape
235    }
236
237    /// Returns the strides of the tensor.
238    #[must_use]
239    pub fn strides(&self) -> &[isize] {
240        &self.strides
241    }
242
243    /// Returns the number of dimensions.
244    #[must_use]
245    pub fn ndim(&self) -> usize {
246        self.shape.len()
247    }
248
249    /// Returns the total number of elements.
250    #[must_use]
251    pub fn numel(&self) -> usize {
252        numel(&self.shape)
253    }
254
255    /// Returns true if the tensor is empty (has zero elements).
256    #[must_use]
257    pub fn is_empty(&self) -> bool {
258        self.numel() == 0
259    }
260
261    /// Returns the size of a specific dimension.
262    ///
263    /// # Arguments
264    /// * `dim` - Dimension index (supports negative indexing)
265    pub fn size(&self, dim: i64) -> Result<usize> {
266        let idx = normalize_dim(dim, self.ndim())?;
267        Ok(self.shape[idx])
268    }
269
270    /// Returns the device this tensor is on.
271    #[must_use]
272    pub fn device(&self) -> Device {
273        self.storage.device()
274    }
275
276    /// Returns true if the tensor is contiguous in memory.
277    #[must_use]
278    pub fn is_contiguous(&self) -> bool {
279        is_contiguous(&self.shape, &self.strides)
280    }
281
282    /// Returns true if this tensor is a scalar (0-dimensional).
283    #[must_use]
284    pub fn is_scalar(&self) -> bool {
285        self.shape.is_empty()
286    }
287
288    // =========================================================================
289    // Data Access
290    // =========================================================================
291
292    /// Returns the element at the given indices.
293    ///
294    /// # Arguments
295    /// * `indices` - Multi-dimensional indices
296    pub fn get(&self, indices: &[usize]) -> Result<T> {
297        if indices.len() != self.ndim() {
298            return Err(Error::invalid_operation(format!(
299                "Expected {} indices, got {}",
300                self.ndim(),
301                indices.len()
302            )));
303        }
304
305        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
306            if idx >= dim {
307                return Err(Error::IndexOutOfBounds {
308                    index: idx,
309                    size: dim,
310                });
311            }
312        }
313
314        let offset = self.offset + linear_index(indices, &self.strides);
315        Ok(self.storage.as_slice()[offset])
316    }
317
318    /// Sets the element at the given indices.
319    ///
320    /// # Arguments
321    /// * `indices` - Multi-dimensional indices
322    /// * `value` - Value to set
323    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
324        if indices.len() != self.ndim() {
325            return Err(Error::invalid_operation(format!(
326                "Expected {} indices, got {}",
327                self.ndim(),
328                indices.len()
329            )));
330        }
331
332        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
333            if idx >= dim {
334                return Err(Error::IndexOutOfBounds {
335                    index: idx,
336                    size: dim,
337                });
338            }
339        }
340
341        let offset = self.offset + linear_index(indices, &self.strides);
342        self.storage.as_slice_mut()[offset] = value;
343        Ok(())
344    }
345
346    /// Returns the scalar value for a 0-dimensional tensor.
347    pub fn item(&self) -> Result<T> {
348        if self.numel() != 1 {
349            return Err(Error::invalid_operation(
350                "item() only works on single-element tensors",
351            ));
352        }
353
354        if self.is_scalar() {
355            Ok(self.storage.as_slice()[self.offset])
356        } else {
357            // Single element but not scalar shape
358            let indices = vec![0; self.ndim()];
359            self.get(&indices)
360        }
361    }
362
363    /// Returns the data as a contiguous vector.
364    ///
365    /// If the tensor is already contiguous, this returns a reference.
366    /// Otherwise, it copies the data into a new contiguous vector.
367    /// For GPU tensors (f32 only), performs a D2H copy.
368    #[must_use]
369    pub fn to_vec(&self) -> Vec<T> {
370        // GPU path: GPU storage is always f32
371        #[cfg(feature = "cuda")]
372        if self.storage.is_gpu() {
373            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
374            let self_f32 = unsafe { gpu_ref(self) };
375            let f32_vec = self_f32.to_vec_gpu();
376            unsafe {
377                let mut v = std::mem::ManuallyDrop::new(f32_vec);
378                return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
379            }
380        }
381
382        if self.is_contiguous() {
383            let storage = self.storage.as_slice();
384            storage[self.offset..self.offset + self.numel()].to_vec()
385        } else {
386            let mut result = Vec::with_capacity(self.numel());
387            self.copy_data_to(&mut result);
388            result
389        }
390    }
391
392    /// Copies data to a slice, handling non-contiguous layouts.
393    fn copy_data_to(&self, dst: &mut Vec<T>) {
394        dst.clear();
395        let storage = self.storage.as_slice();
396
397        // Iterate through all indices
398        let total = self.numel();
399        for i in 0..total {
400            let indices = crate::shape::unravel_index(i, &self.shape);
401            let offset = self.offset + linear_index(&indices, &self.strides);
402            dst.push(storage[offset]);
403        }
404    }
405
406    // =========================================================================
407    // Shape Operations
408    // =========================================================================
409
410    /// Returns a new tensor with the specified shape.
411    ///
412    /// The total number of elements must remain the same.
413    /// Supports -1 in one dimension to infer the size.
414    ///
415    /// # Arguments
416    /// * `new_shape` - Target shape
417    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
418        let shape = reshape(&self.shape, new_shape)?;
419
420        if self.is_contiguous() {
421            // Can just change shape without copying
422            Ok(Self {
423                storage: self.storage.clone(),
424                strides: contiguous_strides(&shape),
425                shape,
426                offset: self.offset,
427            })
428        } else {
429            // Need to make contiguous first
430            let contig = self.contiguous();
431            Ok(Self {
432                storage: contig.storage,
433                strides: contiguous_strides(&shape),
434                shape,
435                offset: 0,
436            })
437        }
438    }
439
440    /// Returns a new tensor with a flattened shape.
441    #[must_use]
442    pub fn flatten(&self) -> Self {
443        self.reshape(&[-1]).expect("Flatten should never fail")
444    }
445
446    /// Returns a new tensor with dimensions of size 1 removed.
447    ///
448    /// # Arguments
449    /// * `dim` - Optional specific dimension to squeeze
450    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
451        let dim = match dim {
452            Some(d) => Some(normalize_dim(d, self.ndim())?),
453            None => None,
454        };
455
456        let new_shape = squeeze(&self.shape, dim);
457        let new_strides: Strides = match dim {
458            Some(d) => {
459                let mut s = self.strides.clone();
460                if d < self.shape.len() && self.shape[d] == 1 {
461                    s.remove(d);
462                }
463                s
464            }
465            None => self
466                .shape
467                .iter()
468                .zip(self.strides.iter())
469                .filter(|(dim, _)| **dim != 1)
470                .map(|(_, stride)| *stride)
471                .collect(),
472        };
473
474        Ok(Self {
475            storage: self.storage.clone(),
476            shape: new_shape,
477            strides: new_strides,
478            offset: self.offset,
479        })
480    }
481
482    /// Returns a new tensor with a dimension of size 1 inserted.
483    ///
484    /// # Arguments
485    /// * `dim` - Position to insert the new dimension
486    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
487        let normalized = if dim < 0 {
488            (dim + self.ndim() as i64 + 1) as usize
489        } else {
490            dim as usize
491        };
492
493        let new_shape = unsqueeze(&self.shape, normalized)?;
494        let mut new_strides = Strides::with_capacity(new_shape.len());
495
496        for (i, _) in new_shape.iter().enumerate() {
497            if i < normalized {
498                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
499            } else if i == normalized {
500                // Stride for new dimension (doesn't matter since size is 1)
501                new_strides.push(1);
502            } else {
503                new_strides.push(self.strides[i - 1]);
504            }
505        }
506
507        Ok(Self {
508            storage: self.storage.clone(),
509            shape: new_shape,
510            strides: new_strides,
511            offset: self.offset,
512        })
513    }
514
515    /// Transposes two dimensions.
516    ///
517    /// # Arguments
518    /// * `dim0` - First dimension
519    /// * `dim1` - Second dimension
520    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
521        let d0 = normalize_dim(dim0, self.ndim())?;
522        let d1 = normalize_dim(dim1, self.ndim())?;
523
524        let new_shape = transpose_shape(&self.shape, d0, d1)?;
525        let new_strides = transpose_strides(&self.strides, d0, d1);
526
527        Ok(Self {
528            storage: self.storage.clone(),
529            shape: new_shape,
530            strides: new_strides,
531            offset: self.offset,
532        })
533    }
534
535    /// Returns the transpose of a 2D tensor.
536    pub fn t(&self) -> Result<Self> {
537        if self.ndim() != 2 {
538            return Err(Error::invalid_operation("t() only works on 2D tensors"));
539        }
540        self.transpose(0, 1)
541    }
542
543    /// Returns a permuted tensor with dimensions reordered.
544    ///
545    /// # Arguments
546    /// * `dims` - New order of dimensions
547    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
548        if dims.len() != self.ndim() {
549            return Err(Error::invalid_operation(format!(
550                "Expected {} dimensions, got {}",
551                self.ndim(),
552                dims.len()
553            )));
554        }
555
556        // Check that dims is a permutation
557        let mut seen = vec![false; self.ndim()];
558        for &d in dims {
559            if d >= self.ndim() {
560                return Err(Error::InvalidDimension {
561                    index: d as i64,
562                    ndim: self.ndim(),
563                });
564            }
565            if seen[d] {
566                return Err(Error::invalid_operation("Duplicate dimension in permute"));
567            }
568            seen[d] = true;
569        }
570
571        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
572        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
573
574        Ok(Self {
575            storage: self.storage.clone(),
576            shape: new_shape,
577            strides: new_strides,
578            offset: self.offset,
579        })
580    }
581
582    /// Returns a contiguous copy of the tensor.
583    #[must_use]
584    pub fn contiguous(&self) -> Self {
585        if self.is_contiguous() && self.offset == 0 {
586            return self.clone();
587        }
588
589        #[cfg(feature = "cuda")]
590        if self.storage.is_gpu() {
591            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
592            let self_f32 = unsafe { gpu_ref(self) };
593            let result = self_f32.contiguous_gpu();
594            return unsafe { gpu_into(result) };
595        }
596
597        let data = self.to_vec();
598        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
599    }
600
601    // =========================================================================
602    // Device Operations
603    // =========================================================================
604
605    /// Transfers the tensor to a different device.
606    ///
607    /// # Arguments
608    /// * `device` - Target device
609    pub fn to_device(&self, device: Device) -> Result<Self> {
610        if self.device() == device {
611            return Ok(self.clone());
612        }
613
614        #[cfg(feature = "cuda")]
615        if self.storage.is_gpu() || device.is_gpu() {
616            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
617            let self_f32 = unsafe { gpu_ref(self) };
618            let result = self_f32.to_device_f32(device)?;
619            return Ok(unsafe { gpu_into(result) });
620        }
621
622        let contig = self.contiguous();
623        let new_storage = contig.storage.to_device(device)?;
624
625        Ok(Self {
626            storage: new_storage,
627            shape: self.shape.clone(),
628            strides: self.strides.clone(),
629            offset: 0,
630        })
631    }
632
633    /// Transfers to CPU.
634    pub fn cpu(&self) -> Result<Self> {
635        self.to_device(Device::Cpu)
636    }
637
638    // =========================================================================
639    // Deep Copy
640    // =========================================================================
641
642    /// Creates a deep copy of this tensor with its own storage.
643    #[must_use]
644    pub fn clone_deep(&self) -> Self {
645        let data = self.to_vec();
646        let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
647        #[cfg(feature = "cuda")]
648        if self.device().is_gpu() {
649            return cpu.to_device(self.device()).unwrap();
650        }
651        cpu
652    }
653}
654
655// =============================================================================
656// Numeric Operations
657// =============================================================================
658
659impl<T: Numeric> Tensor<T> {
660    /// Fills the tensor with a value.
661    pub fn fill_(&self, value: T) {
662        #[cfg(feature = "cuda")]
663        if self.storage.is_gpu() {
664            panic!("fill_() not supported on GPU tensors — create a new tensor instead");
665        }
666        let mut data = self.storage.as_slice_mut();
667        CpuBackend::fill(&mut data, value);
668    }
669
670    /// Fills the tensor with zeros.
671    pub fn zero_(&self) {
672        self.fill_(T::zero());
673    }
674
675    // =========================================================================
676    // Reduction Operations
677    // =========================================================================
678
679    /// Returns the sum of all elements.
680    #[must_use]
681    pub fn sum(&self) -> Self {
682        let data = self.to_vec();
683        let result = CpuBackend::sum(&data);
684        let s = Self::scalar(result);
685        #[cfg(feature = "cuda")]
686        if self.device().is_gpu() {
687            return s.to_device(self.device()).unwrap();
688        }
689        s
690    }
691
692    /// Returns the product of all elements.
693    #[must_use]
694    pub fn prod(&self) -> Self {
695        let data = self.to_vec();
696        let result = CpuBackend::prod(&data);
697        let s = Self::scalar(result);
698        #[cfg(feature = "cuda")]
699        if self.device().is_gpu() {
700            return s.to_device(self.device()).unwrap();
701        }
702        s
703    }
704
705    /// Returns the maximum element.
706    pub fn max(&self) -> Result<Self> {
707        if self.is_empty() {
708            return Err(Error::EmptyTensor);
709        }
710        let data = self.to_vec();
711        let result = CpuBackend::max(&data).unwrap();
712        let s = Self::scalar(result);
713        #[cfg(feature = "cuda")]
714        if self.device().is_gpu() {
715            return Ok(s.to_device(self.device()).unwrap());
716        }
717        Ok(s)
718    }
719
720    /// Returns the minimum element.
721    pub fn min(&self) -> Result<Self> {
722        if self.is_empty() {
723            return Err(Error::EmptyTensor);
724        }
725        let data = self.to_vec();
726        let result = CpuBackend::min(&data).unwrap();
727        let s = Self::scalar(result);
728        #[cfg(feature = "cuda")]
729        if self.device().is_gpu() {
730            return Ok(s.to_device(self.device()).unwrap());
731        }
732        Ok(s)
733    }
734
735    /// Returns the index of the maximum element.
736    pub fn argmax(&self) -> Result<usize> {
737        if self.is_empty() {
738            return Err(Error::EmptyTensor);
739        }
740        let data = self.to_vec();
741        Ok(CpuBackend::argmax(&data).unwrap())
742    }
743
744    /// Returns the index of the minimum element.
745    pub fn argmin(&self) -> Result<usize> {
746        if self.is_empty() {
747            return Err(Error::EmptyTensor);
748        }
749        let data = self.to_vec();
750        Ok(CpuBackend::argmin(&data).unwrap())
751    }
752
753    /// Concatenates tensors along a dimension.
754    ///
755    /// All tensors must have the same shape except along the cat dimension.
756    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
757        if tensors.is_empty() {
758            return Err(Error::invalid_operation("cat requires at least one tensor"));
759        }
760        let ndim = tensors[0].ndim();
761        if dim >= ndim {
762            return Err(Error::invalid_operation("cat dimension out of range"));
763        }
764
765        for t in &tensors[1..] {
766            if t.ndim() != ndim {
767                return Err(Error::invalid_operation(
768                    "cat: all tensors must have same ndim",
769                ));
770            }
771            for d in 0..ndim {
772                if d != dim && t.shape[d] != tensors[0].shape[d] {
773                    return Err(Error::invalid_operation(
774                        "cat: shapes must match on non-cat dims",
775                    ));
776                }
777            }
778        }
779
780        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
781        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
782        out_shape[dim] = total_dim_size;
783
784        let outer_size: usize = out_shape[..dim].iter().product();
785        let inner_size: usize = out_shape[dim + 1..].iter().product();
786        let total_numel: usize = out_shape.iter().product();
787        let mut result = vec![T::zero(); total_numel];
788
789        let mut dim_offset = 0;
790        for t in tensors {
791            let t_data = t.contiguous().to_vec();
792            let t_dim_size = t.shape[dim];
793            for outer in 0..outer_size {
794                for d in 0..t_dim_size {
795                    let src_base = outer * t_dim_size * inner_size + d * inner_size;
796                    let dst_base =
797                        outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
798                    result[dst_base..dst_base + inner_size]
799                        .copy_from_slice(&t_data[src_base..src_base + inner_size]);
800                }
801            }
802            dim_offset += t_dim_size;
803        }
804
805        let out = Self::from_vec(result, &out_shape)?;
806        #[cfg(feature = "cuda")]
807        if tensors[0].device().is_gpu() {
808            return Ok(out.to_device(tensors[0].device()).unwrap());
809        }
810        Ok(out)
811    }
812}
813
814// =============================================================================
815// Float Operations
816// =============================================================================
817
818impl<T: Float> Tensor<T> {
819    /// Returns the mean of all elements.
820    pub fn mean(&self) -> Result<Self> {
821        if self.is_empty() {
822            return Err(Error::EmptyTensor);
823        }
824        let data = self.to_vec();
825        let result = CpuBackend::mean(&data).unwrap();
826        let s = Self::scalar(result);
827        #[cfg(feature = "cuda")]
828        if self.device().is_gpu() {
829            return Ok(s.to_device(self.device()).unwrap());
830        }
831        Ok(s)
832    }
833
834    // =========================================================================
835    // Activation Functions
836    // =========================================================================
837
838    /// Applies `ReLU` activation: max(0, x).
839    #[must_use]
840    pub fn relu(&self) -> Self {
841        #[cfg(feature = "cuda")]
842        if self.device().is_gpu() {
843            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
844            return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
845        }
846        let data = self.to_vec();
847        let mut result = vec![T::zero(); data.len()];
848        CpuBackend::relu(&mut result, &data);
849        Self::from_vec(result, &self.shape).unwrap()
850    }
851
852    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
853    #[must_use]
854    pub fn sigmoid(&self) -> Self {
855        #[cfg(feature = "cuda")]
856        if self.device().is_gpu() {
857            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
858            return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
859        }
860        let data = self.to_vec();
861        let mut result = vec![T::zero(); data.len()];
862        CpuBackend::sigmoid(&mut result, &data);
863        Self::from_vec(result, &self.shape).unwrap()
864    }
865
866    /// Applies tanh activation.
867    #[must_use]
868    pub fn tanh(&self) -> Self {
869        #[cfg(feature = "cuda")]
870        if self.device().is_gpu() {
871            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
872            return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
873        }
874        let data = self.to_vec();
875        let mut result = vec![T::zero(); data.len()];
876        CpuBackend::tanh(&mut result, &data);
877        Self::from_vec(result, &self.shape).unwrap()
878    }
879
880    /// Applies exponential function.
881    #[must_use]
882    pub fn exp(&self) -> Self {
883        #[cfg(feature = "cuda")]
884        if self.device().is_gpu() {
885            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
886            return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
887        }
888        let data = self.to_vec();
889        let mut result = vec![T::zero(); data.len()];
890        CpuBackend::exp(&mut result, &data);
891        Self::from_vec(result, &self.shape).unwrap()
892    }
893
894    /// Applies natural logarithm.
895    #[must_use]
896    pub fn ln(&self) -> Self {
897        #[cfg(feature = "cuda")]
898        if self.device().is_gpu() {
899            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
900            return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
901        }
902        let data = self.to_vec();
903        let mut result = vec![T::zero(); data.len()];
904        CpuBackend::ln(&mut result, &data);
905        Self::from_vec(result, &self.shape).unwrap()
906    }
907
908    /// Applies square root.
909    #[must_use]
910    pub fn sqrt(&self) -> Self {
911        #[cfg(feature = "cuda")]
912        if self.device().is_gpu() {
913            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
914            return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
915        }
916        let data = self.to_vec();
917        let mut result = vec![T::zero(); data.len()];
918        CpuBackend::sqrt(&mut result, &data);
919        Self::from_vec(result, &self.shape).unwrap()
920    }
921
922    /// Computes element-wise power.
923    #[must_use]
924    pub fn pow(&self, exp: T) -> Self {
925        #[cfg(feature = "cuda")]
926        if self.device().is_gpu() {
927            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
928            let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
929            return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
930        }
931        let data = self.to_vec();
932        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
933        Self::from_vec(result, &self.shape).unwrap()
934    }
935
936    /// GELU activation function (Gaussian Error Linear Unit).
937    #[must_use]
938    pub fn gelu(&self) -> Self {
939        #[cfg(feature = "cuda")]
940        if self.device().is_gpu() {
941            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
942            return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
943        }
944        crate::ops::gelu(self)
945    }
946
947    /// SiLU/Swish activation function.
948    #[must_use]
949    pub fn silu(&self) -> Self {
950        #[cfg(feature = "cuda")]
951        if self.device().is_gpu() {
952            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
953            return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
954        }
955        crate::ops::silu(self)
956    }
957
958    /// Softmax along specified dimension.
959    #[must_use]
960    pub fn softmax(&self, dim: i32) -> Self {
961        #[cfg(feature = "cuda")]
962        if self.device().is_gpu() {
963            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
964            let self_f32 = unsafe { gpu_ref(self) };
965            return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
966        }
967        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
968    }
969
970    /// Log softmax along specified dimension.
971    #[must_use]
972    pub fn log_softmax(&self, dim: i32) -> Self {
973        let softmax_result = self.softmax(dim);
974        softmax_result.ln()
975    }
976
977    /// Mean along a dimension.
978    #[must_use]
979    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
980        let ndim = self.ndim();
981        let dim = if dim < 0 {
982            (ndim as i32 + dim) as usize
983        } else {
984            dim as usize
985        };
986
987        if dim >= ndim {
988            return self.clone();
989        }
990
991        // GPU fast path: sum_dim then divide by dim_size (all on GPU)
992        #[cfg(feature = "cuda")]
993        if self.device().is_gpu() {
994            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
995            let self_f32 = unsafe { gpu_ref(self) };
996            let summed = if keepdim {
997                self_f32.sum_dim_keepdim_cuda(dim)
998            } else {
999                self_f32.sum_dim_cuda(dim)
1000            };
1001            let dim_size = self.shape[dim];
1002            let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1003            return unsafe { gpu_into(result) };
1004        }
1005
1006        let dim_size = self.shape[dim];
1007        let data = self.to_vec();
1008        let mut new_shape = self.shape.clone();
1009
1010        if keepdim {
1011            new_shape[dim] = 1;
1012        } else {
1013            new_shape.remove(dim);
1014        }
1015
1016        if new_shape.is_empty() {
1017            new_shape = smallvec::smallvec![1];
1018        }
1019
1020        let new_numel: usize = new_shape.iter().product();
1021        let mut result = vec![T::zero(); new_numel];
1022
1023        let outer_size: usize = self.shape[..dim].iter().product();
1024        let inner_size: usize = self.shape[dim + 1..].iter().product();
1025
1026        for outer in 0..outer_size {
1027            for inner in 0..inner_size {
1028                let mut sum = T::zero();
1029                for d in 0..dim_size {
1030                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1031                    sum = sum + data[idx];
1032                }
1033                let mean = sum / NumCast::from(dim_size).unwrap();
1034                let result_idx = outer * inner_size + inner;
1035                result[result_idx] = mean;
1036            }
1037        }
1038
1039        Self::from_vec(result, &new_shape).unwrap()
1040    }
1041
1042    /// Sum along a dimension.
1043    #[must_use]
1044    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1045        let ndim = self.ndim();
1046        let dim = if dim < 0 {
1047            (ndim as i32 + dim) as usize
1048        } else {
1049            dim as usize
1050        };
1051
1052        if dim >= ndim {
1053            return self.clone();
1054        }
1055
1056        // GPU fast path: use CUDA sum_dim kernel (no CPU copies)
1057        #[cfg(feature = "cuda")]
1058        if self.device().is_gpu() {
1059            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1060            let self_f32 = unsafe { gpu_ref(self) };
1061            let result = if keepdim {
1062                self_f32.sum_dim_keepdim_cuda(dim)
1063            } else {
1064                self_f32.sum_dim_cuda(dim)
1065            };
1066            return unsafe { gpu_into(result) };
1067        }
1068
1069        let dim_size = self.shape[dim];
1070        let data = self.to_vec();
1071        let mut new_shape = self.shape.clone();
1072
1073        if keepdim {
1074            new_shape[dim] = 1;
1075        } else {
1076            new_shape.remove(dim);
1077        }
1078
1079        if new_shape.is_empty() {
1080            new_shape = smallvec::smallvec![1];
1081        }
1082
1083        let new_numel: usize = new_shape.iter().product();
1084        let mut result = vec![T::zero(); new_numel];
1085
1086        let outer_size: usize = self.shape[..dim].iter().product();
1087        let inner_size: usize = self.shape[dim + 1..].iter().product();
1088
1089        for outer in 0..outer_size {
1090            for inner in 0..inner_size {
1091                let mut sum = T::zero();
1092                for d in 0..dim_size {
1093                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1094                    sum = sum + data[idx];
1095                }
1096                let result_idx = outer * inner_size + inner;
1097                result[result_idx] = sum;
1098            }
1099        }
1100
1101        Self::from_vec(result, &new_shape).unwrap()
1102    }
1103
1104    /// Variance along a dimension.
1105    #[must_use]
1106    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1107        // variance = E[x²] - E[x]²  (saves one full-size intermediate allocation)
1108        let mean = self.mean_dim(dim, true);
1109        let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1110        let mean_sq = sq.mean_dim(dim, keepdim);
1111        let mean_keepdim = if keepdim {
1112            mean.clone()
1113        } else {
1114            self.mean_dim(dim, keepdim)
1115        };
1116        let mean_squared = mean_keepdim
1117            .mul(&mean_keepdim)
1118            .unwrap_or_else(|_| mean_keepdim.clone());
1119        mean_sq
1120            .sub(&mean_squared)
1121            .unwrap_or_else(|_| mean_sq.clone())
1122    }
1123
1124    /// Broadcasts tensor to a new shape.
1125    #[must_use]
1126    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1127        if self.shape.as_slice() == shape {
1128            return self.clone();
1129        }
1130
1131        #[cfg(feature = "cuda")]
1132        if self.device().is_gpu() {
1133            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1134            let self_f32 = unsafe { gpu_ref(self) };
1135            return unsafe {
1136                gpu_into(
1137                    self_f32
1138                        .broadcast_to_cuda(shape)
1139                        .expect("CUDA broadcast_to failed"),
1140                )
1141            };
1142        }
1143
1144        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1145        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1146
1147        let total = numel(&result_shape);
1148        let mut result_data = vec![T::zero(); total];
1149        let self_data = self.storage.as_slice();
1150
1151        for i in 0..total {
1152            let indices = crate::shape::unravel_index(i, &result_shape);
1153            let self_idx = self.offset + linear_index(&indices, &self_strides);
1154            result_data[i] = self_data[self_idx];
1155        }
1156
1157        Self::from_vec(result_data, &result_shape).unwrap()
1158    }
1159
1160    /// Slices the tensor using ranges for each dimension.
1161    #[must_use]
1162    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1163        let mut new_shape = Vec::with_capacity(self.ndim());
1164        for (i, range) in ranges.iter().enumerate() {
1165            if i < self.ndim() {
1166                new_shape.push(range.end - range.start);
1167            }
1168        }
1169        // Keep remaining dimensions unchanged
1170        for i in ranges.len()..self.ndim() {
1171            new_shape.push(self.shape[i]);
1172        }
1173
1174        let new_numel: usize = new_shape.iter().product();
1175        let mut result_data = vec![T::zero(); new_numel];
1176        let self_data = self.to_vec();
1177
1178        // Copy data with proper indexing
1179        let mut result_idx = 0;
1180        Self::slice_recursive(
1181            &self_data,
1182            &self.shape,
1183            ranges,
1184            0,
1185            0,
1186            &mut result_data,
1187            &mut result_idx,
1188        );
1189
1190        let out = Self::from_vec(result_data, &new_shape).unwrap();
1191        #[cfg(feature = "cuda")]
1192        if self.device().is_gpu() {
1193            return out.to_device(self.device()).unwrap();
1194        }
1195        out
1196    }
1197
1198    fn slice_recursive(
1199        data: &[T],
1200        shape: &[usize],
1201        ranges: &[std::ops::Range<usize>],
1202        dim: usize,
1203        offset: usize,
1204        result: &mut [T],
1205        result_idx: &mut usize,
1206    ) {
1207        if dim == shape.len() {
1208            result[*result_idx] = data[offset];
1209            *result_idx += 1;
1210            return;
1211        }
1212
1213        let stride: usize = shape[dim + 1..].iter().product();
1214        let (start, end) = if dim < ranges.len() {
1215            (ranges[dim].start, ranges[dim].end)
1216        } else {
1217            (0, shape[dim])
1218        };
1219
1220        for i in start..end {
1221            Self::slice_recursive(
1222                data,
1223                shape,
1224                ranges,
1225                dim + 1,
1226                offset + i * stride,
1227                result,
1228                result_idx,
1229            );
1230        }
1231    }
1232}
1233
1234// =============================================================================
1235// Arithmetic Operator Implementations
1236// =============================================================================
1237
1238impl<T: Numeric> Tensor<T> {
1239    /// Element-wise addition with broadcasting.
1240    pub fn add(&self, other: &Self) -> Result<Self> {
1241        #[cfg(feature = "cuda")]
1242        {
1243            let self_gpu = self.device().is_gpu();
1244            let other_gpu = other.device().is_gpu();
1245            if self_gpu || other_gpu {
1246                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1247                if self_gpu && other_gpu {
1248                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1249                    if self.shape == other.shape {
1250                        return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1251                    } else {
1252                        return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1253                    }
1254                }
1255                // Mixed device — move to GPU, then operate
1256                let target_device = if self_gpu {
1257                    self.device()
1258                } else {
1259                    other.device()
1260                };
1261                let a_gpu = if self_gpu {
1262                    self.clone()
1263                } else {
1264                    self.to_device(target_device)?
1265                };
1266                let b_gpu = if other_gpu {
1267                    other.clone()
1268                } else {
1269                    other.to_device(target_device)?
1270                };
1271                return a_gpu.add(&b_gpu);
1272            }
1273        }
1274        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1275        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1276        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1277
1278        let total = numel(&result_shape);
1279        let mut result_data = vec![T::zero(); total];
1280
1281        let self_data = self.storage.as_slice();
1282        let other_data = other.storage.as_slice();
1283
1284        for i in 0..total {
1285            let indices = crate::shape::unravel_index(i, &result_shape);
1286            let self_idx = self.offset + linear_index(&indices, &self_strides);
1287            let other_idx = other.offset + linear_index(&indices, &other_strides);
1288            result_data[i] = self_data[self_idx] + other_data[other_idx];
1289        }
1290
1291        Self::from_vec(result_data, &result_shape)
1292    }
1293
1294    /// Element-wise subtraction with broadcasting.
1295    pub fn sub(&self, other: &Self) -> Result<Self> {
1296        #[cfg(feature = "cuda")]
1297        {
1298            let self_gpu = self.device().is_gpu();
1299            let other_gpu = other.device().is_gpu();
1300            if self_gpu || other_gpu {
1301                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1302                if self_gpu && other_gpu {
1303                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1304                    if self.shape == other.shape {
1305                        return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1306                    } else {
1307                        return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1308                    }
1309                }
1310                let target = if self_gpu {
1311                    self.device()
1312                } else {
1313                    other.device()
1314                };
1315                let a_gpu = if self_gpu {
1316                    self.clone()
1317                } else {
1318                    self.to_device(target)?
1319                };
1320                let b_gpu = if other_gpu {
1321                    other.clone()
1322                } else {
1323                    other.to_device(target)?
1324                };
1325                return a_gpu.sub(&b_gpu);
1326            }
1327        }
1328        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1329        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1330        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1331
1332        let total = numel(&result_shape);
1333        let mut result_data = vec![T::zero(); total];
1334
1335        let self_data = self.storage.as_slice();
1336        let other_data = other.storage.as_slice();
1337
1338        for i in 0..total {
1339            let indices = crate::shape::unravel_index(i, &result_shape);
1340            let self_idx = self.offset + linear_index(&indices, &self_strides);
1341            let other_idx = other.offset + linear_index(&indices, &other_strides);
1342            result_data[i] = self_data[self_idx] - other_data[other_idx];
1343        }
1344
1345        Self::from_vec(result_data, &result_shape)
1346    }
1347
1348    /// Element-wise multiplication with broadcasting.
1349    pub fn mul(&self, other: &Self) -> Result<Self> {
1350        #[cfg(feature = "cuda")]
1351        {
1352            let self_gpu = self.device().is_gpu();
1353            let other_gpu = other.device().is_gpu();
1354            if self_gpu || other_gpu {
1355                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1356                if self_gpu && other_gpu {
1357                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1358                    if self.shape == other.shape {
1359                        return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1360                    } else {
1361                        return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1362                    }
1363                }
1364                let target = if self_gpu {
1365                    self.device()
1366                } else {
1367                    other.device()
1368                };
1369                let a_gpu = if self_gpu {
1370                    self.clone()
1371                } else {
1372                    self.to_device(target)?
1373                };
1374                let b_gpu = if other_gpu {
1375                    other.clone()
1376                } else {
1377                    other.to_device(target)?
1378                };
1379                return a_gpu.mul(&b_gpu);
1380            }
1381        }
1382        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1383        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1384        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1385
1386        let total = numel(&result_shape);
1387        let mut result_data = vec![T::zero(); total];
1388
1389        let self_data = self.storage.as_slice();
1390        let other_data = other.storage.as_slice();
1391
1392        for i in 0..total {
1393            let indices = crate::shape::unravel_index(i, &result_shape);
1394            let self_idx = self.offset + linear_index(&indices, &self_strides);
1395            let other_idx = other.offset + linear_index(&indices, &other_strides);
1396            result_data[i] = self_data[self_idx] * other_data[other_idx];
1397        }
1398
1399        Self::from_vec(result_data, &result_shape)
1400    }
1401
1402    /// Element-wise division with broadcasting.
1403    pub fn div(&self, other: &Self) -> Result<Self> {
1404        #[cfg(feature = "cuda")]
1405        {
1406            let self_gpu = self.device().is_gpu();
1407            let other_gpu = other.device().is_gpu();
1408            if self_gpu || other_gpu {
1409                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1410                if self_gpu && other_gpu {
1411                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1412                    if self.shape == other.shape {
1413                        return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1414                    } else {
1415                        return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1416                    }
1417                }
1418                let target = if self_gpu {
1419                    self.device()
1420                } else {
1421                    other.device()
1422                };
1423                let a_gpu = if self_gpu {
1424                    self.clone()
1425                } else {
1426                    self.to_device(target)?
1427                };
1428                let b_gpu = if other_gpu {
1429                    other.clone()
1430                } else {
1431                    other.to_device(target)?
1432                };
1433                return a_gpu.div(&b_gpu);
1434            }
1435        }
1436        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1437        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1438        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1439
1440        let total = numel(&result_shape);
1441        let mut result_data = vec![T::zero(); total];
1442
1443        let self_data = self.storage.as_slice();
1444        let other_data = other.storage.as_slice();
1445
1446        for i in 0..total {
1447            let indices = crate::shape::unravel_index(i, &result_shape);
1448            let self_idx = self.offset + linear_index(&indices, &self_strides);
1449            let other_idx = other.offset + linear_index(&indices, &other_strides);
1450            result_data[i] = self_data[self_idx] / other_data[other_idx];
1451        }
1452
1453        Self::from_vec(result_data, &result_shape)
1454    }
1455
1456    /// Scalar addition.
1457    #[must_use]
1458    pub fn add_scalar(&self, scalar: T) -> Self {
1459        #[cfg(feature = "cuda")]
1460        if self.device().is_gpu() {
1461            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1462            let self_f32 = unsafe { gpu_ref(self) };
1463            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1464            return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1465        }
1466        let data = self.to_vec();
1467        let mut result = vec![T::zero(); data.len()];
1468        CpuBackend::add_scalar(&mut result, &data, scalar);
1469        Self::from_vec(result, &self.shape).unwrap()
1470    }
1471
1472    /// Scalar multiplication.
1473    #[must_use]
1474    pub fn mul_scalar(&self, scalar: T) -> Self {
1475        #[cfg(feature = "cuda")]
1476        if self.device().is_gpu() {
1477            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1478            let self_f32 = unsafe { gpu_ref(self) };
1479            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1480            return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1481        }
1482        let data = self.to_vec();
1483        let mut result = vec![T::zero(); data.len()];
1484        CpuBackend::mul_scalar(&mut result, &data, scalar);
1485        Self::from_vec(result, &self.shape).unwrap()
1486    }
1487
1488    /// Element-wise negation.
1489    #[must_use]
1490    pub fn neg(&self) -> Self {
1491        #[cfg(feature = "cuda")]
1492        if self.device().is_gpu() {
1493            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1494            let self_f32 = unsafe { gpu_ref(self) };
1495            return unsafe { gpu_into(self_f32.neg_cuda()) };
1496        }
1497        let data = self.to_vec();
1498        let mut result = vec![T::zero(); data.len()];
1499        CpuBackend::neg(&mut result, &data);
1500        Self::from_vec(result, &self.shape).unwrap()
1501    }
1502
1503    /// Matrix multiplication with batching support.
1504    ///
1505    /// Supports:
1506    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1507    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1508    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1509    pub fn matmul(&self, other: &Self) -> Result<Self> {
1510        #[cfg(feature = "cuda")]
1511        if self.device().is_gpu() {
1512            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1513            let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1514            return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1515        }
1516        if self.ndim() < 2 || other.ndim() < 2 {
1517            return Err(Error::invalid_operation(
1518                "matmul requires at least 2D tensors",
1519            ));
1520        }
1521
1522        let m = self.shape[self.ndim() - 2];
1523        let k1 = self.shape[self.ndim() - 1];
1524        let k2 = other.shape[other.ndim() - 2];
1525        let n = other.shape[other.ndim() - 1];
1526
1527        if k1 != k2 {
1528            return Err(Error::invalid_operation(format!(
1529                "matmul inner dimensions must match: {k1} vs {k2}"
1530            )));
1531        }
1532
1533        // For 2D matrices, do simple matmul
1534        if self.ndim() == 2 && other.ndim() == 2 {
1535            let a_data = self.contiguous().to_vec();
1536            let b_data = other.contiguous().to_vec();
1537
1538            // GPU-accelerated matmul for CPU tensors: only for very large matrices
1539            // where transfer overhead is negligible relative to compute.
1540            // For GPU-resident tensors, the dispatch at the top of matmul() handles it.
1541            #[cfg(feature = "cuda")]
1542            {
1543                let flops = m * n * k1;
1544                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1545                    && flops >= 4_000_000
1546                {
1547                    let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1548                    let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1549                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1550                        let c_t: Vec<T> = unsafe {
1551                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1552                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1553                        };
1554                        return Self::from_vec(c_t, &[m, n]);
1555                    }
1556                }
1557            }
1558
1559            let mut c_data = vec![T::zero(); m * n];
1560            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1561            return Self::from_vec(c_data, &[m, n]);
1562        }
1563
1564        // For batched matmul, compute batch size
1565        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1566        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1567
1568        if batch_dims_self != batch_dims_other {
1569            return Err(Error::invalid_operation(format!(
1570                "matmul batch dimensions must match: {:?} vs {:?}",
1571                batch_dims_self, batch_dims_other
1572            )));
1573        }
1574
1575        let batch_size: usize = batch_dims_self.iter().product();
1576        let a_stride = m * k1;
1577        let b_stride = k1 * n;
1578        let c_stride = m * n;
1579
1580        let a_data = self.contiguous().to_vec();
1581        let b_data = other.contiguous().to_vec();
1582        let mut c_data = vec![T::zero(); batch_size * m * n];
1583
1584        // Try GPU acceleration for f32 batched matmul (only for large enough matrices)
1585        #[cfg(feature = "cuda")]
1586        {
1587            let flops = m * n * k1;
1588            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1589                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1590                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1591                let mut gpu_ok = true;
1592                for batch in 0..batch_size {
1593                    let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1594                    let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1595                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1596                        c_data[batch * c_stride..(batch + 1) * c_stride]
1597                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1598                    } else {
1599                        gpu_ok = false;
1600                        break;
1601                    }
1602                }
1603                if gpu_ok {
1604                    let mut output_shape = batch_dims_self;
1605                    output_shape.push(m);
1606                    output_shape.push(n);
1607                    return Self::from_vec(c_data, &output_shape);
1608                }
1609                // Fall through to CPU if GPU failed
1610                c_data = vec![T::zero(); batch_size * m * n];
1611            }
1612        }
1613
1614        // CPU fallback: loop over batches
1615        for batch in 0..batch_size {
1616            let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1617            let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1618            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1619            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1620        }
1621
1622        // Build output shape: batch_dims + [m, n]
1623        let mut output_shape = batch_dims_self;
1624        output_shape.push(m);
1625        output_shape.push(n);
1626
1627        Self::from_vec(c_data, &output_shape)
1628    }
1629
1630    /// Dot product for 1D tensors.
1631    pub fn dot(&self, other: &Self) -> Result<Self> {
1632        if self.ndim() != 1 || other.ndim() != 1 {
1633            return Err(Error::invalid_operation("dot requires 1D tensors"));
1634        }
1635
1636        if self.shape[0] != other.shape[0] {
1637            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1638        }
1639
1640        let a_data = self.to_vec();
1641        let b_data = other.to_vec();
1642        let result = CpuBackend::dot(&a_data, &b_data);
1643
1644        Ok(Self::scalar(result))
1645    }
1646}
1647
1648// =============================================================================
1649// Operator Trait Implementations
1650// =============================================================================
1651
1652impl<T: Numeric> Add for &Tensor<T> {
1653    type Output = Tensor<T>;
1654
1655    fn add(self, other: Self) -> Self::Output {
1656        self.add(other).expect("Addition failed")
1657    }
1658}
1659
1660impl<T: Numeric> Sub for &Tensor<T> {
1661    type Output = Tensor<T>;
1662
1663    fn sub(self, other: Self) -> Self::Output {
1664        self.sub(other).expect("Subtraction failed")
1665    }
1666}
1667
1668impl<T: Numeric> Mul for &Tensor<T> {
1669    type Output = Tensor<T>;
1670
1671    fn mul(self, other: Self) -> Self::Output {
1672        self.mul(other).expect("Multiplication failed")
1673    }
1674}
1675
1676impl<T: Numeric> Div for &Tensor<T> {
1677    type Output = Tensor<T>;
1678
1679    fn div(self, other: Self) -> Self::Output {
1680        self.div(other).expect("Division failed")
1681    }
1682}
1683
1684impl<T: Numeric> Neg for &Tensor<T> {
1685    type Output = Tensor<T>;
1686
1687    fn neg(self) -> Self::Output {
1688        self.neg()
1689    }
1690}
1691
1692// Scalar operations
1693impl<T: Numeric> Add<T> for &Tensor<T> {
1694    type Output = Tensor<T>;
1695
1696    fn add(self, scalar: T) -> Self::Output {
1697        self.add_scalar(scalar)
1698    }
1699}
1700
1701impl<T: Numeric> Mul<T> for &Tensor<T> {
1702    type Output = Tensor<T>;
1703
1704    fn mul(self, scalar: T) -> Self::Output {
1705        self.mul_scalar(scalar)
1706    }
1707}
1708
1709// =============================================================================
1710// Display Implementation
1711// =============================================================================
1712
1713impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1714    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1715        write!(
1716            f,
1717            "Tensor(shape={:?}, device={}",
1718            self.shape(),
1719            self.device()
1720        )?;
1721        if self.numel() <= 10 {
1722            write!(f, ", data={:?}", self.to_vec())?;
1723        }
1724        write!(f, ")")
1725    }
1726}
1727
1728impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1729    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1730        if self.is_scalar() {
1731            write!(f, "{}", self.item().unwrap())
1732        } else if self.ndim() == 1 {
1733            write!(f, "[")?;
1734            let data = self.to_vec();
1735            for (i, val) in data.iter().enumerate() {
1736                if i > 0 {
1737                    write!(f, ", ")?;
1738                }
1739                write!(f, "{val}")?;
1740            }
1741            write!(f, "]")
1742        } else {
1743            write!(f, "Tensor(shape={:?})", self.shape())
1744        }
1745    }
1746}
1747
1748// =============================================================================
1749// Tests
1750// =============================================================================
1751
1752#[cfg(test)]
1753mod tests {
1754    use super::*;
1755
1756    #[test]
1757    fn test_from_vec() {
1758        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1759        assert_eq!(t.shape(), &[2, 3]);
1760        assert_eq!(t.numel(), 6);
1761    }
1762
1763    #[test]
1764    fn test_get_set() {
1765        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1766        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1767        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1768        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1769        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1770
1771        t.set(&[0, 0], 99.0).unwrap();
1772        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1773    }
1774
1775    #[test]
1776    fn test_reshape() {
1777        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1778        let r = t.reshape(&[3, 2]).unwrap();
1779        assert_eq!(r.shape(), &[3, 2]);
1780
1781        let r = t.reshape(&[-1]).unwrap();
1782        assert_eq!(r.shape(), &[6]);
1783    }
1784
1785    #[test]
1786    fn test_transpose() {
1787        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1788        let r = t.t().unwrap();
1789        assert_eq!(r.shape(), &[3, 2]);
1790        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1791        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1792        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1793    }
1794
1795    #[test]
1796    fn test_arithmetic() {
1797        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1798        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1799
1800        let c = &a + &b;
1801        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1802
1803        let d = &a * &b;
1804        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1805    }
1806
1807    #[test]
1808    fn test_broadcasting() {
1809        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1810        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1811
1812        let c = &a + &b;
1813        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1814    }
1815
1816    #[test]
1817    fn test_sum() {
1818        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1819        let s = t.sum();
1820        assert_eq!(s.item().unwrap(), 10.0);
1821    }
1822
1823    #[test]
1824    fn test_matmul() {
1825        // 2x2 @ 2x2
1826        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1827        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1828        let c = a.matmul(&b).unwrap();
1829
1830        assert_eq!(c.shape(), &[2, 2]);
1831        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1832    }
1833
1834    #[test]
1835    fn test_relu() {
1836        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1837        let r = t.relu();
1838        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1839    }
1840
1841    #[test]
1842    fn test_scalar() {
1843        let s = Tensor::<f32>::scalar(42.0);
1844        assert!(s.is_scalar());
1845        assert_eq!(s.numel(), 1);
1846        assert_eq!(s.item().unwrap(), 42.0);
1847    }
1848}