axonml_tensor/
tensor.rs

1//! Tensor - Core N-Dimensional Array Type
2//!
3//! The `Tensor` struct is the fundamental data structure in Axonml. It represents
4//! an N-dimensional array of numeric values with support for automatic broadcasting,
5//! device placement, and efficient memory sharing through views.
6//!
7//! # Key Features
8//! - Generic over element type (f32, f64, i32, etc.)
9//! - Efficient views with shared storage
10//! - Device-agnostic operations
11//! - Broadcasting support
12//! - Lazy operations where possible
13//!
14//! @version 0.1.0
15//! @author `AutomataNexus` Development Team
16
17use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21use axonml_core::dtype::{Float, Numeric, Scalar};
22use num_traits::NumCast;
23use axonml_core::error::{Error, Result};
24use axonml_core::storage::Storage;
25use axonml_core::Device;
26
27use crate::shape::{
28    broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
29    normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
30    Strides,
31};
32
33// =============================================================================
34// Tensor Struct
35// =============================================================================
36
37/// An N-dimensional array of numeric values.
38///
39/// Tensors are the core data structure for all computations in Axonml.
40/// They support arbitrary dimensions, automatic broadcasting, and efficient
41/// memory sharing between views.
42#[derive(Clone)]
43pub struct Tensor<T: Scalar> {
44    /// Underlying data storage (reference-counted).
45    pub(crate) storage: Storage<T>,
46    /// Shape of the tensor (dimensions).
47    pub(crate) shape: Shape,
48    /// Strides for each dimension.
49    pub(crate) strides: Strides,
50    /// Offset into storage (for views).
51    pub(crate) offset: usize,
52}
53
54impl<T: Scalar> Tensor<T> {
55    // =========================================================================
56    // Constructors
57    // =========================================================================
58
59    /// Creates a new tensor from storage with the given shape.
60    ///
61    /// # Arguments
62    /// * `storage` - The underlying data storage
63    /// * `shape` - Shape of the tensor
64    ///
65    /// # Returns
66    /// New tensor, or error if shape doesn't match storage size.
67    pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
68        let total = numel(shape);
69        if total != storage.len() {
70            return Err(Error::shape_mismatch(&[storage.len()], shape));
71        }
72
73        let shape = Shape::from_slice(shape);
74        let strides = contiguous_strides(&shape);
75
76        Ok(Self {
77            storage,
78            shape,
79            strides,
80            offset: 0,
81        })
82    }
83
84    /// Creates a new tensor from a vector with the given shape.
85    ///
86    /// # Arguments
87    /// * `data` - Vector of data
88    /// * `shape` - Shape of the tensor
89    ///
90    /// # Returns
91    /// New tensor, or error if shape doesn't match data length.
92    pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
93        let storage = Storage::from_vec(data, Device::Cpu);
94        Self::from_storage(storage, shape)
95    }
96
97    /// Creates a new tensor from a slice with the given shape.
98    ///
99    /// # Arguments
100    /// * `data` - Slice of data to copy
101    /// * `shape` - Shape of the tensor
102    ///
103    /// # Returns
104    /// New tensor, or error if shape doesn't match data length.
105    pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
106        let storage = Storage::from_slice(data, Device::Cpu);
107        Self::from_storage(storage, shape)
108    }
109
110    /// Creates a scalar tensor (0-dimensional).
111    ///
112    /// # Arguments
113    /// * `value` - The scalar value
114    ///
115    /// # Returns
116    /// New 0-dimensional tensor.
117    pub fn scalar(value: T) -> Self {
118        Self {
119            storage: Storage::from_vec(vec![value], Device::Cpu),
120            shape: Shape::new(),
121            strides: Strides::new(),
122            offset: 0,
123        }
124    }
125
126    /// Creates a tensor filled with zeros.
127    #[must_use]
128    pub fn zeros(shape: &[usize]) -> Self {
129        crate::creation::zeros(shape)
130    }
131
132    /// Creates a tensor filled with ones.
133    #[must_use]
134    pub fn ones(shape: &[usize]) -> Self
135    where
136        T: Numeric,
137    {
138        crate::creation::ones(shape)
139    }
140
141    /// Creates a tensor filled with a constant value.
142    #[must_use]
143    pub fn full(shape: &[usize], value: T) -> Self {
144        crate::creation::full(shape, value)
145    }
146
147    /// Creates a tensor with random values from standard normal distribution.
148    #[must_use]
149    pub fn randn(shape: &[usize]) -> Self
150    where
151        T: Float,
152        rand_distr::StandardNormal: rand::distributions::Distribution<T>,
153    {
154        crate::creation::randn(shape)
155    }
156
157    /// Creates a tensor with random values from uniform distribution [0, 1).
158    #[must_use]
159    pub fn rand(shape: &[usize]) -> Self
160    where
161        T: Float,
162        rand::distributions::Standard: rand::distributions::Distribution<T>,
163    {
164        crate::creation::rand(shape)
165    }
166
167    // =========================================================================
168    // Properties
169    // =========================================================================
170
171    /// Returns the shape of the tensor.
172    #[must_use]
173    pub fn shape(&self) -> &[usize] {
174        &self.shape
175    }
176
177    /// Returns the strides of the tensor.
178    #[must_use]
179    pub fn strides(&self) -> &[isize] {
180        &self.strides
181    }
182
183    /// Returns the number of dimensions.
184    #[must_use]
185    pub fn ndim(&self) -> usize {
186        self.shape.len()
187    }
188
189    /// Returns the total number of elements.
190    #[must_use]
191    pub fn numel(&self) -> usize {
192        numel(&self.shape)
193    }
194
195    /// Returns true if the tensor is empty (has zero elements).
196    #[must_use]
197    pub fn is_empty(&self) -> bool {
198        self.numel() == 0
199    }
200
201    /// Returns the size of a specific dimension.
202    ///
203    /// # Arguments
204    /// * `dim` - Dimension index (supports negative indexing)
205    pub fn size(&self, dim: i64) -> Result<usize> {
206        let idx = normalize_dim(dim, self.ndim())?;
207        Ok(self.shape[idx])
208    }
209
210    /// Returns the device this tensor is on.
211    #[must_use]
212    pub fn device(&self) -> Device {
213        self.storage.device()
214    }
215
216    /// Returns true if the tensor is contiguous in memory.
217    #[must_use]
218    pub fn is_contiguous(&self) -> bool {
219        is_contiguous(&self.shape, &self.strides)
220    }
221
222    /// Returns true if this tensor is a scalar (0-dimensional).
223    #[must_use]
224    pub fn is_scalar(&self) -> bool {
225        self.shape.is_empty()
226    }
227
228    // =========================================================================
229    // Data Access
230    // =========================================================================
231
232    /// Returns the element at the given indices.
233    ///
234    /// # Arguments
235    /// * `indices` - Multi-dimensional indices
236    pub fn get(&self, indices: &[usize]) -> Result<T> {
237        if indices.len() != self.ndim() {
238            return Err(Error::invalid_operation(format!(
239                "Expected {} indices, got {}",
240                self.ndim(),
241                indices.len()
242            )));
243        }
244
245        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
246            if idx >= dim {
247                return Err(Error::IndexOutOfBounds {
248                    index: idx,
249                    size: dim,
250                });
251            }
252        }
253
254        let offset = self.offset + linear_index(indices, &self.strides);
255        Ok(self.storage.as_slice()[offset])
256    }
257
258    /// Sets the element at the given indices.
259    ///
260    /// # Arguments
261    /// * `indices` - Multi-dimensional indices
262    /// * `value` - Value to set
263    pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
264        if indices.len() != self.ndim() {
265            return Err(Error::invalid_operation(format!(
266                "Expected {} indices, got {}",
267                self.ndim(),
268                indices.len()
269            )));
270        }
271
272        for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
273            if idx >= dim {
274                return Err(Error::IndexOutOfBounds {
275                    index: idx,
276                    size: dim,
277                });
278            }
279        }
280
281        let offset = self.offset + linear_index(indices, &self.strides);
282        self.storage.as_slice_mut()[offset] = value;
283        Ok(())
284    }
285
286    /// Returns the scalar value for a 0-dimensional tensor.
287    pub fn item(&self) -> Result<T> {
288        if self.numel() != 1 {
289            return Err(Error::invalid_operation(
290                "item() only works on single-element tensors",
291            ));
292        }
293
294        if self.is_scalar() {
295            Ok(self.storage.as_slice()[self.offset])
296        } else {
297            // Single element but not scalar shape
298            let indices = vec![0; self.ndim()];
299            self.get(&indices)
300        }
301    }
302
303    /// Returns the data as a contiguous vector.
304    ///
305    /// If the tensor is already contiguous, this returns a reference.
306    /// Otherwise, it copies the data into a new contiguous vector.
307    #[must_use]
308    pub fn to_vec(&self) -> Vec<T> {
309        if self.is_contiguous() {
310            let storage = self.storage.as_slice();
311            storage[self.offset..self.offset + self.numel()].to_vec()
312        } else {
313            let mut result = Vec::with_capacity(self.numel());
314            self.copy_data_to(&mut result);
315            result
316        }
317    }
318
319    /// Copies data to a slice, handling non-contiguous layouts.
320    fn copy_data_to(&self, dst: &mut Vec<T>) {
321        dst.clear();
322        let storage = self.storage.as_slice();
323
324        // Iterate through all indices
325        let total = self.numel();
326        for i in 0..total {
327            let indices = crate::shape::unravel_index(i, &self.shape);
328            let offset = self.offset + linear_index(&indices, &self.strides);
329            dst.push(storage[offset]);
330        }
331    }
332
333    // =========================================================================
334    // Shape Operations
335    // =========================================================================
336
337    /// Returns a new tensor with the specified shape.
338    ///
339    /// The total number of elements must remain the same.
340    /// Supports -1 in one dimension to infer the size.
341    ///
342    /// # Arguments
343    /// * `new_shape` - Target shape
344    pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
345        let shape = reshape(&self.shape, new_shape)?;
346
347        if self.is_contiguous() {
348            // Can just change shape without copying
349            Ok(Self {
350                storage: self.storage.clone(),
351                strides: contiguous_strides(&shape),
352                shape,
353                offset: self.offset,
354            })
355        } else {
356            // Need to make contiguous first
357            let contig = self.contiguous();
358            Ok(Self {
359                storage: contig.storage,
360                strides: contiguous_strides(&shape),
361                shape,
362                offset: 0,
363            })
364        }
365    }
366
367    /// Returns a new tensor with a flattened shape.
368    #[must_use] pub fn flatten(&self) -> Self {
369        self.reshape(&[-1]).expect("Flatten should never fail")
370    }
371
372    /// Returns a new tensor with dimensions of size 1 removed.
373    ///
374    /// # Arguments
375    /// * `dim` - Optional specific dimension to squeeze
376    pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
377        let dim = match dim {
378            Some(d) => Some(normalize_dim(d, self.ndim())?),
379            None => None,
380        };
381
382        let new_shape = squeeze(&self.shape, dim);
383        let new_strides: Strides = match dim {
384            Some(d) => {
385                let mut s = self.strides.clone();
386                if d < self.shape.len() && self.shape[d] == 1 {
387                    s.remove(d);
388                }
389                s
390            }
391            None => self
392                .shape
393                .iter()
394                .zip(self.strides.iter())
395                .filter(|(&dim, _)| dim != 1)
396                .map(|(_, &stride)| stride)
397                .collect(),
398        };
399
400        Ok(Self {
401            storage: self.storage.clone(),
402            shape: new_shape,
403            strides: new_strides,
404            offset: self.offset,
405        })
406    }
407
408    /// Returns a new tensor with a dimension of size 1 inserted.
409    ///
410    /// # Arguments
411    /// * `dim` - Position to insert the new dimension
412    pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
413        let normalized = if dim < 0 {
414            (dim + self.ndim() as i64 + 1) as usize
415        } else {
416            dim as usize
417        };
418
419        let new_shape = unsqueeze(&self.shape, normalized)?;
420        let mut new_strides = Strides::with_capacity(new_shape.len());
421
422        for (i, _) in new_shape.iter().enumerate() {
423            if i < normalized {
424                new_strides.push(self.strides.get(i).copied().unwrap_or(1));
425            } else if i == normalized {
426                // Stride for new dimension (doesn't matter since size is 1)
427                new_strides.push(1);
428            } else {
429                new_strides.push(self.strides[i - 1]);
430            }
431        }
432
433        Ok(Self {
434            storage: self.storage.clone(),
435            shape: new_shape,
436            strides: new_strides,
437            offset: self.offset,
438        })
439    }
440
441    /// Transposes two dimensions.
442    ///
443    /// # Arguments
444    /// * `dim0` - First dimension
445    /// * `dim1` - Second dimension
446    pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
447        let d0 = normalize_dim(dim0, self.ndim())?;
448        let d1 = normalize_dim(dim1, self.ndim())?;
449
450        let new_shape = transpose_shape(&self.shape, d0, d1)?;
451        let new_strides = transpose_strides(&self.strides, d0, d1);
452
453        Ok(Self {
454            storage: self.storage.clone(),
455            shape: new_shape,
456            strides: new_strides,
457            offset: self.offset,
458        })
459    }
460
461    /// Returns the transpose of a 2D tensor.
462    pub fn t(&self) -> Result<Self> {
463        if self.ndim() != 2 {
464            return Err(Error::invalid_operation("t() only works on 2D tensors"));
465        }
466        self.transpose(0, 1)
467    }
468
469    /// Returns a permuted tensor with dimensions reordered.
470    ///
471    /// # Arguments
472    /// * `dims` - New order of dimensions
473    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
474        if dims.len() != self.ndim() {
475            return Err(Error::invalid_operation(format!(
476                "Expected {} dimensions, got {}",
477                self.ndim(),
478                dims.len()
479            )));
480        }
481
482        // Check that dims is a permutation
483        let mut seen = vec![false; self.ndim()];
484        for &d in dims {
485            if d >= self.ndim() {
486                return Err(Error::InvalidDimension {
487                    index: d as i64,
488                    ndim: self.ndim(),
489                });
490            }
491            if seen[d] {
492                return Err(Error::invalid_operation("Duplicate dimension in permute"));
493            }
494            seen[d] = true;
495        }
496
497        let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
498        let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
499
500        Ok(Self {
501            storage: self.storage.clone(),
502            shape: new_shape,
503            strides: new_strides,
504            offset: self.offset,
505        })
506    }
507
508    /// Returns a contiguous copy of the tensor.
509    #[must_use]
510    pub fn contiguous(&self) -> Self {
511        if self.is_contiguous() && self.offset == 0 {
512            return self.clone();
513        }
514
515        let data = self.to_vec();
516        Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
517    }
518
519    // =========================================================================
520    // Device Operations
521    // =========================================================================
522
523    /// Transfers the tensor to a different device.
524    ///
525    /// # Arguments
526    /// * `device` - Target device
527    pub fn to_device(&self, device: Device) -> Result<Self> {
528        if self.device() == device {
529            return Ok(self.clone());
530        }
531
532        let contig = self.contiguous();
533        let new_storage = contig.storage.to_device(device)?;
534
535        Ok(Self {
536            storage: new_storage,
537            shape: self.shape.clone(),
538            strides: self.strides.clone(),
539            offset: 0,
540        })
541    }
542
543    /// Transfers to CPU.
544    pub fn cpu(&self) -> Result<Self> {
545        self.to_device(Device::Cpu)
546    }
547
548    // =========================================================================
549    // Deep Copy
550    // =========================================================================
551
552    /// Creates a deep copy of this tensor with its own storage.
553    #[must_use]
554    pub fn clone_deep(&self) -> Self {
555        let data = self.to_vec();
556        Self::from_vec(data, &self.shape).expect("Deep clone should never fail")
557    }
558}
559
560// =============================================================================
561// Numeric Operations
562// =============================================================================
563
564impl<T: Numeric> Tensor<T> {
565    /// Fills the tensor with a value.
566    pub fn fill_(&self, value: T) {
567        let mut data = self.storage.as_slice_mut();
568        CpuBackend::fill(&mut data, value);
569    }
570
571    /// Fills the tensor with zeros.
572    pub fn zero_(&self) {
573        self.fill_(T::zero());
574    }
575
576    // =========================================================================
577    // Reduction Operations
578    // =========================================================================
579
580    /// Returns the sum of all elements.
581    #[must_use] pub fn sum(&self) -> Self {
582        let data = self.to_vec();
583        let result = CpuBackend::sum(&data);
584        Self::scalar(result)
585    }
586
587    /// Returns the product of all elements.
588    #[must_use] pub fn prod(&self) -> Self {
589        let data = self.to_vec();
590        let result = CpuBackend::prod(&data);
591        Self::scalar(result)
592    }
593
594    /// Returns the maximum element.
595    pub fn max(&self) -> Result<Self> {
596        if self.is_empty() {
597            return Err(Error::EmptyTensor);
598        }
599        let data = self.to_vec();
600        let result = CpuBackend::max(&data).unwrap();
601        Ok(Self::scalar(result))
602    }
603
604    /// Returns the minimum element.
605    pub fn min(&self) -> Result<Self> {
606        if self.is_empty() {
607            return Err(Error::EmptyTensor);
608        }
609        let data = self.to_vec();
610        let result = CpuBackend::min(&data).unwrap();
611        Ok(Self::scalar(result))
612    }
613
614    /// Returns the index of the maximum element.
615    pub fn argmax(&self) -> Result<usize> {
616        if self.is_empty() {
617            return Err(Error::EmptyTensor);
618        }
619        let data = self.to_vec();
620        Ok(CpuBackend::argmax(&data).unwrap())
621    }
622
623    /// Returns the index of the minimum element.
624    pub fn argmin(&self) -> Result<usize> {
625        if self.is_empty() {
626            return Err(Error::EmptyTensor);
627        }
628        let data = self.to_vec();
629        Ok(CpuBackend::argmin(&data).unwrap())
630    }
631}
632
633// =============================================================================
634// Float Operations
635// =============================================================================
636
637impl<T: Float> Tensor<T> {
638    /// Returns the mean of all elements.
639    pub fn mean(&self) -> Result<Self> {
640        if self.is_empty() {
641            return Err(Error::EmptyTensor);
642        }
643        let data = self.to_vec();
644        let result = CpuBackend::mean(&data).unwrap();
645        Ok(Self::scalar(result))
646    }
647
648    // =========================================================================
649    // Activation Functions
650    // =========================================================================
651
652    /// Applies `ReLU` activation: max(0, x).
653    #[must_use]
654    pub fn relu(&self) -> Self {
655        let data = self.to_vec();
656        let mut result = vec![T::zero(); data.len()];
657        CpuBackend::relu(&mut result, &data);
658        Self::from_vec(result, &self.shape).unwrap()
659    }
660
661    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
662    #[must_use]
663    pub fn sigmoid(&self) -> Self {
664        let data = self.to_vec();
665        let mut result = vec![T::zero(); data.len()];
666        CpuBackend::sigmoid(&mut result, &data);
667        Self::from_vec(result, &self.shape).unwrap()
668    }
669
670    /// Applies tanh activation.
671    #[must_use]
672    pub fn tanh(&self) -> Self {
673        let data = self.to_vec();
674        let mut result = vec![T::zero(); data.len()];
675        CpuBackend::tanh(&mut result, &data);
676        Self::from_vec(result, &self.shape).unwrap()
677    }
678
679    /// Applies exponential function.
680    #[must_use]
681    pub fn exp(&self) -> Self {
682        let data = self.to_vec();
683        let mut result = vec![T::zero(); data.len()];
684        CpuBackend::exp(&mut result, &data);
685        Self::from_vec(result, &self.shape).unwrap()
686    }
687
688    /// Applies natural logarithm.
689    #[must_use]
690    pub fn ln(&self) -> Self {
691        let data = self.to_vec();
692        let mut result = vec![T::zero(); data.len()];
693        CpuBackend::ln(&mut result, &data);
694        Self::from_vec(result, &self.shape).unwrap()
695    }
696
697    /// Applies square root.
698    #[must_use]
699    pub fn sqrt(&self) -> Self {
700        let data = self.to_vec();
701        let mut result = vec![T::zero(); data.len()];
702        CpuBackend::sqrt(&mut result, &data);
703        Self::from_vec(result, &self.shape).unwrap()
704    }
705
706    /// Computes element-wise power.
707    #[must_use]
708    pub fn pow(&self, exp: T) -> Self {
709        let data = self.to_vec();
710        let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
711        Self::from_vec(result, &self.shape).unwrap()
712    }
713
714    /// GELU activation function (Gaussian Error Linear Unit).
715    #[must_use]
716    pub fn gelu(&self) -> Self {
717        crate::ops::gelu(self)
718    }
719
720    /// SiLU/Swish activation function.
721    #[must_use]
722    pub fn silu(&self) -> Self {
723        crate::ops::silu(self)
724    }
725
726    /// Softmax along specified dimension.
727    #[must_use]
728    pub fn softmax(&self, dim: i32) -> Self {
729        crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
730    }
731
732    /// Log softmax along specified dimension.
733    #[must_use]
734    pub fn log_softmax(&self, dim: i32) -> Self {
735        let softmax_result = self.softmax(dim);
736        softmax_result.ln()
737    }
738
739    /// Mean along a dimension.
740    #[must_use]
741    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
742        let ndim = self.ndim();
743        let dim = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
744
745        if dim >= ndim {
746            return self.clone();
747        }
748
749        let dim_size = self.shape[dim];
750        let data = self.to_vec();
751        let mut new_shape = self.shape.clone();
752
753        if keepdim {
754            new_shape[dim] = 1;
755        } else {
756            new_shape.remove(dim);
757        }
758
759        if new_shape.is_empty() {
760            new_shape = smallvec::smallvec![1];
761        }
762
763        let new_numel: usize = new_shape.iter().product();
764        let mut result = vec![T::zero(); new_numel];
765
766        // Compute strides for iteration
767        let outer_size: usize = self.shape[..dim].iter().product();
768        let inner_size: usize = self.shape[dim + 1..].iter().product();
769
770        for outer in 0..outer_size {
771            for inner in 0..inner_size {
772                let mut sum = T::zero();
773                for d in 0..dim_size {
774                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
775                    sum = sum + data[idx];
776                }
777                let mean = sum / NumCast::from(dim_size).unwrap();
778                let result_idx = outer * inner_size + inner;
779                result[result_idx] = mean;
780            }
781        }
782
783        Self::from_vec(result, &new_shape).unwrap()
784    }
785
786    /// Variance along a dimension.
787    #[must_use]
788    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
789        let mean = self.mean_dim(dim, true);
790        let diff = self.sub(&mean).unwrap_or_else(|_| self.clone());
791        let squared = diff.mul(&diff).unwrap_or_else(|_| self.clone());
792        squared.mean_dim(dim, keepdim)
793    }
794
795    /// Broadcasts tensor to a new shape.
796    #[must_use]
797    pub fn broadcast_to(&self, shape: &[usize]) -> Self {
798        if self.shape.as_slice() == shape {
799            return self.clone();
800        }
801
802        let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
803        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
804
805        let total = numel(&result_shape);
806        let mut result_data = vec![T::zero(); total];
807        let self_data = self.storage.as_slice();
808
809        for i in 0..total {
810            let indices = crate::shape::unravel_index(i, &result_shape);
811            let self_idx = self.offset + linear_index(&indices, &self_strides);
812            result_data[i] = self_data[self_idx];
813        }
814
815        Self::from_vec(result_data, &result_shape).unwrap()
816    }
817
818    /// Slices the tensor using ranges for each dimension.
819    #[must_use]
820    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
821        let mut new_shape = Vec::with_capacity(self.ndim());
822        for (i, range) in ranges.iter().enumerate() {
823            if i < self.ndim() {
824                new_shape.push(range.end - range.start);
825            }
826        }
827        // Keep remaining dimensions unchanged
828        for i in ranges.len()..self.ndim() {
829            new_shape.push(self.shape[i]);
830        }
831
832        let new_numel: usize = new_shape.iter().product();
833        let mut result_data = vec![T::zero(); new_numel];
834        let self_data = self.to_vec();
835
836        // Copy data with proper indexing
837        let mut result_idx = 0;
838        Self::slice_recursive(
839            &self_data,
840            &self.shape,
841            ranges,
842            0,
843            0,
844            &mut result_data,
845            &mut result_idx,
846        );
847
848        Self::from_vec(result_data, &new_shape).unwrap()
849    }
850
851    fn slice_recursive(
852        data: &[T],
853        shape: &[usize],
854        ranges: &[std::ops::Range<usize>],
855        dim: usize,
856        offset: usize,
857        result: &mut [T],
858        result_idx: &mut usize,
859    ) {
860        if dim == shape.len() {
861            result[*result_idx] = data[offset];
862            *result_idx += 1;
863            return;
864        }
865
866        let stride: usize = shape[dim + 1..].iter().product();
867        let (start, end) = if dim < ranges.len() {
868            (ranges[dim].start, ranges[dim].end)
869        } else {
870            (0, shape[dim])
871        };
872
873        for i in start..end {
874            Self::slice_recursive(data, shape, ranges, dim + 1, offset + i * stride, result, result_idx);
875        }
876    }
877}
878
879// =============================================================================
880// Arithmetic Operator Implementations
881// =============================================================================
882
883impl<T: Numeric> Tensor<T> {
884    /// Element-wise addition with broadcasting.
885    pub fn add(&self, other: &Self) -> Result<Self> {
886        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
887        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
888        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
889
890        let total = numel(&result_shape);
891        let mut result_data = vec![T::zero(); total];
892
893        let self_data = self.storage.as_slice();
894        let other_data = other.storage.as_slice();
895
896        for i in 0..total {
897            let indices = crate::shape::unravel_index(i, &result_shape);
898            let self_idx = self.offset + linear_index(&indices, &self_strides);
899            let other_idx = other.offset + linear_index(&indices, &other_strides);
900            result_data[i] = self_data[self_idx] + other_data[other_idx];
901        }
902
903        Self::from_vec(result_data, &result_shape)
904    }
905
906    /// Element-wise subtraction with broadcasting.
907    pub fn sub(&self, other: &Self) -> Result<Self> {
908        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
909        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
910        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
911
912        let total = numel(&result_shape);
913        let mut result_data = vec![T::zero(); total];
914
915        let self_data = self.storage.as_slice();
916        let other_data = other.storage.as_slice();
917
918        for i in 0..total {
919            let indices = crate::shape::unravel_index(i, &result_shape);
920            let self_idx = self.offset + linear_index(&indices, &self_strides);
921            let other_idx = other.offset + linear_index(&indices, &other_strides);
922            result_data[i] = self_data[self_idx] - other_data[other_idx];
923        }
924
925        Self::from_vec(result_data, &result_shape)
926    }
927
928    /// Element-wise multiplication with broadcasting.
929    pub fn mul(&self, other: &Self) -> Result<Self> {
930        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
931        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
932        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
933
934        let total = numel(&result_shape);
935        let mut result_data = vec![T::zero(); total];
936
937        let self_data = self.storage.as_slice();
938        let other_data = other.storage.as_slice();
939
940        for i in 0..total {
941            let indices = crate::shape::unravel_index(i, &result_shape);
942            let self_idx = self.offset + linear_index(&indices, &self_strides);
943            let other_idx = other.offset + linear_index(&indices, &other_strides);
944            result_data[i] = self_data[self_idx] * other_data[other_idx];
945        }
946
947        Self::from_vec(result_data, &result_shape)
948    }
949
950    /// Element-wise division with broadcasting.
951    pub fn div(&self, other: &Self) -> Result<Self> {
952        let result_shape = broadcast_shape(&self.shape, &other.shape)?;
953        let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
954        let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
955
956        let total = numel(&result_shape);
957        let mut result_data = vec![T::zero(); total];
958
959        let self_data = self.storage.as_slice();
960        let other_data = other.storage.as_slice();
961
962        for i in 0..total {
963            let indices = crate::shape::unravel_index(i, &result_shape);
964            let self_idx = self.offset + linear_index(&indices, &self_strides);
965            let other_idx = other.offset + linear_index(&indices, &other_strides);
966            result_data[i] = self_data[self_idx] / other_data[other_idx];
967        }
968
969        Self::from_vec(result_data, &result_shape)
970    }
971
972    /// Scalar addition.
973    #[must_use]
974    pub fn add_scalar(&self, scalar: T) -> Self {
975        let data = self.to_vec();
976        let mut result = vec![T::zero(); data.len()];
977        CpuBackend::add_scalar(&mut result, &data, scalar);
978        Self::from_vec(result, &self.shape).unwrap()
979    }
980
981    /// Scalar multiplication.
982    #[must_use]
983    pub fn mul_scalar(&self, scalar: T) -> Self {
984        let data = self.to_vec();
985        let mut result = vec![T::zero(); data.len()];
986        CpuBackend::mul_scalar(&mut result, &data, scalar);
987        Self::from_vec(result, &self.shape).unwrap()
988    }
989
990    /// Element-wise negation.
991    #[must_use]
992    pub fn neg(&self) -> Self {
993        let data = self.to_vec();
994        let mut result = vec![T::zero(); data.len()];
995        CpuBackend::neg(&mut result, &data);
996        Self::from_vec(result, &self.shape).unwrap()
997    }
998
999    /// Matrix multiplication with batching support.
1000    ///
1001    /// Supports:
1002    /// - 2D @ 2D: [m, k] @ [k, n] -> [m, n]
1003    /// - 3D @ 3D: [batch, m, k] @ [batch, k, n] -> [batch, m, n]
1004    /// - 4D @ 4D: [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n]
1005    pub fn matmul(&self, other: &Self) -> Result<Self> {
1006        if self.ndim() < 2 || other.ndim() < 2 {
1007            return Err(Error::invalid_operation(
1008                "matmul requires at least 2D tensors",
1009            ));
1010        }
1011
1012        let m = self.shape[self.ndim() - 2];
1013        let k1 = self.shape[self.ndim() - 1];
1014        let k2 = other.shape[other.ndim() - 2];
1015        let n = other.shape[other.ndim() - 1];
1016
1017        if k1 != k2 {
1018            return Err(Error::invalid_operation(format!(
1019                "matmul inner dimensions must match: {k1} vs {k2}"
1020            )));
1021        }
1022
1023        // For 2D matrices, do simple matmul
1024        if self.ndim() == 2 && other.ndim() == 2 {
1025            let a_data = self.contiguous().to_vec();
1026            let b_data = other.contiguous().to_vec();
1027            let mut c_data = vec![T::zero(); m * n];
1028            CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1029            return Self::from_vec(c_data, &[m, n]);
1030        }
1031
1032        // For batched matmul, compute batch size
1033        let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1034        let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1035
1036        if batch_dims_self != batch_dims_other {
1037            return Err(Error::invalid_operation(format!(
1038                "matmul batch dimensions must match: {:?} vs {:?}",
1039                batch_dims_self, batch_dims_other
1040            )));
1041        }
1042
1043        let batch_size: usize = batch_dims_self.iter().product();
1044        let a_stride = m * k1;
1045        let b_stride = k1 * n;
1046        let c_stride = m * n;
1047
1048        let a_data = self.contiguous().to_vec();
1049        let b_data = other.contiguous().to_vec();
1050        let mut c_data = vec![T::zero(); batch_size * m * n];
1051
1052        // Loop over batches and compute matmul for each
1053        for batch in 0..batch_size {
1054            let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1055            let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1056            let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1057            CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1058        }
1059
1060        // Build output shape: batch_dims + [m, n]
1061        let mut output_shape = batch_dims_self;
1062        output_shape.push(m);
1063        output_shape.push(n);
1064
1065        Self::from_vec(c_data, &output_shape)
1066    }
1067
1068    /// Dot product for 1D tensors.
1069    pub fn dot(&self, other: &Self) -> Result<Self> {
1070        if self.ndim() != 1 || other.ndim() != 1 {
1071            return Err(Error::invalid_operation("dot requires 1D tensors"));
1072        }
1073
1074        if self.shape[0] != other.shape[0] {
1075            return Err(Error::shape_mismatch(&self.shape, &other.shape));
1076        }
1077
1078        let a_data = self.to_vec();
1079        let b_data = other.to_vec();
1080        let result = CpuBackend::dot(&a_data, &b_data);
1081
1082        Ok(Self::scalar(result))
1083    }
1084}
1085
1086// =============================================================================
1087// Operator Trait Implementations
1088// =============================================================================
1089
1090impl<T: Numeric> Add for &Tensor<T> {
1091    type Output = Tensor<T>;
1092
1093    fn add(self, other: Self) -> Self::Output {
1094        self.add(other).expect("Addition failed")
1095    }
1096}
1097
1098impl<T: Numeric> Sub for &Tensor<T> {
1099    type Output = Tensor<T>;
1100
1101    fn sub(self, other: Self) -> Self::Output {
1102        self.sub(other).expect("Subtraction failed")
1103    }
1104}
1105
1106impl<T: Numeric> Mul for &Tensor<T> {
1107    type Output = Tensor<T>;
1108
1109    fn mul(self, other: Self) -> Self::Output {
1110        self.mul(other).expect("Multiplication failed")
1111    }
1112}
1113
1114impl<T: Numeric> Div for &Tensor<T> {
1115    type Output = Tensor<T>;
1116
1117    fn div(self, other: Self) -> Self::Output {
1118        self.div(other).expect("Division failed")
1119    }
1120}
1121
1122impl<T: Numeric> Neg for &Tensor<T> {
1123    type Output = Tensor<T>;
1124
1125    fn neg(self) -> Self::Output {
1126        self.neg()
1127    }
1128}
1129
1130// Scalar operations
1131impl<T: Numeric> Add<T> for &Tensor<T> {
1132    type Output = Tensor<T>;
1133
1134    fn add(self, scalar: T) -> Self::Output {
1135        self.add_scalar(scalar)
1136    }
1137}
1138
1139impl<T: Numeric> Mul<T> for &Tensor<T> {
1140    type Output = Tensor<T>;
1141
1142    fn mul(self, scalar: T) -> Self::Output {
1143        self.mul_scalar(scalar)
1144    }
1145}
1146
1147// =============================================================================
1148// Display Implementation
1149// =============================================================================
1150
1151impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1153        write!(
1154            f,
1155            "Tensor(shape={:?}, device={}",
1156            self.shape(),
1157            self.device()
1158        )?;
1159        if self.numel() <= 10 {
1160            write!(f, ", data={:?}", self.to_vec())?;
1161        }
1162        write!(f, ")")
1163    }
1164}
1165
1166impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1168        if self.is_scalar() {
1169            write!(f, "{}", self.item().unwrap())
1170        } else if self.ndim() == 1 {
1171            write!(f, "[")?;
1172            let data = self.to_vec();
1173            for (i, val) in data.iter().enumerate() {
1174                if i > 0 {
1175                    write!(f, ", ")?;
1176                }
1177                write!(f, "{val}")?;
1178            }
1179            write!(f, "]")
1180        } else {
1181            write!(f, "Tensor(shape={:?})", self.shape())
1182        }
1183    }
1184}
1185
1186// =============================================================================
1187// Tests
1188// =============================================================================
1189
1190#[cfg(test)]
1191mod tests {
1192    use super::*;
1193
1194    #[test]
1195    fn test_from_vec() {
1196        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1197        assert_eq!(t.shape(), &[2, 3]);
1198        assert_eq!(t.numel(), 6);
1199    }
1200
1201    #[test]
1202    fn test_get_set() {
1203        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1204        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1205        assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1206        assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1207        assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1208
1209        t.set(&[0, 0], 99.0).unwrap();
1210        assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1211    }
1212
1213    #[test]
1214    fn test_reshape() {
1215        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1216        let r = t.reshape(&[3, 2]).unwrap();
1217        assert_eq!(r.shape(), &[3, 2]);
1218
1219        let r = t.reshape(&[-1]).unwrap();
1220        assert_eq!(r.shape(), &[6]);
1221    }
1222
1223    #[test]
1224    fn test_transpose() {
1225        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1226        let r = t.t().unwrap();
1227        assert_eq!(r.shape(), &[3, 2]);
1228        assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1229        assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1230        assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1231    }
1232
1233    #[test]
1234    fn test_arithmetic() {
1235        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1236        let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1237
1238        let c = &a + &b;
1239        assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1240
1241        let d = &a * &b;
1242        assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1243    }
1244
1245    #[test]
1246    fn test_broadcasting() {
1247        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1248        let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1249
1250        let c = &a + &b;
1251        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1252    }
1253
1254    #[test]
1255    fn test_sum() {
1256        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1257        let s = t.sum();
1258        assert_eq!(s.item().unwrap(), 10.0);
1259    }
1260
1261    #[test]
1262    fn test_matmul() {
1263        // 2x2 @ 2x2
1264        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1265        let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1266        let c = a.matmul(&b).unwrap();
1267
1268        assert_eq!(c.shape(), &[2, 2]);
1269        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1270    }
1271
1272    #[test]
1273    fn test_relu() {
1274        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1275        let r = t.relu();
1276        assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1277    }
1278
1279    #[test]
1280    fn test_scalar() {
1281        let s = Tensor::<f32>::scalar(42.0);
1282        assert!(s.is_scalar());
1283        assert_eq!(s.numel(), 1);
1284        assert_eq!(s.item().unwrap(), 42.0);
1285    }
1286}