Skip to main content

axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! The `Tensor` struct is the fundamental data structure in Axonml. It represents
4//! an N-dimensional array of numeric values with support for automatic broadcasting,
5//! device placement, and efficient memory sharing through views.
6//!
7//! # Key Features
8//! - Generic over element type (f32, f64, i32, etc.)
9//! - Efficient views with shared storage
10//! - Device-agnostic operations
11//! - Broadcasting support
12//! - Lazy operations where possible
13//!
14//! @version 0.1.0
15//! @author `AutomataNexus` Development Team
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21#[cfg(feature = "cuda")]
22use axonml_core::backends::CudaBackend;
23use axonml_core::dtype::{Float, Numeric, Scalar};
24use axonml_core::error::{Error, Result};
25use axonml_core::storage::Storage;
26use axonml_core::Device;
27use num_traits::NumCast;
28
29// =============================================================================
30// CUDA Acceleration
31// =============================================================================
32
33#[cfg(feature = "cuda")]
34mod cuda_accel {
35    use super::*;
36    use std::sync::OnceLock;
37
38    static CUDA_BACKEND: OnceLock<Option<CudaBackend>> = OnceLock::new();
39
40    /// Get the global CUDA backend (initialized lazily on first call).
41    pub fn get_cuda() -> Option<&'static CudaBackend> {
42        CUDA_BACKEND
43            .get_or_init(|| {
44                let backend = CudaBackend::new(0);
45                if backend.is_some() {
46                    eprintln!("[AxonML] CUDA backend initialized (GPU 0)");
47                }
48                backend
49            })
50            .as_ref()
51    }
52
53    /// GPU-accelerated matmul: copies data to GPU, runs cuBLAS GEMM, copies back.
54    /// Returns None if GPU is unavailable or an error occurs.
55    pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
56        let cuda = get_cuda()?;
57
58        // cuBLAS uses column-major, so we compute C = B^T @ A^T in col-major
59        // which gives us C in row-major (the layout we want).
60        let a_gpu = cuda.htod_copy(a).ok()?;
61        let b_gpu = cuda.htod_copy(b).ok()?;
62        let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
63
64        // cuBLAS GEMM: C(m,n) = A(m,k) @ B(k,n) in row-major
65        // In column-major terms: C^T(n,m) = B^T(n,k) @ A^T(k,m)
66        cuda.gemm_f32(
67            false, false,  // no transpose (we swap A and B)
68            n, m, k,       // dimensions swapped for col-major
69            1.0,
70            &b_gpu, n,     // B is "A" in col-major, lda = n
71            &a_gpu, k,     // A is "B" in col-major, ldb = k
72            0.0,
73            &mut c_gpu, n, // ldc = n
74        ).ok()?;
75
76        cuda.dtoh_copy(&c_gpu).ok()
77    }
78}
79
80use crate::shape::{
81    broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
82    normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
83    Strides,
84};
85
86// =============================================================================
87// Tensor Struct
88// =============================================================================
89
90/// An N-dimensional array of numeric values.
91///
92/// Tensors are the core data structure for all computations in Axonml.
93/// They support arbitrary dimensions, automatic broadcasting, and efficient
94/// memory sharing between views.
95#[derive(Clone)]
96pub struct Tensor<T: Scalar> {
97    /// Underlying data storage (reference-counted).
98    pub(crate) storage: Storage<T>,
99    /// Shape of the tensor (dimensions).
100    pub(crate) shape: Shape,
101    /// Strides for each dimension.
102    pub(crate) strides: Strides,
103    /// Offset into storage (for views).
104    pub(crate) offset: usize,
105}
106
107impl<T: Scalar> Tensor<T> {
108    // =========================================================================
109    // Constructors
110    // =========================================================================
111
112    /// Creates a new tensor from storage with the given shape.
113    ///
114    /// # Arguments
115    /// * `storage` - The underlying data storage
116    /// * `shape` - Shape of the tensor
117    ///
118    /// # Returns
119    /// New tensor, or error if shape doesn't match storage size.
120    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
121        let total = numel(shape);
122        if total != storage.len() {
123            return Err(Error::shape_mismatch(&[storage.len()], shape));
124        }
125
126        let shape = Shape::from_slice(shape);
127        let strides = contiguous_strides(&shape);
128
129        Ok(Self {
130            storage,
131            shape,
132            strides,
133            offset: 0,
134        })
135    }
136
137    /// Creates a new tensor from a vector with the given shape.
138    ///
139    /// # Arguments
140    /// * `data` - Vector of data
141    /// * `shape` - Shape of the tensor
142    ///
143    /// # Returns
144    /// New tensor, or error if shape doesn't match data length.
145    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
146        let storage = Storage::from_vec(data, Device::Cpu);
147        Self::from_storage(storage, shape)
148    }
149
150    /// Creates a new tensor from a slice with the given shape.
151    ///
152    /// # Arguments
153    /// * `data` - Slice of data to copy
154    /// * `shape` - Shape of the tensor
155    ///
156    /// # Returns
157    /// New tensor, or error if shape doesn't match data length.
158    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
159        let storage = Storage::from_slice(data, Device::Cpu);
160        Self::from_storage(storage, shape)
161    }
162
163    /// Creates a scalar tensor (0-dimensional).
164    ///
165    /// # Arguments
166    /// * `value` - The scalar value
167    ///
168    /// # Returns
169    /// New 0-dimensional tensor.
170    pub fn scalar(value: T) -> Self {
171        Self {
172            storage: Storage::from_vec(vec![value], Device::Cpu),
173            shape: Shape::new(),
174            strides: Strides::new(),
175            offset: 0,
176        }
177    }
178
179    /// Creates a tensor filled with zeros.
180    #[must_use]
181    pub fn zeros(shape: &[usize]) -> Self {
182        crate::creation::zeros(shape)
183    }
184
185    /// Creates a tensor filled with ones.
186    #[must_use]
187    pub fn ones(shape: &[usize]) -> Self
188    where
189        T: Numeric,
190    {
191        crate::creation::ones(shape)
192    }
193
194    /// Creates a tensor filled with a constant value.
195    #[must_use]
196    pub fn full(shape: &[usize], value: T) -> Self {
197        crate::creation::full(shape, value)
198    }
199
200    /// Creates a tensor with random values from standard normal distribution.
201    #[must_use]
202    pub fn randn(shape: &[usize]) -> Self
203    where
204        T: Float,
205        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
206    {
207        crate::creation::randn(shape)
208    }
209
210    /// Creates a tensor with random values from uniform distribution [0, 1).
211    #[must_use]
212    pub fn rand(shape: &[usize]) -> Self
213    where
214        T: Float,
215        rand::distributions::Standard: rand::distributions::Distribution<T>,
216    {
217        crate::creation::rand(shape)
218    }
219
220    // =========================================================================
221    // Properties
222    // =========================================================================
223
224    /// Returns the shape of the tensor.
225    #[must_use]
226    pub fn shape(&self) -> &[usize] {
227        &self.shape
228    }
229
230    /// Returns the strides of the tensor.
231    #[must_use]
232    pub fn strides(&self) -> &[isize] {
233        &self.strides
234    }
235
236    /// Returns the number of dimensions.
237    #[must_use]
238    pub fn ndim(&self) -> usize {
239        self.shape.len()
240    }
241
242    /// Returns the total number of elements.
243    #[must_use]
244    pub fn numel(&self) -> usize {
245        numel(&self.shape)
246    }
247
248    /// Returns true if the tensor is empty (has zero elements).
249    #[must_use]
250    pub fn is_empty(&self) -> bool {
251        self.numel() == 0
252    }
253
254    /// Returns the size of a specific dimension.
255    ///
256    /// # Arguments
257    /// * `dim` - Dimension index (supports negative indexing)
258    pub fn size(&self, dim: i64) -> Result<usize> {
259        let idx = normalize_dim(dim, self.ndim())?;
260        Ok(self.shape[idx])
261    }
262
263    /// Returns the device this tensor is on.
264    #[must_use]
265    pub fn device(&self) -> Device {
266        self.storage.device()
267    }
268
269    /// Returns true if the tensor is contiguous in memory.
270    #[must_use]
271    pub fn is_contiguous(&self) -> bool {
272        is_contiguous(&self.shape, &self.strides)
273    }
274
275    /// Returns true if this tensor is a scalar (0-dimensional).
276    #[must_use]
277    pub fn is_scalar(&self) -> bool {
278        self.shape.is_empty()
279    }
280
281    // =========================================================================
282    // Data Access
283    // =========================================================================
284
285    /// Returns the element at the given indices.
286    ///
287    /// # Arguments
288    /// * `indices` - Multi-dimensional indices
289    pub fn get(&self, indices: &[usize]) -> Result<T> {
290        if indices.len() != self.ndim() {
291            return Err(Error::invalid_operation(format!(
292                "Expected {} indices, got {}",
293                self.ndim(),
294                indices.len()
295            )));
296        }
297
298        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
299            if idx >= dim {
300                return Err(Error::IndexOutOfBounds {
301                    index: idx,
302                    size: dim,
303                });
304            }
305        }
306
307        let offset = self.offset + linear_index(indices, &self.strides);
308        Ok(self.storage.as_slice()[offset])
309    }
310
311    /// Sets the element at the given indices.
312    ///
313    /// # Arguments
314    /// * `indices` - Multi-dimensional indices
315    /// * `value` - Value to set
316    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
317        if indices.len() != self.ndim() {
318            return Err(Error::invalid_operation(format!(
319                "Expected {} indices, got {}",
320                self.ndim(),
321                indices.len()
322            )));
323        }
324
325        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
326            if idx >= dim {
327                return Err(Error::IndexOutOfBounds {
328                    index: idx,
329                    size: dim,
330                });
331            }
332        }
333
334        let offset = self.offset + linear_index(indices, &self.strides);
335        self.storage.as_slice_mut()[offset] = value;
336        Ok(())
337    }
338
339    /// Returns the scalar value for a 0-dimensional tensor.
340    pub fn item(&self) -> Result<T> {
341        if self.numel() != 1 {
342            return Err(Error::invalid_operation(
343                "item() only works on single-element tensors",
344            ));
345        }
346
347        if self.is_scalar() {
348            Ok(self.storage.as_slice()[self.offset])
349        } else {
350            // Single element but not scalar shape
351            let indices = vec![0; self.ndim()];
352            self.get(&indices)
353        }
354    }
355
356    /// Returns the data as a contiguous vector.
357    ///
358    /// If the tensor is already contiguous, this returns a reference.
359    /// Otherwise, it copies the data into a new contiguous vector.
360    #[must_use]
361    pub fn to_vec(&self) -> Vec<T> {
362        if self.is_contiguous() {
363            let storage = self.storage.as_slice();
364            storage[self.offset..self.offset + self.numel()].to_vec()
365        } else {
366            let mut result = Vec::with_capacity(self.numel());
367            self.copy_data_to(&mut result);
368            result
369        }
370    }
371
372    /// Copies data to a slice, handling non-contiguous layouts.
373    fn copy_data_to(&self, dst: &mut Vec<T>) {
374        dst.clear();
375        let storage = self.storage.as_slice();
376
377        // Iterate through all indices
378        let total = self.numel();
379        for i in 0..total {
380            let indices = crate::shape::unravel_index(i, &self.shape);
381            let offset = self.offset + linear_index(&indices, &self.strides);
382            dst.push(storage[offset]);
383        }
384    }
385
386    // =========================================================================
387    // Shape Operations
388    // =========================================================================
389
390    /// Returns a new tensor with the specified shape.
391    ///
392    /// The total number of elements must remain the same.
393    /// Supports -1 in one dimension to infer the size.
394    ///
395    /// # Arguments
396    /// * `new_shape` - Target shape
397    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
398        let shape = reshape(&self.shape, new_shape)?;
399
400        if self.is_contiguous() {
401            // Can just change shape without copying
402            Ok(Self {
403                storage: self.storage.clone(),
404                strides: contiguous_strides(&shape),
405                shape,
406                offset: self.offset,
407            })
408        } else {
409            // Need to make contiguous first
410            let contig = self.contiguous();
411            Ok(Self {
412                storage: contig.storage,
413                strides: contiguous_strides(&shape),
414                shape,
415                offset: 0,
416            })
417        }
418    }
419
420    /// Returns a new tensor with a flattened shape.
421    #[must_use]
422    pub fn flatten(&self) -> Self {
423        self.reshape(&[-1]).expect("Flatten should never fail")
424    }
425
426    /// Returns a new tensor with dimensions of size 1 removed.
427    ///
428    /// # Arguments
429    /// * `dim` - Optional specific dimension to squeeze
430    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
431        let dim = match dim {
432            Some(d) => Some(normalize_dim(d, self.ndim())?),
433            None => None,
434        };
435
436        let new_shape = squeeze(&self.shape, dim);
437        let new_strides: Strides = match dim {
438            Some(d) => {
439                let mut s = self.strides.clone();
440                if d < self.shape.len() && self.shape[d] == 1 {
441                    s.remove(d);
442                }
443                s
444            }
445            None => self
446                .shape
447                .iter()
448                .zip(self.strides.iter())
449                .filter(|(&dim, _)| dim != 1)
450                .map(|(_, &stride)| stride)
451                .collect(),
452        };
453
454        Ok(Self {
455            storage: self.storage.clone(),
456            shape: new_shape,
457            strides: new_strides,
458            offset: self.offset,
459        })
460    }
461
462    /// Returns a new tensor with a dimension of size 1 inserted.
463    ///
464    /// # Arguments
465    /// * `dim` - Position to insert the new dimension
466    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
467        let normalized = if dim < 0 {
468            (dim + self.ndim() as i64 + 1) as usize
469        } else {
470            dim as usize
471        };
472
473        let new_shape = unsqueeze(&self.shape, normalized)?;
474        let mut new_strides = Strides::with_capacity(new_shape.len());
475
476        for (i, _) in new_shape.iter().enumerate() {
477            if i < normalized {
478                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
479            } else if i == normalized {
480                // Stride for new dimension (doesn't matter since size is 1)
481                new_strides.push(1);
482            } else {
483                new_strides.push(self.strides[i - 1]);
484            }
485        }
486
487        Ok(Self {
488            storage: self.storage.clone(),
489            shape: new_shape,
490            strides: new_strides,
491            offset: self.offset,
492        })
493    }
494
495    /// Transposes two dimensions.
496    ///
497    /// # Arguments
498    /// * `dim0` - First dimension
499    /// * `dim1` - Second dimension
500    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
501        let d0 = normalize_dim(dim0, self.ndim())?;
502        let d1 = normalize_dim(dim1, self.ndim())?;
503
504        let new_shape = transpose_shape(&self.shape, d0, d1)?;
505        let new_strides = transpose_strides(&self.strides, d0, d1);
506
507        Ok(Self {
508            storage: self.storage.clone(),
509            shape: new_shape,
510            strides: new_strides,
511            offset: self.offset,
512        })
513    }
514
515    /// Returns the transpose of a 2D tensor.
516    pub fn t(&self) -> Result<Self> {
517        if self.ndim() != 2 {
518            return Err(Error::invalid_operation("t() only works on 2D tensors"));
519        }
520        self.transpose(0, 1)
521    }
522
523    /// Returns a permuted tensor with dimensions reordered.
524    ///
525    /// # Arguments
526    /// * `dims` - New order of dimensions
527    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
528        if dims.len() != self.ndim() {
529            return Err(Error::invalid_operation(format!(
530                "Expected {} dimensions, got {}",
531                self.ndim(),
532                dims.len()
533            )));
534        }
535
536        // Check that dims is a permutation
537        let mut seen = vec![false; self.ndim()];
538        for &d in dims {
539            if d >= self.ndim() {
540                return Err(Error::InvalidDimension {
541                    index: d as i64,
542                    ndim: self.ndim(),
543                });
544            }
545            if seen[d] {
546                return Err(Error::invalid_operation("Duplicate dimension in permute"));
547            }
548            seen[d] = true;
549        }
550
551        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
552        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
553
554        Ok(Self {
555            storage: self.storage.clone(),
556            shape: new_shape,
557            strides: new_strides,
558            offset: self.offset,
559        })
560    }
561
562    /// Returns a contiguous copy of the tensor.
563    #[must_use]
564    pub fn contiguous(&self) -> Self {
565        if self.is_contiguous() && self.offset == 0 {
566            return self.clone();
567        }
568
569        let data = self.to_vec();
570        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
571    }
572
573    // =========================================================================
574    // Device Operations
575    // =========================================================================
576
577    /// Transfers the tensor to a different device.
578    ///
579    /// # Arguments
580    /// * `device` - Target device
581    pub fn to_device(&self, device: Device) -> Result<Self> {
582        if self.device() == device {
583            return Ok(self.clone());
584        }
585
586        let contig = self.contiguous();
587        let new_storage = contig.storage.to_device(device)?;
588
589        Ok(Self {
590            storage: new_storage,
591            shape: self.shape.clone(),
592            strides: self.strides.clone(),
593            offset: 0,
594        })
595    }
596
597    /// Transfers to CPU.
598    pub fn cpu(&self) -> Result<Self> {
599        self.to_device(Device::Cpu)
600    }
601
602    // =========================================================================
603    // Deep Copy
604    // =========================================================================
605
606    /// Creates a deep copy of this tensor with its own storage.
607    #[must_use]
608    pub fn clone_deep(&self) -> Self {
609        let data = self.to_vec();
610        Self::from_vec(data, &self.shape).expect("Deep clone should never fail")
611    }
612}
613
614// =============================================================================
615// Numeric Operations
616// =============================================================================
617
618impl<T: Numeric> Tensor<T> {
619    /// Fills the tensor with a value.
620    pub fn fill_(&self, value: T) {
621        let mut data = self.storage.as_slice_mut();
622        CpuBackend::fill(&mut data, value);
623    }
624
625    /// Fills the tensor with zeros.
626    pub fn zero_(&self) {
627        self.fill_(T::zero());
628    }
629
630    // =========================================================================
631    // Reduction Operations
632    // =========================================================================
633
634    /// Returns the sum of all elements.
635    #[must_use]
636    pub fn sum(&self) -> Self {
637        let data = self.to_vec();
638        let result = CpuBackend::sum(&data);
639        Self::scalar(result)
640    }
641
642    /// Returns the product of all elements.
643    #[must_use]
644    pub fn prod(&self) -> Self {
645        let data = self.to_vec();
646        let result = CpuBackend::prod(&data);
647        Self::scalar(result)
648    }
649
650    /// Returns the maximum element.
651    pub fn max(&self) -> Result<Self> {
652        if self.is_empty() {
653            return Err(Error::EmptyTensor);
654        }
655        let data = self.to_vec();
656        let result = CpuBackend::max(&data).unwrap();
657        Ok(Self::scalar(result))
658    }
659
660    /// Returns the minimum element.
661    pub fn min(&self) -> Result<Self> {
662        if self.is_empty() {
663            return Err(Error::EmptyTensor);
664        }
665        let data = self.to_vec();
666        let result = CpuBackend::min(&data).unwrap();
667        Ok(Self::scalar(result))
668    }
669
670    /// Returns the index of the maximum element.
671    pub fn argmax(&self) -> Result<usize> {
672        if self.is_empty() {
673            return Err(Error::EmptyTensor);
674        }
675        let data = self.to_vec();
676        Ok(CpuBackend::argmax(&data).unwrap())
677    }
678
679    /// Returns the index of the minimum element.
680    pub fn argmin(&self) -> Result<usize> {
681        if self.is_empty() {
682            return Err(Error::EmptyTensor);
683        }
684        let data = self.to_vec();
685        Ok(CpuBackend::argmin(&data).unwrap())
686    }
687
688    /// Concatenates tensors along a dimension.
689    ///
690    /// All tensors must have the same shape except along the cat dimension.
691    pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
692        if tensors.is_empty() {
693            return Err(Error::invalid_operation("cat requires at least one tensor"));
694        }
695        let ndim = tensors[0].ndim();
696        if dim >= ndim {
697            return Err(Error::invalid_operation("cat dimension out of range"));
698        }
699
700        for t in &tensors[1..] {
701            if t.ndim() != ndim {
702                return Err(Error::invalid_operation("cat: all tensors must have same ndim"));
703            }
704            for d in 0..ndim {
705                if d != dim && t.shape[d] != tensors[0].shape[d] {
706                    return Err(Error::invalid_operation("cat: shapes must match on non-cat dims"));
707                }
708            }
709        }
710
711        let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
712        let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
713        out_shape[dim] = total_dim_size;
714
715        let outer_size: usize = out_shape[..dim].iter().product();
716        let inner_size: usize = out_shape[dim + 1..].iter().product();
717        let total_numel: usize = out_shape.iter().product();
718        let mut result = vec![T::zero(); total_numel];
719
720        let mut dim_offset = 0;
721        for t in tensors {
722            let t_data = t.contiguous().to_vec();
723            let t_dim_size = t.shape[dim];
724            for outer in 0..outer_size {
725                for d in 0..t_dim_size {
726                    for inner in 0..inner_size {
727                        let src_idx = outer * t_dim_size * inner_size + d * inner_size + inner;
728                        let dst_idx = outer * total_dim_size * inner_size
729                            + (dim_offset + d) * inner_size
730                            + inner;
731                        result[dst_idx] = t_data[src_idx];
732                    }
733                }
734            }
735            dim_offset += t_dim_size;
736        }
737
738        Self::from_vec(result, &out_shape)
739    }
740}
741
742// =============================================================================
743// Float Operations
744// =============================================================================
745
746impl<T: Float> Tensor<T> {
747    /// Returns the mean of all elements.
748    pub fn mean(&self) -> Result<Self> {
749        if self.is_empty() {
750            return Err(Error::EmptyTensor);
751        }
752        let data = self.to_vec();
753        let result = CpuBackend::mean(&data).unwrap();
754        Ok(Self::scalar(result))
755    }
756
757    // =========================================================================
758    // Activation Functions
759    // =========================================================================
760
761    /// Applies `ReLU` activation: max(0, x).
762    #[must_use]
763    pub fn relu(&self) -> Self {
764        let data = self.to_vec();
765        let mut result = vec![T::zero(); data.len()];
766        CpuBackend::relu(&mut result, &data);
767        Self::from_vec(result, &self.shape).unwrap()
768    }
769
770    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
771    #[must_use]
772    pub fn sigmoid(&self) -> Self {
773        let data = self.to_vec();
774        let mut result = vec![T::zero(); data.len()];
775        CpuBackend::sigmoid(&mut result, &data);
776        Self::from_vec(result, &self.shape).unwrap()
777    }
778
779    /// Applies tanh activation.
780    #[must_use]
781    pub fn tanh(&self) -> Self {
782        let data = self.to_vec();
783        let mut result = vec![T::zero(); data.len()];
784        CpuBackend::tanh(&mut result, &data);
785        Self::from_vec(result, &self.shape).unwrap()
786    }
787
788    /// Applies exponential function.
789    #[must_use]
790    pub fn exp(&self) -> Self {
791        let data = self.to_vec();
792        let mut result = vec![T::zero(); data.len()];
793        CpuBackend::exp(&mut result, &data);
794        Self::from_vec(result, &self.shape).unwrap()
795    }
796
797    /// Applies natural logarithm.
798    #[must_use]
799    pub fn ln(&self) -> Self {
800        let data = self.to_vec();
801        let mut result = vec![T::zero(); data.len()];
802        CpuBackend::ln(&mut result, &data);
803        Self::from_vec(result, &self.shape).unwrap()
804    }
805
806    /// Applies square root.
807    #[must_use]
808    pub fn sqrt(&self) -> Self {
809        let data = self.to_vec();
810        let mut result = vec![T::zero(); data.len()];
811        CpuBackend::sqrt(&mut result, &data);
812        Self::from_vec(result, &self.shape).unwrap()
813    }
814
815    /// Computes element-wise power.
816    #[must_use]
817    pub fn pow(&self, exp: T) -> Self {
818        let data = self.to_vec();
819        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
820        Self::from_vec(result, &self.shape).unwrap()
821    }
822
823    /// GELU activation function (Gaussian Error Linear Unit).
824    #[must_use]
825    pub fn gelu(&self) -> Self {
826        crate::ops::gelu(self)
827    }
828
829    /// SiLU/Swish activation function.
830    #[must_use]
831    pub fn silu(&self) -> Self {
832        crate::ops::silu(self)
833    }
834
835    /// Softmax along specified dimension.
836    #[must_use]
837    pub fn softmax(&self, dim: i32) -> Self {
838        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
839    }
840
841    /// Log softmax along specified dimension.
842    #[must_use]
843    pub fn log_softmax(&self, dim: i32) -> Self {
844        let softmax_result = self.softmax(dim);
845        softmax_result.ln()
846    }
847
848    /// Mean along a dimension.
849    #[must_use]
850    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
851        let ndim = self.ndim();
852        let dim = if dim < 0 {
853            (ndim as i32 + dim) as usize
854        } else {
855            dim as usize
856        };
857
858        if dim >= ndim {
859            return self.clone();
860        }
861
862        let dim_size = self.shape[dim];
863        let data = self.to_vec();
864        let mut new_shape = self.shape.clone();
865
866        if keepdim {
867            new_shape[dim] = 1;
868        } else {
869            new_shape.remove(dim);
870        }
871
872        if new_shape.is_empty() {
873            new_shape = smallvec::smallvec![1];
874        }
875
876        let new_numel: usize = new_shape.iter().product();
877        let mut result = vec![T::zero(); new_numel];
878
879        // Compute strides for iteration
880        let outer_size: usize = self.shape[..dim].iter().product();
881        let inner_size: usize = self.shape[dim + 1..].iter().product();
882
883        for outer in 0..outer_size {
884            for inner in 0..inner_size {
885                let mut sum = T::zero();
886                for d in 0..dim_size {
887                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
888                    sum = sum + data[idx];
889                }
890                let mean = sum / NumCast::from(dim_size).unwrap();
891                let result_idx = outer * inner_size + inner;
892                result[result_idx] = mean;
893            }
894        }
895
896        Self::from_vec(result, &new_shape).unwrap()
897    }
898
899    /// Sum along a dimension.
900    #[must_use]
901    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
902        let ndim = self.ndim();
903        let dim = if dim < 0 {
904            (ndim as i32 + dim) as usize
905        } else {
906            dim as usize
907        };
908
909        if dim >= ndim {
910            return self.clone();
911        }
912
913        let dim_size = self.shape[dim];
914        let data = self.to_vec();
915        let mut new_shape = self.shape.clone();
916
917        if keepdim {
918            new_shape[dim] = 1;
919        } else {
920            new_shape.remove(dim);
921        }
922
923        if new_shape.is_empty() {
924            new_shape = smallvec::smallvec![1];
925        }
926
927        let new_numel: usize = new_shape.iter().product();
928        let mut result = vec![T::zero(); new_numel];
929
930        let outer_size: usize = self.shape[..dim].iter().product();
931        let inner_size: usize = self.shape[dim + 1..].iter().product();
932
933        for outer in 0..outer_size {
934            for inner in 0..inner_size {
935                let mut sum = T::zero();
936                for d in 0..dim_size {
937                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
938                    sum = sum + data[idx];
939                }
940                let result_idx = outer * inner_size + inner;
941                result[result_idx] = sum;
942            }
943        }
944
945        Self::from_vec(result, &new_shape).unwrap()
946    }
947
948    /// Variance along a dimension.
949    #[must_use]
950    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
951        let mean = self.mean_dim(dim, true);
952        let diff = self.sub(&mean).unwrap_or_else(|_| self.clone());
953        let squared = diff.mul(&diff).unwrap_or_else(|_| self.clone());
954        squared.mean_dim(dim, keepdim)
955    }
956
957    /// Broadcasts tensor to a new shape.
958    #[must_use]
959    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
960        if self.shape.as_slice() == shape {
961            return self.clone();
962        }
963
964        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
965        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
966
967        let total = numel(&result_shape);
968        let mut result_data = vec![T::zero(); total];
969        let self_data = self.storage.as_slice();
970
971        for i in 0..total {
972            let indices = crate::shape::unravel_index(i, &result_shape);
973            let self_idx = self.offset + linear_index(&indices, &self_strides);
974            result_data[i] = self_data[self_idx];
975        }
976
977        Self::from_vec(result_data, &result_shape).unwrap()
978    }
979
980    /// Slices the tensor using ranges for each dimension.
981    #[must_use]
982    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
983        let mut new_shape = Vec::with_capacity(self.ndim());
984        for (i, range) in ranges.iter().enumerate() {
985            if i < self.ndim() {
986                new_shape.push(range.end - range.start);
987            }
988        }
989        // Keep remaining dimensions unchanged
990        for i in ranges.len()..self.ndim() {
991            new_shape.push(self.shape[i]);
992        }
993
994        let new_numel: usize = new_shape.iter().product();
995        let mut result_data = vec![T::zero(); new_numel];
996        let self_data = self.to_vec();
997
998        // Copy data with proper indexing
999        let mut result_idx = 0;
1000        Self::slice_recursive(
1001            &self_data,
1002            &self.shape,
1003            ranges,
1004            0,
1005            0,
1006            &mut result_data,
1007            &mut result_idx,
1008        );
1009
1010        Self::from_vec(result_data, &new_shape).unwrap()
1011    }
1012
1013    fn slice_recursive(
1014        data: &[T],
1015        shape: &[usize],
1016        ranges: &[std::ops::Range<usize>],
1017        dim: usize,
1018        offset: usize,
1019        result: &mut [T],
1020        result_idx: &mut usize,
1021    ) {
1022        if dim == shape.len() {
1023            result[*result_idx] = data[offset];
1024            *result_idx += 1;
1025            return;
1026        }
1027
1028        let stride: usize = shape[dim + 1..].iter().product();
1029        let (start, end) = if dim < ranges.len() {
1030            (ranges[dim].start, ranges[dim].end)
1031        } else {
1032            (0, shape[dim])
1033        };
1034
1035        for i in start..end {
1036            Self::slice_recursive(
1037                data,
1038                shape,
1039                ranges,
1040                dim + 1,
1041                offset + i * stride,
1042                result,
1043                result_idx,
1044            );
1045        }
1046    }
1047}
1048
1049// =============================================================================
1050// Arithmetic Operator Implementations
1051// =============================================================================
1052
1053impl<T: Numeric> Tensor<T> {
1054    /// Element-wise addition with broadcasting.
1055    pub fn add(&self, other: &Self) -> Result<Self> {
1056        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1057        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1058        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1059
1060        let total = numel(&result_shape);
1061        let mut result_data = vec![T::zero(); total];
1062
1063        let self_data = self.storage.as_slice();
1064        let other_data = other.storage.as_slice();
1065
1066        for i in 0..total {
1067            let indices = crate::shape::unravel_index(i, &result_shape);
1068            let self_idx = self.offset + linear_index(&indices, &self_strides);
1069            let other_idx = other.offset + linear_index(&indices, &other_strides);
1070            result_data[i] = self_data[self_idx] + other_data[other_idx];
1071        }
1072
1073        Self::from_vec(result_data, &result_shape)
1074    }
1075
1076    /// Element-wise subtraction with broadcasting.
1077    pub fn sub(&self, other: &Self) -> Result<Self> {
1078        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1079        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1080        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1081
1082        let total = numel(&result_shape);
1083        let mut result_data = vec![T::zero(); total];
1084
1085        let self_data = self.storage.as_slice();
1086        let other_data = other.storage.as_slice();
1087
1088        for i in 0..total {
1089            let indices = crate::shape::unravel_index(i, &result_shape);
1090            let self_idx = self.offset + linear_index(&indices, &self_strides);
1091            let other_idx = other.offset + linear_index(&indices, &other_strides);
1092            result_data[i] = self_data[self_idx] - other_data[other_idx];
1093        }
1094
1095        Self::from_vec(result_data, &result_shape)
1096    }
1097
1098    /// Element-wise multiplication with broadcasting.
1099    pub fn mul(&self, other: &Self) -> Result<Self> {
1100        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1101        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1102        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1103
1104        let total = numel(&result_shape);
1105        let mut result_data = vec![T::zero(); total];
1106
1107        let self_data = self.storage.as_slice();
1108        let other_data = other.storage.as_slice();
1109
1110        for i in 0..total {
1111            let indices = crate::shape::unravel_index(i, &result_shape);
1112            let self_idx = self.offset + linear_index(&indices, &self_strides);
1113            let other_idx = other.offset + linear_index(&indices, &other_strides);
1114            result_data[i] = self_data[self_idx] * other_data[other_idx];
1115        }
1116
1117        Self::from_vec(result_data, &result_shape)
1118    }
1119
1120    /// Element-wise division with broadcasting.
1121    pub fn div(&self, other: &Self) -> Result<Self> {
1122        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1123        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1124        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1125
1126        let total = numel(&result_shape);
1127        let mut result_data = vec![T::zero(); total];
1128
1129        let self_data = self.storage.as_slice();
1130        let other_data = other.storage.as_slice();
1131
1132        for i in 0..total {
1133            let indices = crate::shape::unravel_index(i, &result_shape);
1134            let self_idx = self.offset + linear_index(&indices, &self_strides);
1135            let other_idx = other.offset + linear_index(&indices, &other_strides);
1136            result_data[i] = self_data[self_idx] / other_data[other_idx];
1137        }
1138
1139        Self::from_vec(result_data, &result_shape)
1140    }
1141
1142    /// Scalar addition.
1143    #[must_use]
1144    pub fn add_scalar(&self, scalar: T) -> Self {
1145        let data = self.to_vec();
1146        let mut result = vec![T::zero(); data.len()];
1147        CpuBackend::add_scalar(&mut result, &data, scalar);
1148        Self::from_vec(result, &self.shape).unwrap()
1149    }
1150
1151    /// Scalar multiplication.
1152    #[must_use]
1153    pub fn mul_scalar(&self, scalar: T) -> Self {
1154        let data = self.to_vec();
1155        let mut result = vec![T::zero(); data.len()];
1156        CpuBackend::mul_scalar(&mut result, &data, scalar);
1157        Self::from_vec(result, &self.shape).unwrap()
1158    }
1159
1160    /// Element-wise negation.
1161    #[must_use]
1162    pub fn neg(&self) -> Self {
1163        let data = self.to_vec();
1164        let mut result = vec![T::zero(); data.len()];
1165        CpuBackend::neg(&mut result, &data);
1166        Self::from_vec(result, &self.shape).unwrap()
1167    }
1168
1169    /// Matrix multiplication with batching support.
1170    ///
1171    /// Supports:
1172    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1173    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1174    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1175    pub fn matmul(&self, other: &Self) -> Result<Self> {
1176        if self.ndim() < 2 || other.ndim() < 2 {
1177            return Err(Error::invalid_operation(
1178                "matmul requires at least 2D tensors",
1179            ));
1180        }
1181
1182        let m = self.shape[self.ndim() - 2];
1183        let k1 = self.shape[self.ndim() - 1];
1184        let k2 = other.shape[other.ndim() - 2];
1185        let n = other.shape[other.ndim() - 1];
1186
1187        if k1 != k2 {
1188            return Err(Error::invalid_operation(format!(
1189                "matmul inner dimensions must match: {k1} vs {k2}"
1190            )));
1191        }
1192
1193        // For 2D matrices, do simple matmul
1194        if self.ndim() == 2 && other.ndim() == 2 {
1195            let a_data = self.contiguous().to_vec();
1196            let b_data = other.contiguous().to_vec();
1197
1198            // Try GPU acceleration for f32 tensors
1199            #[cfg(feature = "cuda")]
1200            {
1201                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
1202                    // Safety: we verified T == f32
1203                    let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1204                    let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1205                    if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1206                        let c_t: Vec<T> = unsafe {
1207                            let mut v = std::mem::ManuallyDrop::new(c_f32);
1208                            Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1209                        };
1210                        return Self::from_vec(c_t, &[m, n]);
1211                    }
1212                }
1213            }
1214
1215            let mut c_data = vec![T::zero(); m * n];
1216            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1217            return Self::from_vec(c_data, &[m, n]);
1218        }
1219
1220        // For batched matmul, compute batch size
1221        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1222        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1223
1224        if batch_dims_self != batch_dims_other {
1225            return Err(Error::invalid_operation(format!(
1226                "matmul batch dimensions must match: {:?} vs {:?}",
1227                batch_dims_self, batch_dims_other
1228            )));
1229        }
1230
1231        let batch_size: usize = batch_dims_self.iter().product();
1232        let a_stride = m * k1;
1233        let b_stride = k1 * n;
1234        let c_stride = m * n;
1235
1236        let a_data = self.contiguous().to_vec();
1237        let b_data = other.contiguous().to_vec();
1238        let mut c_data = vec![T::zero(); batch_size * m * n];
1239
1240        // Try GPU acceleration for f32 batched matmul
1241        #[cfg(feature = "cuda")]
1242        {
1243            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
1244                let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1245                let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1246                let mut gpu_ok = true;
1247                for batch in 0..batch_size {
1248                    let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1249                    let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1250                    if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1251                        c_data[batch * c_stride..(batch + 1) * c_stride]
1252                            .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1253                    } else {
1254                        gpu_ok = false;
1255                        break;
1256                    }
1257                }
1258                if gpu_ok {
1259                    let mut output_shape = batch_dims_self;
1260                    output_shape.push(m);
1261                    output_shape.push(n);
1262                    return Self::from_vec(c_data, &output_shape);
1263                }
1264                // Fall through to CPU if GPU failed
1265                c_data = vec![T::zero(); batch_size * m * n];
1266            }
1267        }
1268
1269        // CPU fallback: loop over batches
1270        for batch in 0..batch_size {
1271            let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1272            let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1273            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1274            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1275        }
1276
1277        // Build output shape: batch_dims + [m, n]
1278        let mut output_shape = batch_dims_self;
1279        output_shape.push(m);
1280        output_shape.push(n);
1281
1282        Self::from_vec(c_data, &output_shape)
1283    }
1284
1285    /// Dot product for 1D tensors.
1286    pub fn dot(&self, other: &Self) -> Result<Self> {
1287        if self.ndim() != 1 || other.ndim() != 1 {
1288            return Err(Error::invalid_operation("dot requires 1D tensors"));
1289        }
1290
1291        if self.shape[0] != other.shape[0] {
1292            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1293        }
1294
1295        let a_data = self.to_vec();
1296        let b_data = other.to_vec();
1297        let result = CpuBackend::dot(&a_data, &b_data);
1298
1299        Ok(Self::scalar(result))
1300    }
1301}
1302
1303// =============================================================================
1304// Operator Trait Implementations
1305// =============================================================================
1306
1307impl<T: Numeric> Add for &Tensor<T> {
1308    type Output = Tensor<T>;
1309
1310    fn add(self, other: Self) -> Self::Output {
1311        self.add(other).expect("Addition failed")
1312    }
1313}
1314
1315impl<T: Numeric> Sub for &Tensor<T> {
1316    type Output = Tensor<T>;
1317
1318    fn sub(self, other: Self) -> Self::Output {
1319        self.sub(other).expect("Subtraction failed")
1320    }
1321}
1322
1323impl<T: Numeric> Mul for &Tensor<T> {
1324    type Output = Tensor<T>;
1325
1326    fn mul(self, other: Self) -> Self::Output {
1327        self.mul(other).expect("Multiplication failed")
1328    }
1329}
1330
1331impl<T: Numeric> Div for &Tensor<T> {
1332    type Output = Tensor<T>;
1333
1334    fn div(self, other: Self) -> Self::Output {
1335        self.div(other).expect("Division failed")
1336    }
1337}
1338
1339impl<T: Numeric> Neg for &Tensor<T> {
1340    type Output = Tensor<T>;
1341
1342    fn neg(self) -> Self::Output {
1343        self.neg()
1344    }
1345}
1346
1347// Scalar operations
1348impl<T: Numeric> Add<T> for &Tensor<T> {
1349    type Output = Tensor<T>;
1350
1351    fn add(self, scalar: T) -> Self::Output {
1352        self.add_scalar(scalar)
1353    }
1354}
1355
1356impl<T: Numeric> Mul<T> for &Tensor<T> {
1357    type Output = Tensor<T>;
1358
1359    fn mul(self, scalar: T) -> Self::Output {
1360        self.mul_scalar(scalar)
1361    }
1362}
1363
1364// =============================================================================
1365// Display Implementation
1366// =============================================================================
1367
1368impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1370        write!(
1371            f,
1372            "Tensor(shape={:?}, device={}",
1373            self.shape(),
1374            self.device()
1375        )?;
1376        if self.numel() <= 10 {
1377            write!(f, ", data={:?}", self.to_vec())?;
1378        }
1379        write!(f, ")")
1380    }
1381}
1382
1383impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1384    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1385        if self.is_scalar() {
1386            write!(f, "{}", self.item().unwrap())
1387        } else if self.ndim() == 1 {
1388            write!(f, "[")?;
1389            let data = self.to_vec();
1390            for (i, val) in data.iter().enumerate() {
1391                if i > 0 {
1392                    write!(f, ", ")?;
1393                }
1394                write!(f, "{val}")?;
1395            }
1396            write!(f, "]")
1397        } else {
1398            write!(f, "Tensor(shape={:?})", self.shape())
1399        }
1400    }
1401}
1402
1403// =============================================================================
1404// Tests
1405// =============================================================================
1406
1407#[cfg(test)]
1408mod tests {
1409    use super::*;
1410
1411    #[test]
1412    fn test_from_vec() {
1413        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1414        assert_eq!(t.shape(), &[2, 3]);
1415        assert_eq!(t.numel(), 6);
1416    }
1417
1418    #[test]
1419    fn test_get_set() {
1420        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1421        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1422        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1423        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1424        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1425
1426        t.set(&[0, 0], 99.0).unwrap();
1427        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1428    }
1429
1430    #[test]
1431    fn test_reshape() {
1432        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1433        let r = t.reshape(&[3, 2]).unwrap();
1434        assert_eq!(r.shape(), &[3, 2]);
1435
1436        let r = t.reshape(&[-1]).unwrap();
1437        assert_eq!(r.shape(), &[6]);
1438    }
1439
1440    #[test]
1441    fn test_transpose() {
1442        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1443        let r = t.t().unwrap();
1444        assert_eq!(r.shape(), &[3, 2]);
1445        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1446        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1447        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1448    }
1449
1450    #[test]
1451    fn test_arithmetic() {
1452        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1453        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1454
1455        let c = &a + &b;
1456        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1457
1458        let d = &a * &b;
1459        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1460    }
1461
1462    #[test]
1463    fn test_broadcasting() {
1464        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1465        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1466
1467        let c = &a + &b;
1468        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1469    }
1470
1471    #[test]
1472    fn test_sum() {
1473        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1474        let s = t.sum();
1475        assert_eq!(s.item().unwrap(), 10.0);
1476    }
1477
1478    #[test]
1479    fn test_matmul() {
1480        // 2x2 @ 2x2
1481        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1482        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1483        let c = a.matmul(&b).unwrap();
1484
1485        assert_eq!(c.shape(), &[2, 2]);
1486        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1487    }
1488
1489    #[test]
1490    fn test_relu() {
1491        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1492        let r = t.relu();
1493        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1494    }
1495
1496    #[test]
1497    fn test_scalar() {
1498        let s = Tensor::<f32>::scalar(42.0);
1499        assert!(s.is_scalar());
1500        assert_eq!(s.numel(), 1);
1501        assert_eq!(s.item().unwrap(), 42.0);
1502    }
1503}