Skip to main content

axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! The `Tensor` struct is the fundamental data structure in Axonml. It represents
4//! an N-dimensional array of numeric values with support for automatic broadcasting,
5//! device placement, and efficient memory sharing through views.
6//!
7//! # Key Features
8//! - Generic over element type (f32, f64, i32, etc.)
9//! - Efficient views with shared storage
10//! - Device-agnostic operations
11//! - Broadcasting support
12//! - Lazy operations where possible
13//!
14//! @version 0.1.0
15//! @author `AutomataNexus` Development Team
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21#[cfg(feature = "cuda")]
22use axonml_core::backends::CudaBackend;
23use axonml_core::dtype::{Float, Numeric, Scalar};
24use axonml_core::error::{Error, Result};
25use axonml_core::storage::Storage;
26use axonml_core::Device;
27use num_traits::NumCast;
28
29// =============================================================================
30// CUDA Acceleration
31// =============================================================================
32
33#[cfg(feature = "cuda")]
34mod cuda_accel {
35    use super::*;
36    use axonml_core::backends::cuda::get_cuda_backend;
37
38    /// Get the global CUDA backend (delegates to core singleton).
39    pub fn get_cuda() -> Option<&'static CudaBackend> {
40        get_cuda_backend()
41    }
42
43    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
44    /// Returns None if GPU is unavailable or an error occurs.
45    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46        let cuda = get_cuda()?;
47
48        let a_gpu = cuda.htod_copy(a).ok()?;
49        let b_gpu = cuda.htod_copy(b).ok()?;
50        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
53        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
54        cuda.gemm_f32(
55            false, false,
56            n, m, k,
57            1.0,
58            &b_gpu, n,
59            &a_gpu, k,
60            0.0,
61            &mut c_gpu, n,
62        ).ok()?;
63
64        cuda.dtoh_copy(&c_gpu).ok()
65    }
66}
67
68use crate::shape::{
69    broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
70    normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
71    Strides,
72};
73
74// =============================================================================
75// GPU Dispatch Helpers
76// =============================================================================
77//
78// These enable calling Tensor<f32> GPU methods from generic Tensor<T> code
79// when T is verified to be f32 via TypeId check at runtime.
80
81#[cfg(feature = "cuda")]
82unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
83    &*(t as *const Tensor<T> as *const Tensor<f32>)
84}
85
86#[cfg(feature = "cuda")]
87unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
88    let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
89    std::mem::forget(t);
90    out
91}
92
93#[cfg(feature = "cuda")]
94fn is_f32<T: 'static>() -> bool {
95    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
96}
97
98// =============================================================================
99// Tensor Struct
100// =============================================================================
101
102/// An N-dimensional array of numeric values.
103///
104/// Tensors are the core data structure for all computations in Axonml.
105/// They support arbitrary dimensions, automatic broadcasting, and efficient
106/// memory sharing between views.
107#[derive(Clone)]
108pub struct Tensor<T: Scalar> {
109    /// Underlying data storage (reference-counted).
110    pub(crate) storage: Storage<T>,
111    /// Shape of the tensor (dimensions).
112    pub(crate) shape: Shape,
113    /// Strides for each dimension.
114    pub(crate) strides: Strides,
115    /// Offset into storage (for views).
116    pub(crate) offset: usize,
117}
118
119impl<T: Scalar> Tensor<T> {
120    // =========================================================================
121    // Constructors
122    // =========================================================================
123
124    /// Creates a new tensor from storage with the given shape.
125    ///
126    /// # Arguments
127    /// * `storage` - The underlying data storage
128    /// * `shape` - Shape of the tensor
129    ///
130    /// # Returns
131    /// New tensor, or error if shape doesn't match storage size.
132    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
133        let total = numel(shape);
134        if total != storage.len() {
135            return Err(Error::shape_mismatch(&[storage.len()], shape));
136        }
137
138        let shape = Shape::from_slice(shape);
139        let strides = contiguous_strides(&shape);
140
141        Ok(Self {
142            storage,
143            shape,
144            strides,
145            offset: 0,
146        })
147    }
148
149    /// Creates a new tensor from a vector with the given shape.
150    ///
151    /// # Arguments
152    /// * `data` - Vector of data
153    /// * `shape` - Shape of the tensor
154    ///
155    /// # Returns
156    /// New tensor, or error if shape doesn't match data length.
157    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
158        let storage = Storage::from_vec(data, Device::Cpu);
159        Self::from_storage(storage, shape)
160    }
161
162    /// Creates a new tensor from a slice with the given shape.
163    ///
164    /// # Arguments
165    /// * `data` - Slice of data to copy
166    /// * `shape` - Shape of the tensor
167    ///
168    /// # Returns
169    /// New tensor, or error if shape doesn't match data length.
170    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
171        let storage = Storage::from_slice(data, Device::Cpu);
172        Self::from_storage(storage, shape)
173    }
174
175    /// Creates a scalar tensor (0-dimensional).
176    ///
177    /// # Arguments
178    /// * `value` - The scalar value
179    ///
180    /// # Returns
181    /// New 0-dimensional tensor.
182    pub fn scalar(value: T) -> Self {
183        Self {
184            storage: Storage::from_vec(vec![value], Device::Cpu),
185            shape: Shape::new(),
186            strides: Strides::new(),
187            offset: 0,
188        }
189    }
190
191    /// Creates a tensor filled with zeros.
192    #[must_use]
193    pub fn zeros(shape: &[usize]) -> Self {
194        crate::creation::zeros(shape)
195    }
196
197    /// Creates a tensor filled with ones.
198    #[must_use]
199    pub fn ones(shape: &[usize]) -> Self
200    where
201        T: Numeric,
202    {
203        crate::creation::ones(shape)
204    }
205
206    /// Creates a tensor filled with a constant value.
207    #[must_use]
208    pub fn full(shape: &[usize], value: T) -> Self {
209        crate::creation::full(shape, value)
210    }
211
212    /// Creates a tensor with random values from standard normal distribution.
213    #[must_use]
214    pub fn randn(shape: &[usize]) -> Self
215    where
216        T: Float,
217        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
218    {
219        crate::creation::randn(shape)
220    }
221
222    /// Creates a tensor with random values from uniform distribution [0, 1).
223    #[must_use]
224    pub fn rand(shape: &[usize]) -> Self
225    where
226        T: Float,
227        rand::distributions::Standard: rand::distributions::Distribution<T>,
228    {
229        crate::creation::rand(shape)
230    }
231
232    // =========================================================================
233    // Properties
234    // =========================================================================
235
236    /// Returns the shape of the tensor.
237    #[must_use]
238    pub fn shape(&self) -> &[usize] {
239        &self.shape
240    }
241
242    /// Returns the strides of the tensor.
243    #[must_use]
244    pub fn strides(&self) -> &[isize] {
245        &self.strides
246    }
247
248    /// Returns the number of dimensions.
249    #[must_use]
250    pub fn ndim(&self) -> usize {
251        self.shape.len()
252    }
253
254    /// Returns the total number of elements.
255    #[must_use]
256    pub fn numel(&self) -> usize {
257        numel(&self.shape)
258    }
259
260    /// Returns true if the tensor is empty (has zero elements).
261    #[must_use]
262    pub fn is_empty(&self) -> bool {
263        self.numel() == 0
264    }
265
266    /// Returns the size of a specific dimension.
267    ///
268    /// # Arguments
269    /// * `dim` - Dimension index (supports negative indexing)
270    pub fn size(&self, dim: i64) -> Result<usize> {
271        let idx = normalize_dim(dim, self.ndim())?;
272        Ok(self.shape[idx])
273    }
274
275    /// Returns the device this tensor is on.
276    #[must_use]
277    pub fn device(&self) -> Device {
278        self.storage.device()
279    }
280
281    /// Returns true if the tensor is contiguous in memory.
282    #[must_use]
283    pub fn is_contiguous(&self) -> bool {
284        is_contiguous(&self.shape, &self.strides)
285    }
286
287    /// Returns true if this tensor is a scalar (0-dimensional).
288    #[must_use]
289    pub fn is_scalar(&self) -> bool {
290        self.shape.is_empty()
291    }
292
293    // =========================================================================
294    // Data Access
295    // =========================================================================
296
297    /// Returns the element at the given indices.
298    ///
299    /// # Arguments
300    /// * `indices` - Multi-dimensional indices
301    pub fn get(&self, indices: &[usize]) -> Result<T> {
302        if indices.len() != self.ndim() {
303            return Err(Error::invalid_operation(format!(
304                "Expected {} indices, got {}",
305                self.ndim(),
306                indices.len()
307            )));
308        }
309
310        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
311            if idx >= dim {
312                return Err(Error::IndexOutOfBounds {
313                    index: idx,
314                    size: dim,
315                });
316            }
317        }
318
319        let offset = self.offset + linear_index(indices, &self.strides);
320        Ok(self.storage.as_slice()[offset])
321    }
322
323    /// Sets the element at the given indices.
324    ///
325    /// # Arguments
326    /// * `indices` - Multi-dimensional indices
327    /// * `value` - Value to set
328    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
329        if indices.len() != self.ndim() {
330            return Err(Error::invalid_operation(format!(
331                "Expected {} indices, got {}",
332                self.ndim(),
333                indices.len()
334            )));
335        }
336
337        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
338            if idx >= dim {
339                return Err(Error::IndexOutOfBounds {
340                    index: idx,
341                    size: dim,
342                });
343            }
344        }
345
346        let offset = self.offset + linear_index(indices, &self.strides);
347        self.storage.as_slice_mut()[offset] = value;
348        Ok(())
349    }
350
351    /// Returns the scalar value for a 0-dimensional tensor.
352    pub fn item(&self) -> Result<T> {
353        if self.numel() != 1 {
354            return Err(Error::invalid_operation(
355                "item() only works on single-element tensors",
356            ));
357        }
358
359        if self.is_scalar() {
360            Ok(self.storage.as_slice()[self.offset])
361        } else {
362            // Single element but not scalar shape
363            let indices = vec![0; self.ndim()];
364            self.get(&indices)
365        }
366    }
367
368    /// Returns the data as a contiguous vector.
369    ///
370    /// If the tensor is already contiguous, this returns a reference.
371    /// Otherwise, it copies the data into a new contiguous vector.
372    /// For GPU tensors (f32 only), performs a D2H copy.
373    #[must_use]
374    pub fn to_vec(&self) -> Vec<T> {
375        // GPU path: GPU storage is always f32
376        #[cfg(feature = "cuda")]
377        if self.storage.is_gpu() {
378            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
379            let self_f32 = unsafe { gpu_ref(self) };
380            let f32_vec = self_f32.to_vec_gpu();
381            unsafe {
382                let mut v = std::mem::ManuallyDrop::new(f32_vec);
383                return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
384            }
385        }
386
387        if self.is_contiguous() {
388            let storage = self.storage.as_slice();
389            storage[self.offset..self.offset + self.numel()].to_vec()
390        } else {
391            let mut result = Vec::with_capacity(self.numel());
392            self.copy_data_to(&mut result);
393            result
394        }
395    }
396
397    /// Copies data to a slice, handling non-contiguous layouts.
398    fn copy_data_to(&self, dst: &mut Vec<T>) {
399        dst.clear();
400        let storage = self.storage.as_slice();
401
402        // Iterate through all indices
403        let total = self.numel();
404        for i in 0..total {
405            let indices = crate::shape::unravel_index(i, &self.shape);
406            let offset = self.offset + linear_index(&indices, &self.strides);
407            dst.push(storage[offset]);
408        }
409    }
410
411    // =========================================================================
412    // Shape Operations
413    // =========================================================================
414
415    /// Returns a new tensor with the specified shape.
416    ///
417    /// The total number of elements must remain the same.
418    /// Supports -1 in one dimension to infer the size.
419    ///
420    /// # Arguments
421    /// * `new_shape` - Target shape
422    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
423        let shape = reshape(&self.shape, new_shape)?;
424
425        if self.is_contiguous() {
426            // Can just change shape without copying
427            Ok(Self {
428                storage: self.storage.clone(),
429                strides: contiguous_strides(&shape),
430                shape,
431                offset: self.offset,
432            })
433        } else {
434            // Need to make contiguous first
435            let contig = self.contiguous();
436            Ok(Self {
437                storage: contig.storage,
438                strides: contiguous_strides(&shape),
439                shape,
440                offset: 0,
441            })
442        }
443    }
444
445    /// Returns a new tensor with a flattened shape.
446    #[must_use]
447    pub fn flatten(&self) -> Self {
448        self.reshape(&[-1]).expect("Flatten should never fail")
449    }
450
451    /// Returns a new tensor with dimensions of size 1 removed.
452    ///
453    /// # Arguments
454    /// * `dim` - Optional specific dimension to squeeze
455    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
456        let dim = match dim {
457            Some(d) => Some(normalize_dim(d, self.ndim())?),
458            None => None,
459        };
460
461        let new_shape = squeeze(&self.shape, dim);
462        let new_strides: Strides = match dim {
463            Some(d) => {
464                let mut s = self.strides.clone();
465                if d < self.shape.len() && self.shape[d] == 1 {
466                    s.remove(d);
467                }
468                s
469            }
470            None => self
471                .shape
472                .iter()
473                .zip(self.strides.iter())
474                .filter(|(&dim, _)| dim != 1)
475                .map(|(_, &stride)| stride)
476                .collect(),
477        };
478
479        Ok(Self {
480            storage: self.storage.clone(),
481            shape: new_shape,
482            strides: new_strides,
483            offset: self.offset,
484        })
485    }
486
487    /// Returns a new tensor with a dimension of size 1 inserted.
488    ///
489    /// # Arguments
490    /// * `dim` - Position to insert the new dimension
491    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
492        let normalized = if dim < 0 {
493            (dim + self.ndim() as i64 + 1) as usize
494        } else {
495            dim as usize
496        };
497
498        let new_shape = unsqueeze(&self.shape, normalized)?;
499        let mut new_strides = Strides::with_capacity(new_shape.len());
500
501        for (i, _) in new_shape.iter().enumerate() {
502            if i < normalized {
503                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
504            } else if i == normalized {
505                // Stride for new dimension (doesn't matter since size is 1)
506                new_strides.push(1);
507            } else {
508                new_strides.push(self.strides[i - 1]);
509            }
510        }
511
512        Ok(Self {
513            storage: self.storage.clone(),
514            shape: new_shape,
515            strides: new_strides,
516            offset: self.offset,
517        })
518    }
519
520    /// Transposes two dimensions.
521    ///
522    /// # Arguments
523    /// * `dim0` - First dimension
524    /// * `dim1` - Second dimension
525    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
526        let d0 = normalize_dim(dim0, self.ndim())?;
527        let d1 = normalize_dim(dim1, self.ndim())?;
528
529        let new_shape = transpose_shape(&self.shape, d0, d1)?;
530        let new_strides = transpose_strides(&self.strides, d0, d1);
531
532        Ok(Self {
533            storage: self.storage.clone(),
534            shape: new_shape,
535            strides: new_strides,
536            offset: self.offset,
537        })
538    }
539
540    /// Returns the transpose of a 2D tensor.
541    pub fn t(&self) -> Result<Self> {
542        if self.ndim() != 2 {
543            return Err(Error::invalid_operation("t() only works on 2D tensors"));
544        }
545        self.transpose(0, 1)
546    }
547
548    /// Returns a permuted tensor with dimensions reordered.
549    ///
550    /// # Arguments
551    /// * `dims` - New order of dimensions
552    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
553        if dims.len() != self.ndim() {
554            return Err(Error::invalid_operation(format!(
555                "Expected {} dimensions, got {}",
556                self.ndim(),
557                dims.len()
558            )));
559        }
560
561        // Check that dims is a permutation
562        let mut seen = vec![false; self.ndim()];
563        for &d in dims {
564            if d >= self.ndim() {
565                return Err(Error::InvalidDimension {
566                    index: d as i64,
567                    ndim: self.ndim(),
568                });
569            }
570            if seen[d] {
571                return Err(Error::invalid_operation("Duplicate dimension in permute"));
572            }
573            seen[d] = true;
574        }
575
576        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
577        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
578
579        Ok(Self {
580            storage: self.storage.clone(),
581            shape: new_shape,
582            strides: new_strides,
583            offset: self.offset,
584        })
585    }
586
587    /// Returns a contiguous copy of the tensor.
588    #[must_use]
589    pub fn contiguous(&self) -> Self {
590        if self.is_contiguous() && self.offset == 0 {
591            return self.clone();
592        }
593
594        #[cfg(feature = "cuda")]
595        if self.storage.is_gpu() {
596            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
597            let self_f32 = unsafe { gpu_ref(self) };
598            let result = self_f32.contiguous_gpu();
599            return unsafe { gpu_into(result) };
600        }
601
602        let data = self.to_vec();
603        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
604    }
605
606    // =========================================================================
607    // Device Operations
608    // =========================================================================
609
610    /// Transfers the tensor to a different device.
611    ///
612    /// # Arguments
613    /// * `device` - Target device
614    pub fn to_device(&self, device: Device) -> Result<Self> {
615        if self.device() == device {
616            return Ok(self.clone());
617        }
618
619        #[cfg(feature = "cuda")]
620        if self.storage.is_gpu() || device.is_gpu() {
621            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
622            let self_f32 = unsafe { gpu_ref(self) };
623            let result = self_f32.to_device_f32(device)?;
624            return Ok(unsafe { gpu_into(result) });
625        }
626
627        let contig = self.contiguous();
628        let new_storage = contig.storage.to_device(device)?;
629
630        Ok(Self {
631            storage: new_storage,
632            shape: self.shape.clone(),
633            strides: self.strides.clone(),
634            offset: 0,
635        })
636    }
637
638    /// Transfers to CPU.
639    pub fn cpu(&self) -> Result<Self> {
640        self.to_device(Device::Cpu)
641    }
642
643    // =========================================================================
644    // Deep Copy
645    // =========================================================================
646
647    /// Creates a deep copy of this tensor with its own storage.
648    #[must_use]
649    pub fn clone_deep(&self) -> Self {
650        let data = self.to_vec();
651        let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
652        #[cfg(feature = "cuda")]
653        if self.device().is_gpu() {
654            return cpu.to_device(self.device()).unwrap();
655        }
656        cpu
657    }
658}
659
660// =============================================================================
661// Numeric Operations
662// =============================================================================
663
664impl<T: Numeric> Tensor<T> {
665    /// Fills the tensor with a value.
666    pub fn fill_(&self, value: T) {
667        #[cfg(feature = "cuda")]
668        if self.storage.is_gpu() {
669            panic!("fill_() not supported on GPU tensors — create a new tensor instead");
670        }
671        let mut data = self.storage.as_slice_mut();
672        CpuBackend::fill(&mut data, value);
673    }
674
675    /// Fills the tensor with zeros.
676    pub fn zero_(&self) {
677        self.fill_(T::zero());
678    }
679
680    // =========================================================================
681    // Reduction Operations
682    // =========================================================================
683
684    /// Returns the sum of all elements.
685    #[must_use]
686    pub fn sum(&self) -> Self {
687        let data = self.to_vec();
688        let result = CpuBackend::sum(&data);
689        let s = Self::scalar(result);
690        #[cfg(feature = "cuda")]
691        if self.device().is_gpu() {
692            return s.to_device(self.device()).unwrap();
693        }
694        s
695    }
696
697    /// Returns the product of all elements.
698    #[must_use]
699    pub fn prod(&self) -> Self {
700        let data = self.to_vec();
701        let result = CpuBackend::prod(&data);
702        let s = Self::scalar(result);
703        #[cfg(feature = "cuda")]
704        if self.device().is_gpu() {
705            return s.to_device(self.device()).unwrap();
706        }
707        s
708    }
709
710    /// Returns the maximum element.
711    pub fn max(&self) -> Result<Self> {
712        if self.is_empty() {
713            return Err(Error::EmptyTensor);
714        }
715        let data = self.to_vec();
716        let result = CpuBackend::max(&data).unwrap();
717        let s = Self::scalar(result);
718        #[cfg(feature = "cuda")]
719        if self.device().is_gpu() {
720            return Ok(s.to_device(self.device()).unwrap());
721        }
722        Ok(s)
723    }
724
725    /// Returns the minimum element.
726    pub fn min(&self) -> Result<Self> {
727        if self.is_empty() {
728            return Err(Error::EmptyTensor);
729        }
730        let data = self.to_vec();
731        let result = CpuBackend::min(&data).unwrap();
732        let s = Self::scalar(result);
733        #[cfg(feature = "cuda")]
734        if self.device().is_gpu() {
735            return Ok(s.to_device(self.device()).unwrap());
736        }
737        Ok(s)
738    }
739
740    /// Returns the index of the maximum element.
741    pub fn argmax(&self) -> Result<usize> {
742        if self.is_empty() {
743            return Err(Error::EmptyTensor);
744        }
745        let data = self.to_vec();
746        Ok(CpuBackend::argmax(&data).unwrap())
747    }
748
749    /// Returns the index of the minimum element.
750    pub fn argmin(&self) -> Result<usize> {
751        if self.is_empty() {
752            return Err(Error::EmptyTensor);
753        }
754        let data = self.to_vec();
755        Ok(CpuBackend::argmin(&data).unwrap())
756    }
757
758    /// Concatenates tensors along a dimension.
759    ///
760    /// All tensors must have the same shape except along the cat dimension.
761    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
762        if tensors.is_empty() {
763            return Err(Error::invalid_operation("cat requires at least one tensor"));
764        }
765        let ndim = tensors[0].ndim();
766        if dim >= ndim {
767            return Err(Error::invalid_operation("cat dimension out of range"));
768        }
769
770        for t in &tensors[1..] {
771            if t.ndim() != ndim {
772                return Err(Error::invalid_operation("cat: all tensors must have same ndim"));
773            }
774            for d in 0..ndim {
775                if d != dim && t.shape[d] != tensors[0].shape[d] {
776                    return Err(Error::invalid_operation("cat: shapes must match on non-cat dims"));
777                }
778            }
779        }
780
781        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
782        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
783        out_shape[dim] = total_dim_size;
784
785        let outer_size: usize = out_shape[..dim].iter().product();
786        let inner_size: usize = out_shape[dim + 1..].iter().product();
787        let total_numel: usize = out_shape.iter().product();
788        let mut result = vec![T::zero(); total_numel];
789
790        let mut dim_offset = 0;
791        for t in tensors {
792            let t_data = t.contiguous().to_vec();
793            let t_dim_size = t.shape[dim];
794            for outer in 0..outer_size {
795                for d in 0..t_dim_size {
796                    let src_base = outer * t_dim_size * inner_size + d * inner_size;
797                    let dst_base = outer * total_dim_size * inner_size
798                        + (dim_offset + d) * inner_size;
799                    result[dst_base..dst_base + inner_size]
800                        .copy_from_slice(&t_data[src_base..src_base + inner_size]);
801                }
802            }
803            dim_offset += t_dim_size;
804        }
805
806        let out = Self::from_vec(result, &out_shape)?;
807        #[cfg(feature = "cuda")]
808        if tensors[0].device().is_gpu() {
809            return Ok(out.to_device(tensors[0].device()).unwrap());
810        }
811        Ok(out)
812    }
813}
814
815// =============================================================================
816// Float Operations
817// =============================================================================
818
819impl<T: Float> Tensor<T> {
820    /// Returns the mean of all elements.
821    pub fn mean(&self) -> Result<Self> {
822        if self.is_empty() {
823            return Err(Error::EmptyTensor);
824        }
825        let data = self.to_vec();
826        let result = CpuBackend::mean(&data).unwrap();
827        let s = Self::scalar(result);
828        #[cfg(feature = "cuda")]
829        if self.device().is_gpu() {
830            return Ok(s.to_device(self.device()).unwrap());
831        }
832        Ok(s)
833    }
834
835    // =========================================================================
836    // Activation Functions
837    // =========================================================================
838
839    /// Applies `ReLU` activation: max(0, x).
840    #[must_use]
841    pub fn relu(&self) -> Self {
842        #[cfg(feature = "cuda")]
843        if self.device().is_gpu() {
844            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
845            return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
846        }
847        let data = self.to_vec();
848        let mut result = vec![T::zero(); data.len()];
849        CpuBackend::relu(&mut result, &data);
850        Self::from_vec(result, &self.shape).unwrap()
851    }
852
853    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
854    #[must_use]
855    pub fn sigmoid(&self) -> Self {
856        #[cfg(feature = "cuda")]
857        if self.device().is_gpu() {
858            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
859            return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
860        }
861        let data = self.to_vec();
862        let mut result = vec![T::zero(); data.len()];
863        CpuBackend::sigmoid(&mut result, &data);
864        Self::from_vec(result, &self.shape).unwrap()
865    }
866
867    /// Applies tanh activation.
868    #[must_use]
869    pub fn tanh(&self) -> Self {
870        #[cfg(feature = "cuda")]
871        if self.device().is_gpu() {
872            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
873            return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
874        }
875        let data = self.to_vec();
876        let mut result = vec![T::zero(); data.len()];
877        CpuBackend::tanh(&mut result, &data);
878        Self::from_vec(result, &self.shape).unwrap()
879    }
880
881    /// Applies exponential function.
882    #[must_use]
883    pub fn exp(&self) -> Self {
884        #[cfg(feature = "cuda")]
885        if self.device().is_gpu() {
886            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
887            return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
888        }
889        let data = self.to_vec();
890        let mut result = vec![T::zero(); data.len()];
891        CpuBackend::exp(&mut result, &data);
892        Self::from_vec(result, &self.shape).unwrap()
893    }
894
895    /// Applies natural logarithm.
896    #[must_use]
897    pub fn ln(&self) -> Self {
898        #[cfg(feature = "cuda")]
899        if self.device().is_gpu() {
900            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
901            return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
902        }
903        let data = self.to_vec();
904        let mut result = vec![T::zero(); data.len()];
905        CpuBackend::ln(&mut result, &data);
906        Self::from_vec(result, &self.shape).unwrap()
907    }
908
909    /// Applies square root.
910    #[must_use]
911    pub fn sqrt(&self) -> Self {
912        #[cfg(feature = "cuda")]
913        if self.device().is_gpu() {
914            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
915            return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
916        }
917        let data = self.to_vec();
918        let mut result = vec![T::zero(); data.len()];
919        CpuBackend::sqrt(&mut result, &data);
920        Self::from_vec(result, &self.shape).unwrap()
921    }
922
923    /// Computes element-wise power.
924    #[must_use]
925    pub fn pow(&self, exp: T) -> Self {
926        #[cfg(feature = "cuda")]
927        if self.device().is_gpu() {
928            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
929            let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
930            return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
931        }
932        let data = self.to_vec();
933        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
934        Self::from_vec(result, &self.shape).unwrap()
935    }
936
937    /// GELU activation function (Gaussian Error Linear Unit).
938    #[must_use]
939    pub fn gelu(&self) -> Self {
940        #[cfg(feature = "cuda")]
941        if self.device().is_gpu() {
942            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
943            return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
944        }
945        crate::ops::gelu(self)
946    }
947
948    /// SiLU/Swish activation function.
949    #[must_use]
950    pub fn silu(&self) -> Self {
951        #[cfg(feature = "cuda")]
952        if self.device().is_gpu() {
953            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
954            return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
955        }
956        crate::ops::silu(self)
957    }
958
959    /// Softmax along specified dimension.
960    #[must_use]
961    pub fn softmax(&self, dim: i32) -> Self {
962        #[cfg(feature = "cuda")]
963        if self.device().is_gpu() {
964            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
965            let self_f32 = unsafe { gpu_ref(self) };
966            return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
967        }
968        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
969    }
970
971    /// Log softmax along specified dimension.
972    #[must_use]
973    pub fn log_softmax(&self, dim: i32) -> Self {
974        let softmax_result = self.softmax(dim);
975        softmax_result.ln()
976    }
977
978    /// Mean along a dimension.
979    #[must_use]
980    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
981        let ndim = self.ndim();
982        let dim = if dim < 0 {
983            (ndim as i32 + dim) as usize
984        } else {
985            dim as usize
986        };
987
988        if dim >= ndim {
989            return self.clone();
990        }
991
992        // GPU fast path: sum_dim then divide by dim_size (all on GPU)
993        #[cfg(feature = "cuda")]
994        if self.device().is_gpu() {
995            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
996            let self_f32 = unsafe { gpu_ref(self) };
997            let summed = if keepdim {
998                self_f32.sum_dim_keepdim_cuda(dim)
999            } else {
1000                self_f32.sum_dim_cuda(dim)
1001            };
1002            let dim_size = self.shape[dim];
1003            let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1004            return unsafe { gpu_into(result) };
1005        }
1006
1007        let dim_size = self.shape[dim];
1008        let data = self.to_vec();
1009        let mut new_shape = self.shape.clone();
1010
1011        if keepdim {
1012            new_shape[dim] = 1;
1013        } else {
1014            new_shape.remove(dim);
1015        }
1016
1017        if new_shape.is_empty() {
1018            new_shape = smallvec::smallvec![1];
1019        }
1020
1021        let new_numel: usize = new_shape.iter().product();
1022        let mut result = vec![T::zero(); new_numel];
1023
1024        let outer_size: usize = self.shape[..dim].iter().product();
1025        let inner_size: usize = self.shape[dim + 1..].iter().product();
1026
1027        for outer in 0..outer_size {
1028            for inner in 0..inner_size {
1029                let mut sum = T::zero();
1030                for d in 0..dim_size {
1031                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1032                    sum = sum + data[idx];
1033                }
1034                let mean = sum / NumCast::from(dim_size).unwrap();
1035                let result_idx = outer * inner_size + inner;
1036                result[result_idx] = mean;
1037            }
1038        }
1039
1040        Self::from_vec(result, &new_shape).unwrap()
1041    }
1042
1043    /// Sum along a dimension.
1044    #[must_use]
1045    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1046        let ndim = self.ndim();
1047        let dim = if dim < 0 {
1048            (ndim as i32 + dim) as usize
1049        } else {
1050            dim as usize
1051        };
1052
1053        if dim >= ndim {
1054            return self.clone();
1055        }
1056
1057        // GPU fast path: use CUDA sum_dim kernel (no CPU copies)
1058        #[cfg(feature = "cuda")]
1059        if self.device().is_gpu() {
1060            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1061            let self_f32 = unsafe { gpu_ref(self) };
1062            let result = if keepdim {
1063                self_f32.sum_dim_keepdim_cuda(dim)
1064            } else {
1065                self_f32.sum_dim_cuda(dim)
1066            };
1067            return unsafe { gpu_into(result) };
1068        }
1069
1070        let dim_size = self.shape[dim];
1071        let data = self.to_vec();
1072        let mut new_shape = self.shape.clone();
1073
1074        if keepdim {
1075            new_shape[dim] = 1;
1076        } else {
1077            new_shape.remove(dim);
1078        }
1079
1080        if new_shape.is_empty() {
1081            new_shape = smallvec::smallvec![1];
1082        }
1083
1084        let new_numel: usize = new_shape.iter().product();
1085        let mut result = vec![T::zero(); new_numel];
1086
1087        let outer_size: usize = self.shape[..dim].iter().product();
1088        let inner_size: usize = self.shape[dim + 1..].iter().product();
1089
1090        for outer in 0..outer_size {
1091            for inner in 0..inner_size {
1092                let mut sum = T::zero();
1093                for d in 0..dim_size {
1094                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1095                    sum = sum + data[idx];
1096                }
1097                let result_idx = outer * inner_size + inner;
1098                result[result_idx] = sum;
1099            }
1100        }
1101
1102        Self::from_vec(result, &new_shape).unwrap()
1103    }
1104
1105    /// Variance along a dimension.
1106    #[must_use]
1107    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1108        // variance = E[x²] - E[x]²  (saves one full-size intermediate allocation)
1109        let mean = self.mean_dim(dim, true);
1110        let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1111        let mean_sq = sq.mean_dim(dim, keepdim);
1112        let mean_keepdim = if keepdim { mean.clone() } else { self.mean_dim(dim, keepdim) };
1113        let mean_squared = mean_keepdim.mul(&mean_keepdim).unwrap_or_else(|_| mean_keepdim.clone());
1114        mean_sq.sub(&mean_squared).unwrap_or_else(|_| mean_sq.clone())
1115    }
1116
1117    /// Broadcasts tensor to a new shape.
1118    #[must_use]
1119    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1120        if self.shape.as_slice() == shape {
1121            return self.clone();
1122        }
1123
1124        #[cfg(feature = "cuda")]
1125        if self.device().is_gpu() {
1126            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1127            let self_f32 = unsafe { gpu_ref(self) };
1128            return unsafe { gpu_into(self_f32.broadcast_to_cuda(shape).expect("CUDA broadcast_to failed")) };
1129        }
1130
1131        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1132        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1133
1134        let total = numel(&result_shape);
1135        let mut result_data = vec![T::zero(); total];
1136        let self_data = self.storage.as_slice();
1137
1138        for i in 0..total {
1139            let indices = crate::shape::unravel_index(i, &result_shape);
1140            let self_idx = self.offset + linear_index(&indices, &self_strides);
1141            result_data[i] = self_data[self_idx];
1142        }
1143
1144        Self::from_vec(result_data, &result_shape).unwrap()
1145    }
1146
1147    /// Slices the tensor using ranges for each dimension.
1148    #[must_use]
1149    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1150        let mut new_shape = Vec::with_capacity(self.ndim());
1151        for (i, range) in ranges.iter().enumerate() {
1152            if i < self.ndim() {
1153                new_shape.push(range.end - range.start);
1154            }
1155        }
1156        // Keep remaining dimensions unchanged
1157        for i in ranges.len()..self.ndim() {
1158            new_shape.push(self.shape[i]);
1159        }
1160
1161        let new_numel: usize = new_shape.iter().product();
1162        let mut result_data = vec![T::zero(); new_numel];
1163        let self_data = self.to_vec();
1164
1165        // Copy data with proper indexing
1166        let mut result_idx = 0;
1167        Self::slice_recursive(
1168            &self_data,
1169            &self.shape,
1170            ranges,
1171            0,
1172            0,
1173            &mut result_data,
1174            &mut result_idx,
1175        );
1176
1177        let out = Self::from_vec(result_data, &new_shape).unwrap();
1178        #[cfg(feature = "cuda")]
1179        if self.device().is_gpu() {
1180            return out.to_device(self.device()).unwrap();
1181        }
1182        out
1183    }
1184
1185    fn slice_recursive(
1186        data: &[T],
1187        shape: &[usize],
1188        ranges: &[std::ops::Range<usize>],
1189        dim: usize,
1190        offset: usize,
1191        result: &mut [T],
1192        result_idx: &mut usize,
1193    ) {
1194        if dim == shape.len() {
1195            result[*result_idx] = data[offset];
1196            *result_idx += 1;
1197            return;
1198        }
1199
1200        let stride: usize = shape[dim + 1..].iter().product();
1201        let (start, end) = if dim < ranges.len() {
1202            (ranges[dim].start, ranges[dim].end)
1203        } else {
1204            (0, shape[dim])
1205        };
1206
1207        for i in start..end {
1208            Self::slice_recursive(
1209                data,
1210                shape,
1211                ranges,
1212                dim + 1,
1213                offset + i * stride,
1214                result,
1215                result_idx,
1216            );
1217        }
1218    }
1219}
1220
1221// =============================================================================
1222// Arithmetic Operator Implementations
1223// =============================================================================
1224
1225impl<T: Numeric> Tensor<T> {
1226    /// Element-wise addition with broadcasting.
1227    pub fn add(&self, other: &Self) -> Result<Self> {
1228        #[cfg(feature = "cuda")]
1229        {
1230            let self_gpu = self.device().is_gpu();
1231            let other_gpu = other.device().is_gpu();
1232            if self_gpu || other_gpu {
1233                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1234                if self_gpu && other_gpu {
1235                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1236                    if self.shape == other.shape {
1237                        return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1238                    } else {
1239                        return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1240                    }
1241                }
1242                // Mixed device — move to GPU, then operate
1243                let target_device = if self_gpu { self.device() } else { other.device() };
1244                let a_gpu = if self_gpu { self.clone() } else { self.to_device(target_device)? };
1245                let b_gpu = if other_gpu { other.clone() } else { other.to_device(target_device)? };
1246                return a_gpu.add(&b_gpu);
1247            }
1248        }
1249        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1250        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1251        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1252
1253        let total = numel(&result_shape);
1254        let mut result_data = vec![T::zero(); total];
1255
1256        let self_data = self.storage.as_slice();
1257        let other_data = other.storage.as_slice();
1258
1259        for i in 0..total {
1260            let indices = crate::shape::unravel_index(i, &result_shape);
1261            let self_idx = self.offset + linear_index(&indices, &self_strides);
1262            let other_idx = other.offset + linear_index(&indices, &other_strides);
1263            result_data[i] = self_data[self_idx] + other_data[other_idx];
1264        }
1265
1266        Self::from_vec(result_data, &result_shape)
1267    }
1268
1269    /// Element-wise subtraction with broadcasting.
1270    pub fn sub(&self, other: &Self) -> Result<Self> {
1271        #[cfg(feature = "cuda")]
1272        {
1273            let self_gpu = self.device().is_gpu();
1274            let other_gpu = other.device().is_gpu();
1275            if self_gpu || other_gpu {
1276                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1277                if self_gpu && other_gpu {
1278                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1279                    if self.shape == other.shape {
1280                        return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1281                    } else {
1282                        return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1283                    }
1284                }
1285                let target = if self_gpu { self.device() } else { other.device() };
1286                let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1287                let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1288                return a_gpu.sub(&b_gpu);
1289            }
1290        }
1291        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1292        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1293        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1294
1295        let total = numel(&result_shape);
1296        let mut result_data = vec![T::zero(); total];
1297
1298        let self_data = self.storage.as_slice();
1299        let other_data = other.storage.as_slice();
1300
1301        for i in 0..total {
1302            let indices = crate::shape::unravel_index(i, &result_shape);
1303            let self_idx = self.offset + linear_index(&indices, &self_strides);
1304            let other_idx = other.offset + linear_index(&indices, &other_strides);
1305            result_data[i] = self_data[self_idx] - other_data[other_idx];
1306        }
1307
1308        Self::from_vec(result_data, &result_shape)
1309    }
1310
1311    /// Element-wise multiplication with broadcasting.
1312    pub fn mul(&self, other: &Self) -> Result<Self> {
1313        #[cfg(feature = "cuda")]
1314        {
1315            let self_gpu = self.device().is_gpu();
1316            let other_gpu = other.device().is_gpu();
1317            if self_gpu || other_gpu {
1318                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1319                if self_gpu && other_gpu {
1320                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1321                    if self.shape == other.shape {
1322                        return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1323                    } else {
1324                        return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1325                    }
1326                }
1327                let target = if self_gpu { self.device() } else { other.device() };
1328                let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1329                let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1330                return a_gpu.mul(&b_gpu);
1331            }
1332        }
1333        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1334        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1335        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1336
1337        let total = numel(&result_shape);
1338        let mut result_data = vec![T::zero(); total];
1339
1340        let self_data = self.storage.as_slice();
1341        let other_data = other.storage.as_slice();
1342
1343        for i in 0..total {
1344            let indices = crate::shape::unravel_index(i, &result_shape);
1345            let self_idx = self.offset + linear_index(&indices, &self_strides);
1346            let other_idx = other.offset + linear_index(&indices, &other_strides);
1347            result_data[i] = self_data[self_idx] * other_data[other_idx];
1348        }
1349
1350        Self::from_vec(result_data, &result_shape)
1351    }
1352
1353    /// Element-wise division with broadcasting.
1354    pub fn div(&self, other: &Self) -> Result<Self> {
1355        #[cfg(feature = "cuda")]
1356        {
1357            let self_gpu = self.device().is_gpu();
1358            let other_gpu = other.device().is_gpu();
1359            if self_gpu || other_gpu {
1360                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1361                if self_gpu && other_gpu {
1362                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1363                    if self.shape == other.shape {
1364                        return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1365                    } else {
1366                        return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1367                    }
1368                }
1369                let target = if self_gpu { self.device() } else { other.device() };
1370                let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1371                let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1372                return a_gpu.div(&b_gpu);
1373            }
1374        }
1375        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1376        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1377        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1378
1379        let total = numel(&result_shape);
1380        let mut result_data = vec![T::zero(); total];
1381
1382        let self_data = self.storage.as_slice();
1383        let other_data = other.storage.as_slice();
1384
1385        for i in 0..total {
1386            let indices = crate::shape::unravel_index(i, &result_shape);
1387            let self_idx = self.offset + linear_index(&indices, &self_strides);
1388            let other_idx = other.offset + linear_index(&indices, &other_strides);
1389            result_data[i] = self_data[self_idx] / other_data[other_idx];
1390        }
1391
1392        Self::from_vec(result_data, &result_shape)
1393    }
1394
1395    /// Scalar addition.
1396    #[must_use]
1397    pub fn add_scalar(&self, scalar: T) -> Self {
1398        #[cfg(feature = "cuda")]
1399        if self.device().is_gpu() {
1400            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1401            let self_f32 = unsafe { gpu_ref(self) };
1402            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1403            return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1404        }
1405        let data = self.to_vec();
1406        let mut result = vec![T::zero(); data.len()];
1407        CpuBackend::add_scalar(&mut result, &data, scalar);
1408        Self::from_vec(result, &self.shape).unwrap()
1409    }
1410
1411    /// Scalar multiplication.
1412    #[must_use]
1413    pub fn mul_scalar(&self, scalar: T) -> Self {
1414        #[cfg(feature = "cuda")]
1415        if self.device().is_gpu() {
1416            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1417            let self_f32 = unsafe { gpu_ref(self) };
1418            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1419            return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1420        }
1421        let data = self.to_vec();
1422        let mut result = vec![T::zero(); data.len()];
1423        CpuBackend::mul_scalar(&mut result, &data, scalar);
1424        Self::from_vec(result, &self.shape).unwrap()
1425    }
1426
1427    /// Element-wise negation.
1428    #[must_use]
1429    pub fn neg(&self) -> Self {
1430        #[cfg(feature = "cuda")]
1431        if self.device().is_gpu() {
1432            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1433            let self_f32 = unsafe { gpu_ref(self) };
1434            return unsafe { gpu_into(self_f32.neg_cuda()) };
1435        }
1436        let data = self.to_vec();
1437        let mut result = vec![T::zero(); data.len()];
1438        CpuBackend::neg(&mut result, &data);
1439        Self::from_vec(result, &self.shape).unwrap()
1440    }
1441
1442    /// Matrix multiplication with batching support.
1443    ///
1444    /// Supports:
1445    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1446    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1447    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1448    pub fn matmul(&self, other: &Self) -> Result<Self> {
1449        #[cfg(feature = "cuda")]
1450        if self.device().is_gpu() {
1451            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1452            let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1453            return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1454        }
1455        if self.ndim() < 2 || other.ndim() < 2 {
1456            return Err(Error::invalid_operation(
1457                "matmul requires at least 2D tensors",
1458            ));
1459        }
1460
1461        let m = self.shape[self.ndim() - 2];
1462        let k1 = self.shape[self.ndim() - 1];
1463        let k2 = other.shape[other.ndim() - 2];
1464        let n = other.shape[other.ndim() - 1];
1465
1466        if k1 != k2 {
1467            return Err(Error::invalid_operation(format!(
1468                "matmul inner dimensions must match: {k1} vs {k2}"
1469            )));
1470        }
1471
1472        // For 2D matrices, do simple matmul
1473        if self.ndim() == 2 && other.ndim() == 2 {
1474            let a_data = self.contiguous().to_vec();
1475            let b_data = other.contiguous().to_vec();
1476
1477            // GPU-accelerated matmul for CPU tensors: only for very large matrices
1478            // where transfer overhead is negligible relative to compute.
1479            // For GPU-resident tensors, the dispatch at the top of matmul() handles it.
1480            #[cfg(feature = "cuda")]
1481            {
1482                let flops = m * n * k1;
1483                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1484                    && flops >= 4_000_000
1485                {
1486                    let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1487                    let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1488                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1489                        let c_t: Vec<T> = unsafe {
1490                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1491                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1492                        };
1493                        return Self::from_vec(c_t, &[m, n]);
1494                    }
1495                }
1496            }
1497
1498            let mut c_data = vec![T::zero(); m * n];
1499            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1500            return Self::from_vec(c_data, &[m, n]);
1501        }
1502
1503        // For batched matmul, compute batch size
1504        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1505        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1506
1507        if batch_dims_self != batch_dims_other {
1508            return Err(Error::invalid_operation(format!(
1509                "matmul batch dimensions must match: {:?} vs {:?}",
1510                batch_dims_self, batch_dims_other
1511            )));
1512        }
1513
1514        let batch_size: usize = batch_dims_self.iter().product();
1515        let a_stride = m * k1;
1516        let b_stride = k1 * n;
1517        let c_stride = m * n;
1518
1519        let a_data = self.contiguous().to_vec();
1520        let b_data = other.contiguous().to_vec();
1521        let mut c_data = vec![T::zero(); batch_size * m * n];
1522
1523        // Try GPU acceleration for f32 batched matmul (only for large enough matrices)
1524        #[cfg(feature = "cuda")]
1525        {
1526            let flops = m * n * k1;
1527            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1528                && flops >= 4_000_000
1529            {
1530                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1531                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1532                let mut gpu_ok = true;
1533                for batch in 0..batch_size {
1534                    let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1535                    let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1536                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1537                        c_data[batch * c_stride..(batch + 1) * c_stride]
1538                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1539                    } else {
1540                        gpu_ok = false;
1541                        break;
1542                    }
1543                }
1544                if gpu_ok {
1545                    let mut output_shape = batch_dims_self;
1546                    output_shape.push(m);
1547                    output_shape.push(n);
1548                    return Self::from_vec(c_data, &output_shape);
1549                }
1550                // Fall through to CPU if GPU failed
1551                c_data = vec![T::zero(); batch_size * m * n];
1552            }
1553        }
1554
1555        // CPU fallback: loop over batches
1556        for batch in 0..batch_size {
1557            let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1558            let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1559            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1560            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1561        }
1562
1563        // Build output shape: batch_dims + [m, n]
1564        let mut output_shape = batch_dims_self;
1565        output_shape.push(m);
1566        output_shape.push(n);
1567
1568        Self::from_vec(c_data, &output_shape)
1569    }
1570
1571    /// Dot product for 1D tensors.
1572    pub fn dot(&self, other: &Self) -> Result<Self> {
1573        if self.ndim() != 1 || other.ndim() != 1 {
1574            return Err(Error::invalid_operation("dot requires 1D tensors"));
1575        }
1576
1577        if self.shape[0] != other.shape[0] {
1578            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1579        }
1580
1581        let a_data = self.to_vec();
1582        let b_data = other.to_vec();
1583        let result = CpuBackend::dot(&a_data, &b_data);
1584
1585        Ok(Self::scalar(result))
1586    }
1587}
1588
1589// =============================================================================
1590// Operator Trait Implementations
1591// =============================================================================
1592
1593impl<T: Numeric> Add for &Tensor<T> {
1594    type Output = Tensor<T>;
1595
1596    fn add(self, other: Self) -> Self::Output {
1597        self.add(other).expect("Addition failed")
1598    }
1599}
1600
1601impl<T: Numeric> Sub for &Tensor<T> {
1602    type Output = Tensor<T>;
1603
1604    fn sub(self, other: Self) -> Self::Output {
1605        self.sub(other).expect("Subtraction failed")
1606    }
1607}
1608
1609impl<T: Numeric> Mul for &Tensor<T> {
1610    type Output = Tensor<T>;
1611
1612    fn mul(self, other: Self) -> Self::Output {
1613        self.mul(other).expect("Multiplication failed")
1614    }
1615}
1616
1617impl<T: Numeric> Div for &Tensor<T> {
1618    type Output = Tensor<T>;
1619
1620    fn div(self, other: Self) -> Self::Output {
1621        self.div(other).expect("Division failed")
1622    }
1623}
1624
1625impl<T: Numeric> Neg for &Tensor<T> {
1626    type Output = Tensor<T>;
1627
1628    fn neg(self) -> Self::Output {
1629        self.neg()
1630    }
1631}
1632
1633// Scalar operations
1634impl<T: Numeric> Add<T> for &Tensor<T> {
1635    type Output = Tensor<T>;
1636
1637    fn add(self, scalar: T) -> Self::Output {
1638        self.add_scalar(scalar)
1639    }
1640}
1641
1642impl<T: Numeric> Mul<T> for &Tensor<T> {
1643    type Output = Tensor<T>;
1644
1645    fn mul(self, scalar: T) -> Self::Output {
1646        self.mul_scalar(scalar)
1647    }
1648}
1649
1650// =============================================================================
1651// Display Implementation
1652// =============================================================================
1653
1654impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1655    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1656        write!(
1657            f,
1658            "Tensor(shape={:?}, device={}",
1659            self.shape(),
1660            self.device()
1661        )?;
1662        if self.numel() <= 10 {
1663            write!(f, ", data={:?}", self.to_vec())?;
1664        }
1665        write!(f, ")")
1666    }
1667}
1668
1669impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1670    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1671        if self.is_scalar() {
1672            write!(f, "{}", self.item().unwrap())
1673        } else if self.ndim() == 1 {
1674            write!(f, "[")?;
1675            let data = self.to_vec();
1676            for (i, val) in data.iter().enumerate() {
1677                if i > 0 {
1678                    write!(f, ", ")?;
1679                }
1680                write!(f, "{val}")?;
1681            }
1682            write!(f, "]")
1683        } else {
1684            write!(f, "Tensor(shape={:?})", self.shape())
1685        }
1686    }
1687}
1688
1689// =============================================================================
1690// Tests
1691// =============================================================================
1692
1693#[cfg(test)]
1694mod tests {
1695    use super::*;
1696
1697    #[test]
1698    fn test_from_vec() {
1699        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1700        assert_eq!(t.shape(), &[2, 3]);
1701        assert_eq!(t.numel(), 6);
1702    }
1703
1704    #[test]
1705    fn test_get_set() {
1706        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1707        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1708        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1709        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1710        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1711
1712        t.set(&[0, 0], 99.0).unwrap();
1713        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1714    }
1715
1716    #[test]
1717    fn test_reshape() {
1718        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1719        let r = t.reshape(&[3, 2]).unwrap();
1720        assert_eq!(r.shape(), &[3, 2]);
1721
1722        let r = t.reshape(&[-1]).unwrap();
1723        assert_eq!(r.shape(), &[6]);
1724    }
1725
1726    #[test]
1727    fn test_transpose() {
1728        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1729        let r = t.t().unwrap();
1730        assert_eq!(r.shape(), &[3, 2]);
1731        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1732        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1733        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1734    }
1735
1736    #[test]
1737    fn test_arithmetic() {
1738        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1739        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1740
1741        let c = &a + &b;
1742        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1743
1744        let d = &a * &b;
1745        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1746    }
1747
1748    #[test]
1749    fn test_broadcasting() {
1750        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1751        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1752
1753        let c = &a + &b;
1754        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1755    }
1756
1757    #[test]
1758    fn test_sum() {
1759        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1760        let s = t.sum();
1761        assert_eq!(s.item().unwrap(), 10.0);
1762    }
1763
1764    #[test]
1765    fn test_matmul() {
1766        // 2x2 @ 2x2
1767        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1768        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1769        let c = a.matmul(&b).unwrap();
1770
1771        assert_eq!(c.shape(), &[2, 2]);
1772        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1773    }
1774
1775    #[test]
1776    fn test_relu() {
1777        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1778        let r = t.relu();
1779        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1780    }
1781
1782    #[test]
1783    fn test_scalar() {
1784        let s = Tensor::<f32>::scalar(42.0);
1785        assert!(s.is_scalar());
1786        assert_eq!(s.numel(), 1);
1787        assert_eq!(s.item().unwrap(), 42.0);
1788    }
1789}