Skip to main content

axonml_tensor/
tensor.rs

1//! Core `Tensor<T>` struct — 2170 lines, 76 public methods.
2//!
3//! Constructors (from_vec, from_slice, scalar, zeros, ones, randn, rand, full),
4//! properties (shape, numel, ndim, device, is_contiguous, strides, offset),
5//! shape ops (reshape, transpose, t, squeeze, unsqueeze, expand, permute,
6//! contiguous, narrow, select, index_select, cat, chunk, split, flip, roll),
7//! arithmetic (add, sub, mul, div, neg, abs, pow, add_scalar, mul_scalar,
8//! where_cond, clamp, clamp_min), reductions (sum, mean, max, min, prod,
9//! argmax, argmin, sum_dim, mean_dim, var_dim), activations (relu, sigmoid,
10//! tanh, softmax, log_softmax, gelu, silu, elu, leaky_relu), matmul (CPU
11//! via CpuBackend::matmul with GEMV fast path, GPU via cuBLAS + quantized
12//! Q4_K/Q6_K dispatch), indexing (gather, scatter, nonzero, unique, sort,
13//! argsort, topk), comparison (eq, gt, lt), cast (to, to_device), data
14//! access (to_vec, get, item), map/zip_map/zip_map3, and Display.
15//!
16//! # File
17//! `crates/axonml-tensor/src/tensor.rs`
18//!
19//! # Author
20//! Andrew Jewell Sr. — AutomataNexus LLC
21//! ORCID: 0009-0005-2158-7060
22//!
23//! # Updated
24//! April 14, 2026 11:15 PM EST
25//!
26//! # Disclaimer
27//! Use at own risk. This software is provided "as is", without warranty of any
28//! kind, express or implied. The author and AutomataNexus shall not be held
29//! liable for any damages arising from the use of this software.
30
31use core::fmt;
32use core::ops::{Add, Div, Mul, Neg, Sub};
33
34use axonml_core::Device;
35use axonml_core::backends::CpuBackend;
36#[cfg(feature = "cuda")]
37use axonml_core::backends::CudaBackend;
38use axonml_core::dtype::{Float, Numeric, Scalar};
39use axonml_core::error::{Error, Result};
40use axonml_core::storage::Storage;
41use num_traits::NumCast;
42
43// =============================================================================
44// CUDA Acceleration
45// =============================================================================
46
47#[cfg(feature = "cuda")]
48mod cuda_accel {
49    use super::*;
50    use axonml_core::backends::cuda::get_cuda_backend;
51
52    /// Get the global CUDA backend (delegates to core singleton).
53    pub fn get_cuda() -> Option<&'static CudaBackend> {
54        get_cuda_backend()
55    }
56
57    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
58    /// Returns None if GPU is unavailable or an error occurs.
59    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
60        let cuda = get_cuda()?;
61
62        let a_gpu = cuda.htod_copy(a).ok()?;
63        let b_gpu = cuda.htod_copy(b).ok()?;
64        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
65
66        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
67        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
68        cuda.gemm_f32(
69            false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
70        )
71        .ok()?;
72
73        cuda.dtoh_copy(&c_gpu).ok()
74    }
75}
76
77use crate::shape::{
78    Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
79    linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
80    unsqueeze,
81};
82
83// =============================================================================
84// GPU Dispatch Helpers
85// =============================================================================
86//
87// These enable calling Tensor<f32> GPU methods from generic Tensor<T> code
88// when T is verified to be f32 via TypeId check at runtime.
89
90#[cfg(feature = "cuda")]
91unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
92    assert!(
93        is_f32::<T>(),
94        "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
95        T::DTYPE
96    );
97    // SAFETY: T is f32 (asserted above), Tensor<f32> and Tensor<T> have identical layout
98    unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
99}
100
101#[cfg(feature = "cuda")]
102unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
103    assert!(
104        is_f32::<T>(),
105        "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
106        T::DTYPE
107    );
108    // SAFETY: T is f32 (asserted above), ownership transfer via ptr::read + forget
109    unsafe {
110        let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
111        std::mem::forget(t);
112        out
113    }
114}
115
116#[cfg(feature = "cuda")]
117fn is_f32<T: 'static>() -> bool {
118    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
119}
120
121// =============================================================================
122// Tensor Struct
123// =============================================================================
124
125/// An N-dimensional array of numeric values.
126///
127/// Tensors are the core data structure for all computations in Axonml.
128/// They support arbitrary dimensions, automatic broadcasting, and efficient
129/// memory sharing between views.
130#[derive(Clone)]
131pub struct Tensor<T: Scalar> {
132    /// Underlying data storage (reference-counted).
133    pub(crate) storage: Storage<T>,
134    /// Shape of the tensor (dimensions).
135    pub(crate) shape: Shape,
136    /// Strides for each dimension.
137    pub(crate) strides: Strides,
138    /// Offset into storage (for views).
139    pub(crate) offset: usize,
140}
141
142impl<T: Scalar> Tensor<T> {
143    // =========================================================================
144    // Constructors
145    // =========================================================================
146
147    /// Creates a new tensor from storage with the given shape.
148    ///
149    /// # Arguments
150    /// * `storage` - The underlying data storage
151    /// * `shape` - Shape of the tensor
152    ///
153    /// # Returns
154    /// New tensor, or error if shape doesn't match storage size.
155    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
156        let total = numel(shape);
157        if total != storage.len() {
158            return Err(Error::shape_mismatch(&[storage.len()], shape));
159        }
160
161        let shape = Shape::from_slice(shape);
162        let strides = contiguous_strides(&shape);
163
164        Ok(Self {
165            storage,
166            shape,
167            strides,
168            offset: 0,
169        })
170    }
171
172    /// Creates a new tensor from a vector with the given shape.
173    ///
174    /// # Arguments
175    /// * `data` - Vector of data
176    /// * `shape` - Shape of the tensor
177    ///
178    /// # Returns
179    /// New tensor, or error if shape doesn't match data length.
180    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
181        let storage = Storage::from_vec(data, Device::Cpu);
182        Self::from_storage(storage, shape)
183    }
184
185    /// Creates a new tensor from a slice with the given shape.
186    ///
187    /// # Arguments
188    /// * `data` - Slice of data to copy
189    /// * `shape` - Shape of the tensor
190    ///
191    /// # Returns
192    /// New tensor, or error if shape doesn't match data length.
193    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
194        let storage = Storage::from_slice(data, Device::Cpu);
195        Self::from_storage(storage, shape)
196    }
197
198    /// Creates a scalar tensor (0-dimensional).
199    ///
200    /// # Arguments
201    /// * `value` - The scalar value
202    ///
203    /// # Returns
204    /// New 0-dimensional tensor.
205    pub fn scalar(value: T) -> Self {
206        Self {
207            storage: Storage::from_vec(vec![value], Device::Cpu),
208            shape: Shape::new(),
209            strides: Strides::new(),
210            offset: 0,
211        }
212    }
213
214    /// Creates a tensor filled with zeros.
215    #[must_use]
216    pub fn zeros(shape: &[usize]) -> Self {
217        crate::creation::zeros(shape)
218    }
219
220    /// Creates a tensor filled with ones.
221    #[must_use]
222    pub fn ones(shape: &[usize]) -> Self
223    where
224        T: Numeric,
225    {
226        crate::creation::ones(shape)
227    }
228
229    /// Creates a tensor filled with a constant value.
230    #[must_use]
231    pub fn full(shape: &[usize], value: T) -> Self {
232        crate::creation::full(shape, value)
233    }
234
235    /// Creates a tensor with random values from standard normal distribution.
236    #[must_use]
237    pub fn randn(shape: &[usize]) -> Self
238    where
239        T: Float,
240        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
241    {
242        crate::creation::randn(shape)
243    }
244
245    /// Creates a tensor with random values from uniform distribution [0, 1).
246    #[must_use]
247    pub fn rand(shape: &[usize]) -> Self
248    where
249        T: Float,
250        rand::distributions::Standard: rand::distributions::Distribution<T>,
251    {
252        crate::creation::rand(shape)
253    }
254
255    // =========================================================================
256    // Properties
257    // =========================================================================
258
259    /// Returns the shape of the tensor.
260    #[must_use]
261    pub fn shape(&self) -> &[usize] {
262        &self.shape
263    }
264
265    /// Returns the strides of the tensor.
266    #[must_use]
267    pub fn strides(&self) -> &[isize] {
268        &self.strides
269    }
270
271    /// Returns the number of dimensions.
272    #[must_use]
273    pub fn ndim(&self) -> usize {
274        self.shape.len()
275    }
276
277    /// Returns the total number of elements.
278    #[must_use]
279    pub fn numel(&self) -> usize {
280        numel(&self.shape)
281    }
282
283    /// Returns true if the tensor is empty (has zero elements).
284    #[must_use]
285    pub fn is_empty(&self) -> bool {
286        self.numel() == 0
287    }
288
289    /// Returns the size of a specific dimension.
290    ///
291    /// # Arguments
292    /// * `dim` - Dimension index (supports negative indexing)
293    pub fn size(&self, dim: i64) -> Result<usize> {
294        let idx = normalize_dim(dim, self.ndim())?;
295        Ok(self.shape[idx])
296    }
297
298    /// Returns the device this tensor is on.
299    #[must_use]
300    pub fn device(&self) -> Device {
301        self.storage.device()
302    }
303
304    /// Returns true if the tensor is contiguous in memory.
305    #[must_use]
306    pub fn is_contiguous(&self) -> bool {
307        is_contiguous(&self.shape, &self.strides)
308    }
309
310    /// Returns true if this tensor is a scalar (0-dimensional).
311    #[must_use]
312    pub fn is_scalar(&self) -> bool {
313        self.shape.is_empty()
314    }
315
316    // =========================================================================
317    // Data Access
318    // =========================================================================
319
320    /// Returns the element at the given indices.
321    ///
322    /// # Arguments
323    /// * `indices` - Multi-dimensional indices
324    pub fn get(&self, indices: &[usize]) -> Result<T> {
325        if indices.len() != self.ndim() {
326            return Err(Error::invalid_operation(format!(
327                "Expected {} indices, got {}",
328                self.ndim(),
329                indices.len()
330            )));
331        }
332
333        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
334            if idx >= dim {
335                return Err(Error::IndexOutOfBounds {
336                    index: idx,
337                    size: dim,
338                });
339            }
340        }
341
342        let offset = self.offset + linear_index(indices, &self.strides);
343        Ok(self.storage.as_slice()[offset])
344    }
345
346    /// Sets the element at the given indices.
347    ///
348    /// # Arguments
349    /// * `indices` - Multi-dimensional indices
350    /// * `value` - Value to set
351    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
352        if indices.len() != self.ndim() {
353            return Err(Error::invalid_operation(format!(
354                "Expected {} indices, got {}",
355                self.ndim(),
356                indices.len()
357            )));
358        }
359
360        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
361            if idx >= dim {
362                return Err(Error::IndexOutOfBounds {
363                    index: idx,
364                    size: dim,
365                });
366            }
367        }
368
369        let offset = self.offset + linear_index(indices, &self.strides);
370        self.storage.as_slice_mut()[offset] = value;
371        Ok(())
372    }
373
374    /// Returns the scalar value for a 0-dimensional tensor.
375    pub fn item(&self) -> Result<T> {
376        if self.numel() != 1 {
377            return Err(Error::invalid_operation(
378                "item() only works on single-element tensors",
379            ));
380        }
381
382        // Use to_vec() which handles both CPU and GPU tensors safely
383        let data = self.to_vec();
384        if data.is_empty() {
385            Err(Error::invalid_operation("item() on empty tensor"))
386        } else {
387            Ok(data[0])
388        }
389    }
390
391    /// Returns the data as a contiguous vector.
392    ///
393    /// If the tensor is already contiguous, this returns a reference.
394    /// Otherwise, it copies the data into a new contiguous vector.
395    /// For GPU tensors (f32 only), performs a D2H copy.
396    #[must_use]
397    pub fn to_vec(&self) -> Vec<T> {
398        // GPU path: GPU storage is always f32
399        #[cfg(feature = "cuda")]
400        if self.storage.is_gpu() {
401            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
402            let self_f32 = unsafe { gpu_ref(self) };
403            let f32_vec = self_f32.to_vec_gpu();
404            unsafe {
405                let mut v = std::mem::ManuallyDrop::new(f32_vec);
406                return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
407            }
408        }
409
410        if self.is_contiguous() {
411            let storage = self.storage.as_slice();
412            storage[self.offset..self.offset + self.numel()].to_vec()
413        } else {
414            let mut result = Vec::with_capacity(self.numel());
415            self.copy_data_to(&mut result);
416            result
417        }
418    }
419
420    /// Copies data to a slice, handling non-contiguous layouts.
421    fn copy_data_to(&self, dst: &mut Vec<T>) {
422        dst.clear();
423        let storage = self.storage.as_slice();
424
425        // Iterate through all indices
426        let total = self.numel();
427        for i in 0..total {
428            let indices = crate::shape::unravel_index(i, &self.shape);
429            let offset = self.offset + linear_index(&indices, &self.strides);
430            dst.push(storage[offset]);
431        }
432    }
433
434    // =========================================================================
435    // Shape Operations
436    // =========================================================================
437
438    /// Returns a new tensor with the specified shape.
439    ///
440    /// The total number of elements must remain the same.
441    /// Supports -1 in one dimension to infer the size.
442    ///
443    /// # Arguments
444    /// * `new_shape` - Target shape
445    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
446        let shape = reshape(&self.shape, new_shape)?;
447
448        if self.is_contiguous() {
449            // Can just change shape without copying
450            Ok(Self {
451                storage: self.storage.clone(),
452                strides: contiguous_strides(&shape),
453                shape,
454                offset: self.offset,
455            })
456        } else {
457            // Need to make contiguous first
458            let contig = self.contiguous();
459            Ok(Self {
460                storage: contig.storage,
461                strides: contiguous_strides(&shape),
462                shape,
463                offset: 0,
464            })
465        }
466    }
467
468    /// Returns a new tensor with a flattened shape.
469    #[must_use]
470    pub fn flatten(&self) -> Self {
471        self.reshape(&[-1]).expect("Flatten should never fail")
472    }
473
474    /// Returns a new tensor with dimensions of size 1 removed.
475    ///
476    /// # Arguments
477    /// * `dim` - Optional specific dimension to squeeze
478    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
479        let dim = match dim {
480            Some(d) => Some(normalize_dim(d, self.ndim())?),
481            None => None,
482        };
483
484        let new_shape = squeeze(&self.shape, dim);
485        let new_strides: Strides = match dim {
486            Some(d) => {
487                let mut s = self.strides.clone();
488                if d < self.shape.len() && self.shape[d] == 1 {
489                    s.remove(d);
490                }
491                s
492            }
493            None => self
494                .shape
495                .iter()
496                .zip(self.strides.iter())
497                .filter(|(dim, _)| **dim != 1)
498                .map(|(_, stride)| *stride)
499                .collect(),
500        };
501
502        Ok(Self {
503            storage: self.storage.clone(),
504            shape: new_shape,
505            strides: new_strides,
506            offset: self.offset,
507        })
508    }
509
510    /// Returns a new tensor with a dimension of size 1 inserted.
511    ///
512    /// # Arguments
513    /// * `dim` - Position to insert the new dimension
514    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
515        let normalized = if dim < 0 {
516            (dim + self.ndim() as i64 + 1) as usize
517        } else {
518            dim as usize
519        };
520
521        let new_shape = unsqueeze(&self.shape, normalized)?;
522        let mut new_strides = Strides::with_capacity(new_shape.len());
523
524        for (i, _) in new_shape.iter().enumerate() {
525            if i < normalized {
526                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
527            } else if i == normalized {
528                // Stride for new dimension (doesn't matter since size is 1)
529                new_strides.push(1);
530            } else {
531                new_strides.push(self.strides[i - 1]);
532            }
533        }
534
535        Ok(Self {
536            storage: self.storage.clone(),
537            shape: new_shape,
538            strides: new_strides,
539            offset: self.offset,
540        })
541    }
542
543    /// Transposes two dimensions.
544    ///
545    /// # Arguments
546    /// * `dim0` - First dimension
547    /// * `dim1` - Second dimension
548    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
549        let d0 = normalize_dim(dim0, self.ndim())?;
550        let d1 = normalize_dim(dim1, self.ndim())?;
551
552        let new_shape = transpose_shape(&self.shape, d0, d1)?;
553        let new_strides = transpose_strides(&self.strides, d0, d1);
554
555        Ok(Self {
556            storage: self.storage.clone(),
557            shape: new_shape,
558            strides: new_strides,
559            offset: self.offset,
560        })
561    }
562
563    /// Returns the transpose of a 2D tensor.
564    pub fn t(&self) -> Result<Self> {
565        if self.ndim() != 2 {
566            return Err(Error::invalid_operation("t() only works on 2D tensors"));
567        }
568        self.transpose(0, 1)
569    }
570
571    /// Returns a permuted tensor with dimensions reordered.
572    ///
573    /// # Arguments
574    /// * `dims` - New order of dimensions
575    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
576        if dims.len() != self.ndim() {
577            return Err(Error::invalid_operation(format!(
578                "Expected {} dimensions, got {}",
579                self.ndim(),
580                dims.len()
581            )));
582        }
583
584        // Check that dims is a permutation
585        let mut seen = vec![false; self.ndim()];
586        for &d in dims {
587            if d >= self.ndim() {
588                return Err(Error::InvalidDimension {
589                    index: d as i64,
590                    ndim: self.ndim(),
591                });
592            }
593            if seen[d] {
594                return Err(Error::invalid_operation("Duplicate dimension in permute"));
595            }
596            seen[d] = true;
597        }
598
599        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
600        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
601
602        Ok(Self {
603            storage: self.storage.clone(),
604            shape: new_shape,
605            strides: new_strides,
606            offset: self.offset,
607        })
608    }
609
610    /// Returns a contiguous copy of the tensor.
611    #[must_use]
612    pub fn contiguous(&self) -> Self {
613        if self.is_contiguous() && self.offset == 0 {
614            return self.clone();
615        }
616
617        #[cfg(feature = "cuda")]
618        if self.storage.is_gpu() {
619            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
620            let self_f32 = unsafe { gpu_ref(self) };
621            let result = self_f32.contiguous_gpu();
622            return unsafe { gpu_into(result) };
623        }
624
625        let data = self.to_vec();
626        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
627    }
628
629    // =========================================================================
630    // Functional Map Operations (zero-copy for CPU tensors)
631    // =========================================================================
632
633    /// Apply a function element-wise, producing a new tensor with the same shape.
634    ///
635    /// Avoids the to_vec() → map → from_vec() pattern by operating directly
636    /// on contiguous storage.
637    #[must_use]
638    pub fn map<F: Fn(T) -> T>(&self, f: F) -> Self {
639        let data = self.to_vec(); // contiguous read
640        let result: Vec<T> = data.into_iter().map(f).collect();
641        Self::from_vec(result, &self.shape).unwrap()
642    }
643
644    /// Apply a binary function element-wise with another tensor of the same shape.
645    ///
646    /// This is the primary zero-allocation pattern for backward functions:
647    /// instead of `a.to_vec()` + `b.to_vec()` + zip + `from_vec()`,
648    /// use `a.zip_map(&b, |x, y| ...)` which does a single allocation.
649    #[must_use]
650    pub fn zip_map<F: Fn(T, T) -> T>(&self, other: &Self, f: F) -> Self {
651        let a = self.to_vec();
652        let b = other.to_vec();
653        debug_assert_eq!(
654            a.len(),
655            b.len(),
656            "zip_map requires same number of elements: {} vs {}",
657            a.len(),
658            b.len()
659        );
660        let result: Vec<T> = a.into_iter().zip(b).map(|(x, y)| f(x, y)).collect();
661        Self::from_vec(result, &self.shape).unwrap()
662    }
663
664    /// Apply a ternary function element-wise with two other tensors.
665    #[must_use]
666    pub fn zip_map3<F: Fn(T, T, T) -> T>(&self, b: &Self, c: &Self, f: F) -> Self {
667        let a_data = self.to_vec();
668        let b_data = b.to_vec();
669        let c_data = c.to_vec();
670        debug_assert_eq!(a_data.len(), b_data.len());
671        debug_assert_eq!(a_data.len(), c_data.len());
672        let result: Vec<T> = a_data
673            .into_iter()
674            .zip(b_data)
675            .zip(c_data)
676            .map(|((a, b), c)| f(a, b, c))
677            .collect();
678        Self::from_vec(result, &self.shape).unwrap()
679    }
680
681    // =========================================================================
682    // Device Operations
683    // =========================================================================
684
685    /// Transfers the tensor to a different device.
686    ///
687    /// # Arguments
688    /// * `device` - Target device
689    pub fn to_device(&self, device: Device) -> Result<Self> {
690        if self.device() == device {
691            return Ok(self.clone());
692        }
693
694        #[cfg(feature = "cuda")]
695        if self.storage.is_gpu() || device.is_gpu() {
696            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
697            let self_f32 = unsafe { gpu_ref(self) };
698            let result = self_f32.to_device_f32(device)?;
699            return Ok(unsafe { gpu_into(result) });
700        }
701
702        let contig = self.contiguous();
703        let new_storage = contig.storage.to_device(device)?;
704
705        Ok(Self {
706            storage: new_storage,
707            shape: self.shape.clone(),
708            strides: self.strides.clone(),
709            offset: 0,
710        })
711    }
712
713    /// Transfers to CPU.
714    pub fn cpu(&self) -> Result<Self> {
715        self.to_device(Device::Cpu)
716    }
717
718    // =========================================================================
719    // Deep Copy
720    // =========================================================================
721
722    /// Creates a deep copy of this tensor with its own storage.
723    #[must_use]
724    pub fn clone_deep(&self) -> Self {
725        let data = self.to_vec();
726        let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
727        #[cfg(feature = "cuda")]
728        if self.device().is_gpu() {
729            return cpu.to_device(self.device()).unwrap();
730        }
731        cpu
732    }
733}
734
735// =============================================================================
736// Numeric Operations
737// =============================================================================
738
739impl<T: Numeric> Tensor<T> {
740    /// Fills the tensor with a value.
741    ///
742    /// # Panics
743    /// Panics on GPU tensors. Use `Tensor::from_vec(vec![value; n], shape)`
744    /// followed by `.to_device()` instead.
745    pub fn fill_(&self, value: T) {
746        assert!(
747            self.storage.is_cpu(),
748            "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
749        );
750        let mut data = self.storage.as_slice_mut();
751        CpuBackend::fill(&mut data, value);
752    }
753
754    /// Fills the tensor with zeros.
755    pub fn zero_(&self) {
756        self.fill_(T::zero());
757    }
758
759    // =========================================================================
760    // Reduction Operations
761    // =========================================================================
762
763    /// Returns the sum of all elements as a scalar tensor.
764    ///
765    /// On GPU, uses native CUDA reduction kernels (no CPU round-trip).
766    #[must_use]
767    pub fn sum(&self) -> Self {
768        #[cfg(feature = "cuda")]
769        if self.device().is_gpu() {
770            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
771            let self_f32 = unsafe { gpu_ref(self) };
772            let mut t = self_f32.clone();
773            while t.ndim() > 1 {
774                t = t.sum_dim_cuda(0);
775            }
776            if t.numel() > 1 {
777                t = t.sum_dim_cuda(0);
778            }
779            return unsafe { gpu_into(t) };
780        }
781
782        let data = self.to_vec();
783        let result = CpuBackend::sum(&data);
784        Self::scalar(result)
785    }
786
787    /// Returns the product of all elements.
788    ///
789    /// GPU: D2H round-trip (no CUDA prod reduction kernel yet).
790    #[must_use]
791    pub fn prod(&self) -> Self {
792        let data = self.to_vec();
793        let result = CpuBackend::prod(&data);
794        let s = Self::scalar(result);
795        #[cfg(feature = "cuda")]
796        if self.device().is_gpu() {
797            return s
798                .to_device(self.device())
799                .expect("prod: device transfer failed");
800        }
801        s
802    }
803
804    /// Returns the maximum element.
805    ///
806    /// GPU: D2H round-trip (no CUDA max reduction kernel yet).
807    pub fn max(&self) -> Result<Self> {
808        if self.is_empty() {
809            return Err(Error::EmptyTensor);
810        }
811        let data = self.to_vec();
812        let result = CpuBackend::max(&data).expect("max on non-empty tensor");
813        let s = Self::scalar(result);
814        #[cfg(feature = "cuda")]
815        if self.device().is_gpu() {
816            return Ok(s
817                .to_device(self.device())
818                .expect("max: device transfer failed"));
819        }
820        Ok(s)
821    }
822
823    /// Returns the minimum element.
824    ///
825    /// GPU: D2H round-trip (no CUDA min reduction kernel yet).
826    pub fn min(&self) -> Result<Self> {
827        if self.is_empty() {
828            return Err(Error::EmptyTensor);
829        }
830        let data = self.to_vec();
831        let result = CpuBackend::min(&data).expect("min on non-empty tensor");
832        let s = Self::scalar(result);
833        #[cfg(feature = "cuda")]
834        if self.device().is_gpu() {
835            return Ok(s
836                .to_device(self.device())
837                .expect("min: device transfer failed"));
838        }
839        Ok(s)
840    }
841
842    /// Returns the index of the maximum element.
843    pub fn argmax(&self) -> Result<usize> {
844        if self.is_empty() {
845            return Err(Error::EmptyTensor);
846        }
847        let data = self.to_vec();
848        Ok(CpuBackend::argmax(&data).unwrap())
849    }
850
851    /// Returns the index of the minimum element.
852    pub fn argmin(&self) -> Result<usize> {
853        if self.is_empty() {
854            return Err(Error::EmptyTensor);
855        }
856        let data = self.to_vec();
857        Ok(CpuBackend::argmin(&data).unwrap())
858    }
859
860    /// Concatenates tensors along a dimension.
861    ///
862    /// All tensors must have the same shape except along the cat dimension.
863    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
864        if tensors.is_empty() {
865            return Err(Error::invalid_operation("cat requires at least one tensor"));
866        }
867        let ndim = tensors[0].ndim();
868        if dim >= ndim {
869            return Err(Error::invalid_operation("cat dimension out of range"));
870        }
871
872        for t in &tensors[1..] {
873            if t.ndim() != ndim {
874                return Err(Error::invalid_operation(
875                    "cat: all tensors must have same ndim",
876                ));
877            }
878            for d in 0..ndim {
879                if d != dim && t.shape[d] != tensors[0].shape[d] {
880                    return Err(Error::invalid_operation(
881                        "cat: shapes must match on non-cat dims",
882                    ));
883                }
884            }
885        }
886
887        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
888        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
889        out_shape[dim] = total_dim_size;
890
891        let outer_size: usize = out_shape[..dim].iter().product();
892        let inner_size: usize = out_shape[dim + 1..].iter().product();
893        let total_numel: usize = out_shape.iter().product();
894        let mut result = vec![T::zero(); total_numel];
895
896        let mut dim_offset = 0;
897        for t in tensors {
898            let t_data = t.contiguous().to_vec();
899            let t_dim_size = t.shape[dim];
900            for outer in 0..outer_size {
901                for d in 0..t_dim_size {
902                    let src_base = outer * t_dim_size * inner_size + d * inner_size;
903                    let dst_base =
904                        outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
905                    result[dst_base..dst_base + inner_size]
906                        .copy_from_slice(&t_data[src_base..src_base + inner_size]);
907                }
908            }
909            dim_offset += t_dim_size;
910        }
911
912        let out = Self::from_vec(result, &out_shape)?;
913        #[cfg(feature = "cuda")]
914        if tensors[0].device().is_gpu() {
915            return Ok(out.to_device(tensors[0].device()).unwrap());
916        }
917        Ok(out)
918    }
919}
920
921// =============================================================================
922// Float Operations
923// =============================================================================
924
925impl<T: Float> Tensor<T> {
926    /// Returns the mean of all elements.
927    /// Returns the mean of all elements.
928    ///
929    /// On GPU, uses native CUDA sum reduction then divides by numel.
930    pub fn mean(&self) -> Result<Self> {
931        if self.is_empty() {
932            return Err(Error::EmptyTensor);
933        }
934        #[cfg(feature = "cuda")]
935        if self.device().is_gpu() {
936            let s = self.sum(); // uses CUDA sum_dim chain
937            let n = self.numel() as f32;
938            // mul_scalar stays on GPU
939            return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
940        }
941
942        let data = self.to_vec();
943        let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
944        Ok(Self::scalar(result))
945    }
946
947    // =========================================================================
948    // Activation Functions
949    // =========================================================================
950
951    /// Applies `ReLU` activation: max(0, x).
952    #[must_use]
953    pub fn relu(&self) -> Self {
954        #[cfg(feature = "cuda")]
955        if self.device().is_gpu() {
956            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
957            return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
958        }
959        let data = self.to_vec();
960        let mut result = vec![T::zero(); data.len()];
961        CpuBackend::relu(&mut result, &data);
962        Self::from_vec(result, &self.shape).unwrap()
963    }
964
965    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
966    #[must_use]
967    pub fn sigmoid(&self) -> Self {
968        #[cfg(feature = "cuda")]
969        if self.device().is_gpu() {
970            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
971            return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
972        }
973        let data = self.to_vec();
974        let mut result = vec![T::zero(); data.len()];
975        CpuBackend::sigmoid(&mut result, &data);
976        Self::from_vec(result, &self.shape).unwrap()
977    }
978
979    /// Applies tanh activation.
980    #[must_use]
981    pub fn tanh(&self) -> Self {
982        #[cfg(feature = "cuda")]
983        if self.device().is_gpu() {
984            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
985            return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
986        }
987        let data = self.to_vec();
988        let mut result = vec![T::zero(); data.len()];
989        CpuBackend::tanh(&mut result, &data);
990        Self::from_vec(result, &self.shape).unwrap()
991    }
992
993    /// Applies exponential function.
994    #[must_use]
995    pub fn exp(&self) -> Self {
996        #[cfg(feature = "cuda")]
997        if self.device().is_gpu() {
998            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
999            return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
1000        }
1001        let data = self.to_vec();
1002        let mut result = vec![T::zero(); data.len()];
1003        CpuBackend::exp(&mut result, &data);
1004        Self::from_vec(result, &self.shape).unwrap()
1005    }
1006
1007    /// Applies natural logarithm.
1008    #[must_use]
1009    pub fn ln(&self) -> Self {
1010        #[cfg(feature = "cuda")]
1011        if self.device().is_gpu() {
1012            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1013            return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
1014        }
1015        let data = self.to_vec();
1016        let mut result = vec![T::zero(); data.len()];
1017        CpuBackend::ln(&mut result, &data);
1018        Self::from_vec(result, &self.shape).unwrap()
1019    }
1020
1021    /// Applies square root.
1022    #[must_use]
1023    pub fn sqrt(&self) -> Self {
1024        #[cfg(feature = "cuda")]
1025        if self.device().is_gpu() {
1026            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1027            return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
1028        }
1029        let data = self.to_vec();
1030        let mut result = vec![T::zero(); data.len()];
1031        CpuBackend::sqrt(&mut result, &data);
1032        Self::from_vec(result, &self.shape).unwrap()
1033    }
1034
1035    /// Computes element-wise power.
1036    #[must_use]
1037    pub fn pow(&self, exp: T) -> Self {
1038        #[cfg(feature = "cuda")]
1039        if self.device().is_gpu() {
1040            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1041            let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
1042            return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
1043        }
1044        let data = self.to_vec();
1045        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
1046        Self::from_vec(result, &self.shape).unwrap()
1047    }
1048
1049    /// GELU activation function (Gaussian Error Linear Unit).
1050    #[must_use]
1051    pub fn gelu(&self) -> Self {
1052        #[cfg(feature = "cuda")]
1053        if self.device().is_gpu() {
1054            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1055            return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
1056        }
1057        crate::ops::gelu(self)
1058    }
1059
1060    /// SiLU/Swish activation function.
1061    #[must_use]
1062    pub fn silu(&self) -> Self {
1063        #[cfg(feature = "cuda")]
1064        if self.device().is_gpu() {
1065            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1066            return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
1067        }
1068        crate::ops::silu(self)
1069    }
1070
1071    /// Softmax along specified dimension.
1072    #[must_use]
1073    pub fn softmax(&self, dim: i32) -> Self {
1074        #[cfg(feature = "cuda")]
1075        if self.device().is_gpu() {
1076            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1077            let self_f32 = unsafe { gpu_ref(self) };
1078            return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1079        }
1080        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1081    }
1082
1083    /// Log softmax along specified dimension.
1084    #[must_use]
1085    pub fn log_softmax(&self, dim: i32) -> Self {
1086        let softmax_result = self.softmax(dim);
1087        softmax_result.ln()
1088    }
1089
1090    /// Mean along a dimension.
1091    #[must_use]
1092    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1093        let ndim = self.ndim();
1094        let dim = if dim < 0 {
1095            (ndim as i32 + dim) as usize
1096        } else {
1097            dim as usize
1098        };
1099
1100        if dim >= ndim {
1101            return self.clone();
1102        }
1103
1104        // GPU fast path: sum_dim then divide by dim_size (all on GPU)
1105        #[cfg(feature = "cuda")]
1106        if self.device().is_gpu() {
1107            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1108            let self_f32 = unsafe { gpu_ref(self) };
1109            let summed = if keepdim {
1110                self_f32.sum_dim_keepdim_cuda(dim)
1111            } else {
1112                self_f32.sum_dim_cuda(dim)
1113            };
1114            let dim_size = self.shape[dim];
1115            let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1116            return unsafe { gpu_into(result) };
1117        }
1118
1119        let dim_size = self.shape[dim];
1120        let data = self.to_vec();
1121        let mut new_shape = self.shape.clone();
1122
1123        if keepdim {
1124            new_shape[dim] = 1;
1125        } else {
1126            new_shape.remove(dim);
1127        }
1128
1129        if new_shape.is_empty() {
1130            new_shape = smallvec::smallvec![1];
1131        }
1132
1133        let new_numel: usize = new_shape.iter().product();
1134        let mut result = vec![T::zero(); new_numel];
1135
1136        let outer_size: usize = self.shape[..dim].iter().product();
1137        let inner_size: usize = self.shape[dim + 1..].iter().product();
1138
1139        for outer in 0..outer_size {
1140            for inner in 0..inner_size {
1141                let mut sum = T::zero();
1142                for d in 0..dim_size {
1143                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1144                    sum = sum + data[idx];
1145                }
1146                let mean = sum / NumCast::from(dim_size).unwrap();
1147                let result_idx = outer * inner_size + inner;
1148                result[result_idx] = mean;
1149            }
1150        }
1151
1152        Self::from_vec(result, &new_shape).unwrap()
1153    }
1154
1155    /// Sum along a dimension.
1156    #[must_use]
1157    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1158        let ndim = self.ndim();
1159        let dim = if dim < 0 {
1160            (ndim as i32 + dim) as usize
1161        } else {
1162            dim as usize
1163        };
1164
1165        if dim >= ndim {
1166            return self.clone();
1167        }
1168
1169        // GPU fast path: use CUDA sum_dim kernel (no CPU copies)
1170        #[cfg(feature = "cuda")]
1171        if self.device().is_gpu() {
1172            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1173            let self_f32 = unsafe { gpu_ref(self) };
1174            let result = if keepdim {
1175                self_f32.sum_dim_keepdim_cuda(dim)
1176            } else {
1177                self_f32.sum_dim_cuda(dim)
1178            };
1179            return unsafe { gpu_into(result) };
1180        }
1181
1182        let dim_size = self.shape[dim];
1183        let data = self.to_vec();
1184        let mut new_shape = self.shape.clone();
1185
1186        if keepdim {
1187            new_shape[dim] = 1;
1188        } else {
1189            new_shape.remove(dim);
1190        }
1191
1192        if new_shape.is_empty() {
1193            new_shape = smallvec::smallvec![1];
1194        }
1195
1196        let new_numel: usize = new_shape.iter().product();
1197        let mut result = vec![T::zero(); new_numel];
1198
1199        let outer_size: usize = self.shape[..dim].iter().product();
1200        let inner_size: usize = self.shape[dim + 1..].iter().product();
1201
1202        for outer in 0..outer_size {
1203            for inner in 0..inner_size {
1204                let mut sum = T::zero();
1205                for d in 0..dim_size {
1206                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
1207                    sum = sum + data[idx];
1208                }
1209                let result_idx = outer * inner_size + inner;
1210                result[result_idx] = sum;
1211            }
1212        }
1213
1214        Self::from_vec(result, &new_shape).unwrap()
1215    }
1216
1217    /// Variance along a dimension.
1218    #[must_use]
1219    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1220        // variance = E[x²] - E[x]²  (saves one full-size intermediate allocation)
1221        let mean = self.mean_dim(dim, true);
1222        let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1223        let mean_sq = sq.mean_dim(dim, keepdim);
1224        let mean_keepdim = if keepdim {
1225            mean.clone()
1226        } else {
1227            self.mean_dim(dim, keepdim)
1228        };
1229        let mean_squared = mean_keepdim
1230            .mul(&mean_keepdim)
1231            .unwrap_or_else(|_| mean_keepdim.clone());
1232        mean_sq
1233            .sub(&mean_squared)
1234            .unwrap_or_else(|_| mean_sq.clone())
1235    }
1236
1237    /// Broadcasts tensor to a new shape.
1238    #[must_use]
1239    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1240        if self.shape.as_slice() == shape {
1241            return self.clone();
1242        }
1243
1244        #[cfg(feature = "cuda")]
1245        if self.device().is_gpu() {
1246            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1247            let self_f32 = unsafe { gpu_ref(self) };
1248            return unsafe {
1249                gpu_into(
1250                    self_f32
1251                        .broadcast_to_cuda(shape)
1252                        .expect("CUDA broadcast_to failed"),
1253                )
1254            };
1255        }
1256
1257        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1258        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1259
1260        let total = numel(&result_shape);
1261        let mut result_data = vec![T::zero(); total];
1262        let self_data = self.storage.as_slice();
1263
1264        for i in 0..total {
1265            let indices = crate::shape::unravel_index(i, &result_shape);
1266            let self_idx = self.offset + linear_index(&indices, &self_strides);
1267            result_data[i] = self_data[self_idx];
1268        }
1269
1270        Self::from_vec(result_data, &result_shape).unwrap()
1271    }
1272
1273    /// Slices the tensor using ranges for each dimension.
1274    #[must_use]
1275    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1276        let mut new_shape = Vec::with_capacity(self.ndim());
1277        for (i, range) in ranges.iter().enumerate() {
1278            if i < self.ndim() {
1279                new_shape.push(range.end - range.start);
1280            }
1281        }
1282        // Keep remaining dimensions unchanged
1283        for i in ranges.len()..self.ndim() {
1284            new_shape.push(self.shape[i]);
1285        }
1286
1287        let new_numel: usize = new_shape.iter().product();
1288        let mut result_data = vec![T::zero(); new_numel];
1289        let self_data = self.to_vec();
1290
1291        // Copy data with proper indexing
1292        let mut result_idx = 0;
1293        Self::slice_recursive(
1294            &self_data,
1295            &self.shape,
1296            ranges,
1297            0,
1298            0,
1299            &mut result_data,
1300            &mut result_idx,
1301        );
1302
1303        let out = Self::from_vec(result_data, &new_shape).unwrap();
1304        #[cfg(feature = "cuda")]
1305        if self.device().is_gpu() {
1306            return out.to_device(self.device()).unwrap();
1307        }
1308        out
1309    }
1310
1311    fn slice_recursive(
1312        data: &[T],
1313        shape: &[usize],
1314        ranges: &[std::ops::Range<usize>],
1315        dim: usize,
1316        offset: usize,
1317        result: &mut [T],
1318        result_idx: &mut usize,
1319    ) {
1320        if dim == shape.len() {
1321            result[*result_idx] = data[offset];
1322            *result_idx += 1;
1323            return;
1324        }
1325
1326        let stride: usize = shape[dim + 1..].iter().product();
1327        let (start, end) = if dim < ranges.len() {
1328            (ranges[dim].start, ranges[dim].end)
1329        } else {
1330            (0, shape[dim])
1331        };
1332
1333        for i in start..end {
1334            Self::slice_recursive(
1335                data,
1336                shape,
1337                ranges,
1338                dim + 1,
1339                offset + i * stride,
1340                result,
1341                result_idx,
1342            );
1343        }
1344    }
1345}
1346
1347// =============================================================================
1348// Arithmetic Operator Implementations
1349// =============================================================================
1350
1351impl<T: Numeric> Tensor<T> {
1352    /// Element-wise addition with broadcasting.
1353    pub fn add(&self, other: &Self) -> Result<Self> {
1354        #[cfg(feature = "cuda")]
1355        {
1356            let self_gpu = self.device().is_gpu();
1357            let other_gpu = other.device().is_gpu();
1358            if self_gpu || other_gpu {
1359                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1360                if self_gpu && other_gpu {
1361                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1362                    if self.shape == other.shape {
1363                        return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1364                    } else {
1365                        return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1366                    }
1367                }
1368                // Mixed device — move to GPU, then operate
1369                let target_device = if self_gpu {
1370                    self.device()
1371                } else {
1372                    other.device()
1373                };
1374                let a_gpu = if self_gpu {
1375                    self.clone()
1376                } else {
1377                    self.to_device(target_device)?
1378                };
1379                let b_gpu = if other_gpu {
1380                    other.clone()
1381                } else {
1382                    other.to_device(target_device)?
1383                };
1384                return a_gpu.add(&b_gpu);
1385            }
1386        }
1387        // Fast path: same shape, both contiguous — no index arithmetic needed
1388        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1389            let a = self.storage.as_slice();
1390            let b = other.storage.as_slice();
1391            let ao = self.offset;
1392            let bo = other.offset;
1393            let n = numel(&self.shape);
1394            let mut result_data = vec![T::zero(); n];
1395            for i in 0..n {
1396                result_data[i] = a[ao + i] + b[bo + i];
1397            }
1398            return Self::from_vec(result_data, &self.shape);
1399        }
1400
1401        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1402        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1403        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1404
1405        let total = numel(&result_shape);
1406        let mut result_data = vec![T::zero(); total];
1407
1408        let self_data = self.storage.as_slice();
1409        let other_data = other.storage.as_slice();
1410
1411        for i in 0..total {
1412            let indices = crate::shape::unravel_index(i, &result_shape);
1413            let self_idx = self.offset + linear_index(&indices, &self_strides);
1414            let other_idx = other.offset + linear_index(&indices, &other_strides);
1415            result_data[i] = self_data[self_idx] + other_data[other_idx];
1416        }
1417
1418        Self::from_vec(result_data, &result_shape)
1419    }
1420
1421    /// Element-wise subtraction with broadcasting.
1422    pub fn sub(&self, other: &Self) -> Result<Self> {
1423        #[cfg(feature = "cuda")]
1424        {
1425            let self_gpu = self.device().is_gpu();
1426            let other_gpu = other.device().is_gpu();
1427            if self_gpu || other_gpu {
1428                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1429                if self_gpu && other_gpu {
1430                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1431                    if self.shape == other.shape {
1432                        return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1433                    } else {
1434                        return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1435                    }
1436                }
1437                let target = if self_gpu {
1438                    self.device()
1439                } else {
1440                    other.device()
1441                };
1442                let a_gpu = if self_gpu {
1443                    self.clone()
1444                } else {
1445                    self.to_device(target)?
1446                };
1447                let b_gpu = if other_gpu {
1448                    other.clone()
1449                } else {
1450                    other.to_device(target)?
1451                };
1452                return a_gpu.sub(&b_gpu);
1453            }
1454        }
1455        // Fast path: same shape, contiguous
1456        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1457            let a = self.storage.as_slice();
1458            let b = other.storage.as_slice();
1459            let (ao, bo) = (self.offset, other.offset);
1460            let n = numel(&self.shape);
1461            let mut r = vec![T::zero(); n];
1462            for i in 0..n {
1463                r[i] = a[ao + i] - b[bo + i];
1464            }
1465            return Self::from_vec(r, &self.shape);
1466        }
1467
1468        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1469        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1470        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1471
1472        let total = numel(&result_shape);
1473        let mut result_data = vec![T::zero(); total];
1474
1475        let self_data = self.storage.as_slice();
1476        let other_data = other.storage.as_slice();
1477
1478        for i in 0..total {
1479            let indices = crate::shape::unravel_index(i, &result_shape);
1480            let self_idx = self.offset + linear_index(&indices, &self_strides);
1481            let other_idx = other.offset + linear_index(&indices, &other_strides);
1482            result_data[i] = self_data[self_idx] - other_data[other_idx];
1483        }
1484
1485        Self::from_vec(result_data, &result_shape)
1486    }
1487
1488    /// Element-wise multiplication with broadcasting.
1489    pub fn mul(&self, other: &Self) -> Result<Self> {
1490        #[cfg(feature = "cuda")]
1491        {
1492            let self_gpu = self.device().is_gpu();
1493            let other_gpu = other.device().is_gpu();
1494            if self_gpu || other_gpu {
1495                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1496                if self_gpu && other_gpu {
1497                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1498                    if self.shape == other.shape {
1499                        return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1500                    } else {
1501                        return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1502                    }
1503                }
1504                let target = if self_gpu {
1505                    self.device()
1506                } else {
1507                    other.device()
1508                };
1509                let a_gpu = if self_gpu {
1510                    self.clone()
1511                } else {
1512                    self.to_device(target)?
1513                };
1514                let b_gpu = if other_gpu {
1515                    other.clone()
1516                } else {
1517                    other.to_device(target)?
1518                };
1519                return a_gpu.mul(&b_gpu);
1520            }
1521        }
1522        // Fast path: same shape, contiguous
1523        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1524            let a = self.storage.as_slice();
1525            let b = other.storage.as_slice();
1526            let (ao, bo) = (self.offset, other.offset);
1527            let n = numel(&self.shape);
1528            let mut r = vec![T::zero(); n];
1529            for i in 0..n {
1530                r[i] = a[ao + i] * b[bo + i];
1531            }
1532            return Self::from_vec(r, &self.shape);
1533        }
1534
1535        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1536        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1537        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1538
1539        let total = numel(&result_shape);
1540        let mut result_data = vec![T::zero(); total];
1541
1542        let self_data = self.storage.as_slice();
1543        let other_data = other.storage.as_slice();
1544
1545        for i in 0..total {
1546            let indices = crate::shape::unravel_index(i, &result_shape);
1547            let self_idx = self.offset + linear_index(&indices, &self_strides);
1548            let other_idx = other.offset + linear_index(&indices, &other_strides);
1549            result_data[i] = self_data[self_idx] * other_data[other_idx];
1550        }
1551
1552        Self::from_vec(result_data, &result_shape)
1553    }
1554
1555    /// Element-wise division with broadcasting.
1556    pub fn div(&self, other: &Self) -> Result<Self> {
1557        #[cfg(feature = "cuda")]
1558        {
1559            let self_gpu = self.device().is_gpu();
1560            let other_gpu = other.device().is_gpu();
1561            if self_gpu || other_gpu {
1562                assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1563                if self_gpu && other_gpu {
1564                    let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1565                    if self.shape == other.shape {
1566                        return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1567                    } else {
1568                        return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1569                    }
1570                }
1571                let target = if self_gpu {
1572                    self.device()
1573                } else {
1574                    other.device()
1575                };
1576                let a_gpu = if self_gpu {
1577                    self.clone()
1578                } else {
1579                    self.to_device(target)?
1580                };
1581                let b_gpu = if other_gpu {
1582                    other.clone()
1583                } else {
1584                    other.to_device(target)?
1585                };
1586                return a_gpu.div(&b_gpu);
1587            }
1588        }
1589        // Fast path: same shape, contiguous
1590        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1591            let a = self.storage.as_slice();
1592            let b = other.storage.as_slice();
1593            let (ao, bo) = (self.offset, other.offset);
1594            let n = numel(&self.shape);
1595            let mut r = vec![T::zero(); n];
1596            for i in 0..n {
1597                r[i] = a[ao + i] / b[bo + i];
1598            }
1599            return Self::from_vec(r, &self.shape);
1600        }
1601
1602        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1603        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1604        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1605
1606        let total = numel(&result_shape);
1607        let mut result_data = vec![T::zero(); total];
1608
1609        let self_data = self.storage.as_slice();
1610        let other_data = other.storage.as_slice();
1611
1612        for i in 0..total {
1613            let indices = crate::shape::unravel_index(i, &result_shape);
1614            let self_idx = self.offset + linear_index(&indices, &self_strides);
1615            let other_idx = other.offset + linear_index(&indices, &other_strides);
1616            result_data[i] = self_data[self_idx] / other_data[other_idx];
1617        }
1618
1619        Self::from_vec(result_data, &result_shape)
1620    }
1621
1622    /// Scalar addition.
1623    #[must_use]
1624    pub fn add_scalar(&self, scalar: T) -> Self {
1625        #[cfg(feature = "cuda")]
1626        if self.device().is_gpu() {
1627            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1628            let self_f32 = unsafe { gpu_ref(self) };
1629            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1630            return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1631        }
1632        let data = self.to_vec();
1633        let mut result = vec![T::zero(); data.len()];
1634        CpuBackend::add_scalar(&mut result, &data, scalar);
1635        Self::from_vec(result, &self.shape).unwrap()
1636    }
1637
1638    /// Scalar multiplication.
1639    #[must_use]
1640    pub fn mul_scalar(&self, scalar: T) -> Self {
1641        #[cfg(feature = "cuda")]
1642        if self.device().is_gpu() {
1643            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1644            let self_f32 = unsafe { gpu_ref(self) };
1645            let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1646            return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1647        }
1648        let data = self.to_vec();
1649        let mut result = vec![T::zero(); data.len()];
1650        CpuBackend::mul_scalar(&mut result, &data, scalar);
1651        Self::from_vec(result, &self.shape).unwrap()
1652    }
1653
1654    /// Element-wise negation.
1655    #[must_use]
1656    pub fn neg(&self) -> Self {
1657        #[cfg(feature = "cuda")]
1658        if self.device().is_gpu() {
1659            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1660            let self_f32 = unsafe { gpu_ref(self) };
1661            return unsafe { gpu_into(self_f32.neg_cuda()) };
1662        }
1663        let data = self.to_vec();
1664        let mut result = vec![T::zero(); data.len()];
1665        CpuBackend::neg(&mut result, &data);
1666        Self::from_vec(result, &self.shape).unwrap()
1667    }
1668
1669    /// Matrix multiplication with batching support.
1670    ///
1671    /// Supports:
1672    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1673    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1674    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1675    pub fn matmul(&self, other: &Self) -> Result<Self> {
1676        #[cfg(feature = "cuda")]
1677        if self.device().is_gpu() {
1678            assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1679            let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1680            return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1681        }
1682        if self.ndim() < 2 || other.ndim() < 2 {
1683            return Err(Error::invalid_operation(
1684                "matmul requires at least 2D tensors",
1685            ));
1686        }
1687
1688        let m = self.shape[self.ndim() - 2];
1689        let k1 = self.shape[self.ndim() - 1];
1690        let k2 = other.shape[other.ndim() - 2];
1691        let n = other.shape[other.ndim() - 1];
1692
1693        if k1 != k2 {
1694            return Err(Error::invalid_operation(format!(
1695                "matmul inner dimensions must match: {k1} vs {k2}"
1696            )));
1697        }
1698
1699        // For 2D matrices, do simple matmul.
1700        //
1701        // Fast path: when both tensors are already contiguous with offset 0
1702        // (the common case for pre-loaded weights and intermediate activations),
1703        // read the storage slices directly and skip the `.to_vec()` allocation-
1704        // and-copy. This is critical for LLM inference where the same weight
1705        // matrix is multiplied by a different input for every decoded token:
1706        // without this path we pay a full weight-matrix memcpy per matmul
1707        // (~10 MB × 7 matmuls × 42 layers = ~3 GB of memcpy per decode token
1708        // on an 8B model).
1709        if self.ndim() == 2 && other.ndim() == 2 {
1710            let self_fast = self.is_contiguous() && self.offset == 0;
1711            let other_fast = other.is_contiguous() && other.offset == 0;
1712
1713            // We still need owned buffers for CPU matmul (matrixmultiply + our
1714            // gemv reads from slices, but the cuda fallback owns Vec<f32>).
1715            // Whenever possible, avoid materializing the slow side — at worst
1716            // one allocation, not two.
1717            let self_storage = self.storage.as_slice();
1718            let other_storage = other.storage.as_slice();
1719
1720            let a_slice: &[T] = if self_fast {
1721                &self_storage[..m * k1]
1722            } else {
1723                // Fall back to materializing — keep the Vec alive for the call.
1724                // Hoisted into an Option so the borrow lives long enough.
1725                &[]
1726            };
1727            let b_slice: &[T] = if other_fast {
1728                &other_storage[..k1 * n]
1729            } else {
1730                &[]
1731            };
1732
1733            // Materialization fallbacks (only when we couldn't use a direct slice).
1734            let a_owned: Option<Vec<T>> = if self_fast {
1735                None
1736            } else {
1737                Some(self.contiguous().to_vec())
1738            };
1739            let b_owned: Option<Vec<T>> = if other_fast {
1740                None
1741            } else {
1742                Some(other.contiguous().to_vec())
1743            };
1744            let a: &[T] = a_owned.as_deref().unwrap_or(a_slice);
1745            let b: &[T] = b_owned.as_deref().unwrap_or(b_slice);
1746
1747            // GPU-accelerated matmul for CPU tensors: only for very large matrices
1748            // where transfer overhead is negligible relative to compute.
1749            // For GPU-resident tensors, the dispatch at the top of matmul() handles it.
1750            #[cfg(feature = "cuda")]
1751            {
1752                let flops = m * n * k1;
1753                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1754                    && flops >= 4_000_000
1755                {
1756                    debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1757                    // SAFETY: T is f32 (checked by TypeId above), same size and layout
1758                    let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
1759                    let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
1760                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1761                        // SAFETY: T is f32, Vec<f32> → Vec<T> is a no-op transmute
1762                        let c_t: Vec<T> = unsafe {
1763                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1764                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1765                        };
1766                        return Self::from_vec(c_t, &[m, n]);
1767                    }
1768                }
1769            }
1770
1771            // CpuBackend::matmul overwrites every element via sgemm/gemv
1772            // with beta=0, so the zero-init memset here is never observed.
1773            // Previously used `Vec::with_capacity + set_len` to avoid the
1774            // memset entirely, but clippy::uninit_vec flags that pattern —
1775            // and for any matmul big enough to matter, the memset cost is
1776            // dominated by the O(m*n*k) FMA work that follows. For tiny
1777            // matmuls the memset is ~μs.
1778            let mut c_data: Vec<T> = vec![T::zero(); m * n];
1779            CpuBackend::matmul(&mut c_data, a, b, m, n, k1);
1780            return Self::from_vec(c_data, &[m, n]);
1781        }
1782
1783        // For batched matmul, compute batch size
1784        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1785        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1786
1787        // Broadcast batch dimensions (PyTorch parity)
1788        let broadcast_batch = if batch_dims_self == batch_dims_other {
1789            None
1790        } else {
1791            // Pad to same length
1792            let max_len = batch_dims_self.len().max(batch_dims_other.len());
1793            let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1794            let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1795            let a_dims: Vec<usize> = pad_a
1796                .iter()
1797                .chain(batch_dims_self.iter())
1798                .copied()
1799                .collect();
1800            let b_dims: Vec<usize> = pad_b
1801                .iter()
1802                .chain(batch_dims_other.iter())
1803                .copied()
1804                .collect();
1805
1806            let mut out_dims = Vec::with_capacity(max_len);
1807            for i in 0..max_len {
1808                if a_dims[i] == b_dims[i] {
1809                    out_dims.push(a_dims[i]);
1810                } else if a_dims[i] == 1 {
1811                    out_dims.push(b_dims[i]);
1812                } else if b_dims[i] == 1 {
1813                    out_dims.push(a_dims[i]);
1814                } else {
1815                    return Err(Error::invalid_operation(format!(
1816                        "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1817                        batch_dims_self, batch_dims_other
1818                    )));
1819                }
1820            }
1821            Some((a_dims, b_dims, out_dims))
1822        };
1823
1824        let (batch_size, a_batch_idx, b_batch_idx) =
1825            if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1826                let bs: usize = out_dims.iter().product();
1827                // Build index mapping: for each output batch, which a and b batch to use
1828                let mut a_idx = Vec::with_capacity(bs);
1829                let mut b_idx = Vec::with_capacity(bs);
1830                for flat in 0..bs {
1831                    let mut remaining = flat;
1832                    let mut ai = 0usize;
1833                    let mut bi = 0usize;
1834                    let mut a_stride_acc = 1usize;
1835                    let mut b_stride_acc = 1usize;
1836                    for d in (0..out_dims.len()).rev() {
1837                        let out_d = out_dims[d];
1838                        let idx = remaining % out_d;
1839                        remaining /= out_d;
1840                        let a_d = a_dims[d];
1841                        let b_d = b_dims[d];
1842                        ai += (idx % a_d) * a_stride_acc;
1843                        bi += (idx % b_d) * b_stride_acc;
1844                        a_stride_acc *= a_d;
1845                        b_stride_acc *= b_d;
1846                    }
1847                    a_idx.push(ai);
1848                    b_idx.push(bi);
1849                }
1850                (bs, a_idx, b_idx)
1851            } else {
1852                let bs: usize = batch_dims_self.iter().product();
1853                let idx: Vec<usize> = (0..bs).collect();
1854                (bs, idx.clone(), idx)
1855            };
1856
1857        let a_stride = m * k1;
1858        let b_stride = k1 * n;
1859        let c_stride = m * n;
1860
1861        let a_data = self.contiguous().to_vec();
1862        let b_data = other.contiguous().to_vec();
1863        let mut c_data = vec![T::zero(); batch_size * m * n];
1864
1865        // Try GPU acceleration for f32 batched matmul (only for large enough matrices)
1866        #[cfg(feature = "cuda")]
1867        {
1868            let flops = m * n * k1;
1869            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1870                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1871                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1872                let mut gpu_ok = true;
1873                for batch in 0..batch_size {
1874                    let ai = a_batch_idx[batch];
1875                    let bi = b_batch_idx[batch];
1876                    let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1877                    let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1878                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1879                        c_data[batch * c_stride..(batch + 1) * c_stride]
1880                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1881                    } else {
1882                        gpu_ok = false;
1883                        break;
1884                    }
1885                }
1886                if gpu_ok {
1887                    let mut output_shape = batch_dims_self;
1888                    output_shape.push(m);
1889                    output_shape.push(n);
1890                    return Self::from_vec(c_data, &output_shape);
1891                }
1892                // Fall through to CPU if GPU failed
1893                c_data = vec![T::zero(); batch_size * m * n];
1894            }
1895        }
1896
1897        // CPU fallback: loop over batches
1898        for batch in 0..batch_size {
1899            let ai = a_batch_idx[batch];
1900            let bi = b_batch_idx[batch];
1901            let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1902            let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1903            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1904            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1905        }
1906
1907        // Build output shape: broadcast batch dims + [m, n]
1908        let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1909            out_dims.clone()
1910        } else {
1911            batch_dims_self
1912        };
1913        output_shape.push(m);
1914        output_shape.push(n);
1915
1916        Self::from_vec(c_data, &output_shape)
1917    }
1918
1919    /// Dot product for 1D tensors.
1920    pub fn dot(&self, other: &Self) -> Result<Self> {
1921        if self.ndim() != 1 || other.ndim() != 1 {
1922            return Err(Error::invalid_operation("dot requires 1D tensors"));
1923        }
1924
1925        if self.shape[0] != other.shape[0] {
1926            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1927        }
1928
1929        let a_data = self.to_vec();
1930        let b_data = other.to_vec();
1931        let result = CpuBackend::dot(&a_data, &b_data);
1932
1933        Ok(Self::scalar(result))
1934    }
1935}
1936
1937// =============================================================================
1938// Operator Trait Implementations
1939// =============================================================================
1940
1941impl<T: Numeric> Add for &Tensor<T> {
1942    type Output = Tensor<T>;
1943
1944    fn add(self, other: Self) -> Self::Output {
1945        self.add(other).expect("Addition failed")
1946    }
1947}
1948
1949impl<T: Numeric> Sub for &Tensor<T> {
1950    type Output = Tensor<T>;
1951
1952    fn sub(self, other: Self) -> Self::Output {
1953        self.sub(other).expect("Subtraction failed")
1954    }
1955}
1956
1957impl<T: Numeric> Mul for &Tensor<T> {
1958    type Output = Tensor<T>;
1959
1960    fn mul(self, other: Self) -> Self::Output {
1961        self.mul(other).expect("Multiplication failed")
1962    }
1963}
1964
1965impl<T: Numeric> Div for &Tensor<T> {
1966    type Output = Tensor<T>;
1967
1968    fn div(self, other: Self) -> Self::Output {
1969        self.div(other).expect("Division failed")
1970    }
1971}
1972
1973impl<T: Numeric> Neg for &Tensor<T> {
1974    type Output = Tensor<T>;
1975
1976    fn neg(self) -> Self::Output {
1977        self.neg()
1978    }
1979}
1980
1981// Scalar operations
1982impl<T: Numeric> Add<T> for &Tensor<T> {
1983    type Output = Tensor<T>;
1984
1985    fn add(self, scalar: T) -> Self::Output {
1986        self.add_scalar(scalar)
1987    }
1988}
1989
1990impl<T: Numeric> Mul<T> for &Tensor<T> {
1991    type Output = Tensor<T>;
1992
1993    fn mul(self, scalar: T) -> Self::Output {
1994        self.mul_scalar(scalar)
1995    }
1996}
1997
1998// =============================================================================
1999// Display Implementation
2000// =============================================================================
2001
2002impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
2003    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2004        write!(
2005            f,
2006            "Tensor(shape={:?}, device={}",
2007            self.shape(),
2008            self.device()
2009        )?;
2010        if self.numel() <= 10 {
2011            write!(f, ", data={:?}", self.to_vec())?;
2012        }
2013        write!(f, ")")
2014    }
2015}
2016
2017impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
2018    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2019        if self.is_scalar() {
2020            write!(f, "{}", self.item().unwrap())
2021        } else if self.ndim() == 1 {
2022            write!(f, "[")?;
2023            let data = self.to_vec();
2024            for (i, val) in data.iter().enumerate() {
2025                if i > 0 {
2026                    write!(f, ", ")?;
2027                }
2028                write!(f, "{val}")?;
2029            }
2030            write!(f, "]")
2031        } else {
2032            write!(f, "Tensor(shape={:?})", self.shape())
2033        }
2034    }
2035}
2036
2037// =============================================================================
2038// f32 ↔ f16 Casting for AMP (Automatic Mixed Precision)
2039// =============================================================================
2040
2041impl Tensor<f32> {
2042    /// Cast this f32 tensor to f16 values stored as f32.
2043    ///
2044    /// Each value is rounded to f16 precision. This simulates half-precision
2045    /// computation while keeping the tensor type as f32, which is how AMP
2046    /// works — the autograd graph stays f32 but computation uses f16 precision.
2047    ///
2048    /// On GPU, this uses a CUDA kernel for fast conversion.
2049    /// On CPU, this uses the `half` crate.
2050    #[must_use]
2051    pub fn to_f16_precision(&self) -> Self {
2052        let data = self.to_vec();
2053        let f16_data: Vec<f32> = data
2054            .iter()
2055            .map(|&v| {
2056                let h = half::f16::from_f32(v);
2057                h.to_f32()
2058            })
2059            .collect();
2060        Self::from_vec(f16_data, self.shape()).unwrap()
2061    }
2062
2063    /// Cast f16-precision values back to full f32 precision.
2064    ///
2065    /// This is a no-op since the data is already stored as f32.
2066    /// Included for API symmetry with `to_f16_precision()`.
2067    #[must_use]
2068    pub fn to_f32_precision(&self) -> Self {
2069        self.clone()
2070    }
2071
2072    /// Returns true if applying f16 precision would change any values.
2073    /// Useful for debugging AMP-related numerical issues.
2074    #[must_use]
2075    pub fn has_f16_rounding_error(&self) -> bool {
2076        let data = self.to_vec();
2077        data.iter().any(|&v| {
2078            let h = half::f16::from_f32(v);
2079            (h.to_f32() - v).abs() > f32::EPSILON
2080        })
2081    }
2082}
2083
2084// =============================================================================
2085// Tests
2086// =============================================================================
2087
2088#[cfg(test)]
2089mod tests {
2090    use super::*;
2091
2092    #[test]
2093    fn test_from_vec() {
2094        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2095        assert_eq!(t.shape(), &[2, 3]);
2096        assert_eq!(t.numel(), 6);
2097    }
2098
2099    #[test]
2100    fn test_get_set() {
2101        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2102        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
2103        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
2104        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
2105        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
2106
2107        t.set(&[0, 0], 99.0).unwrap();
2108        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
2109    }
2110
2111    #[test]
2112    fn test_reshape() {
2113        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2114        let r = t.reshape(&[3, 2]).expect("reshape failed");
2115        assert_eq!(r.shape(), &[3, 2]);
2116
2117        let r = t.reshape(&[-1]).expect("reshape failed");
2118        assert_eq!(r.shape(), &[6]);
2119    }
2120
2121    #[test]
2122    fn test_transpose() {
2123        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2124        let r = t.t().unwrap();
2125        assert_eq!(r.shape(), &[3, 2]);
2126        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
2127        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
2128        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
2129    }
2130
2131    #[test]
2132    fn test_arithmetic() {
2133        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2134        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
2135
2136        let c = &a + &b;
2137        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
2138
2139        let d = &a * &b;
2140        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
2141    }
2142
2143    #[test]
2144    fn test_broadcasting() {
2145        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2146        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
2147
2148        let c = &a + &b;
2149        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
2150    }
2151
2152    #[test]
2153    fn test_sum() {
2154        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
2155        let s = t.sum();
2156        assert_eq!(s.item().unwrap(), 10.0);
2157    }
2158
2159    #[test]
2160    fn test_matmul() {
2161        // 2x2 @ 2x2
2162        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2163        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
2164        let c = a.matmul(&b).unwrap();
2165
2166        assert_eq!(c.shape(), &[2, 2]);
2167        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
2168    }
2169
2170    #[test]
2171    fn test_relu() {
2172        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
2173        let r = t.relu();
2174        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
2175    }
2176
2177    #[test]
2178    fn test_scalar() {
2179        let s = Tensor::<f32>::scalar(42.0);
2180        assert!(s.is_scalar());
2181        assert_eq!(s.numel(), 1);
2182        assert_eq!(s.item().unwrap(), 42.0);
2183    }
2184}