Skip to main content

lumen_core/tensor/
shape.rs

1use std::sync::Arc;
2use crate::{AutogradMetaT, Dim, Dims, Error, Layout, Result, Shape, Storage, WithDType, D};
3use super::{Tensor, TensorId, TensorImpl, Slice};
4
5impl<T: WithDType> Tensor<T> {
6    /// Creates a new tensor with the specified dimension removed if its size was one.
7    ///
8    /// ```rust
9    /// use lumen_core::{Tensor, DType, D};
10    /// let a = Tensor::<f32>::zeros((2, 3, 1)).unwrap();
11    ///
12    /// let c = a.squeeze(2).unwrap();
13    /// assert_eq!(c.shape().dims(), &[2, 3]);
14    /// ```
15    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
16        let dims = self.dims();
17        let dim = dim.to_index(self.shape(), "squeeze")?;
18        if dims[dim] == 1 {
19            let mut dims = dims.to_vec();
20            let mut strides = self.layout().stride().to_vec();
21            dims.remove(dim);
22            strides.remove(dim);
23            let tensor_ = TensorImpl {
24                id: TensorId::new(),
25                storage: self.0.storage.clone(),
26                layout: Layout::new(dims, strides, self.layout().start_offset()),
27                meta: T::AutogradMeta::on_reshape_op(self)
28            };
29            Ok(Self(Arc::new(tensor_)))
30        } else {
31            Err( Error::SqueezeDimNot1 { shape: self.shape().clone(), dim } )?
32        }
33    }
34
35    /// Creates a new tensor with a dimension of size one inserted at the specified position.
36    ///
37    /// ```rust
38    /// use lumen_core::{Tensor, DType, D};
39    /// let a = Tensor::<f32>::zeros((2, 3)).unwrap();
40    ///
41    /// let c = a.unsqueeze(0).unwrap();
42    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
43    ///
44    /// let c = a.unsqueeze(D::Minus1).unwrap();
45    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
46    /// ```
47    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
48        let mut dims = self.dims().to_vec();
49        let mut strides = self.layout().stride().to_vec();
50        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
51        dims.insert(dim, 1);
52        let stride = if dim < strides.len() { strides[dim] } else { 1 };
53        strides.insert(dim, stride);
54        let tensor_ = TensorImpl {
55            id: TensorId::new(),
56            storage: self.0.storage.clone(),
57            layout: Layout::new(dims, strides, self.layout().start_offset()),
58            meta: T::AutogradMeta::on_reshape_op(self),
59        };
60        Ok(Self(Arc::new(tensor_)))
61    }
62
63    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
64    /// ranges from `start` to `start + len`.
65    /// ```
66    /// use lumen_core::Tensor;
67    /// let a = Tensor::new(&[
68    ///     [0f32, 1., 2.],
69    ///     [3.  , 4., 5.],
70    ///     [6.  , 7., 8.]
71    /// ]).unwrap();
72    ///
73    /// let b = a.narrow(0, 1, 2).unwrap();
74    /// assert_eq!(b.shape().dims(), &[2, 3]);
75    ///
76    /// let c = a.narrow(1, 1, 1).unwrap();
77    /// assert_eq!(c.shape().dims(), &[3, 1]);
78    /// ```
79    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
80        let dims = self.dims();
81        let dim = dim.to_index(self.shape(), "narrow")?;
82        let err = |msg| {
83            Err::<(), _>(Error::NarrowInvalidArgs {
84                shape: self.shape().clone(),
85                dim,
86                start,
87                len,
88                msg,
89            })
90        };
91
92        if start > dims[dim] {
93            err("start > dim_len")?;
94        }
95        if start.saturating_add(len) > dims[dim] {
96            err("start + len > dim_len")?
97        }
98        if start == 0 && dims[dim] == len {
99            Ok(self.clone())
100        } else {
101            let meta = T::AutogradMeta::on_narrow_op(self, dim, start, len);
102            let layout = self.layout().narrow(dim, start, len)?;
103            let tensor_ = TensorImpl {
104                id: TensorId::new(),
105                storage: self.0.storage.clone(),
106                layout,
107                meta
108            };
109            Ok(Self(Arc::new(tensor_)))
110        }
111    }
112
113    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
114    /// slice from `start` to `start : end : step`.
115    /// 
116    /// ```
117    /// use lumen_core::{Tensor, DType, s, Slice};
118    /// let a = Tensor::<i32>::zeros((5, 5, 5)).unwrap();
119    ///
120    /// let b = a.narrow(0, 1, 2).unwrap();
121    /// assert_eq!(b.shape().dims(), &[2, 5, 5]);
122    ///
123    /// let c = a.slice(1, &s!(::2)).unwrap();
124    /// assert_eq!(c.shape().dims(), &[5, 3, 5]);
125    /// ```
126    pub fn slice<D: Dim>(&self, dim: D, slice: &Slice) -> Result<Self> {
127        let dims = self.dims();
128        let dim = dim.to_index(self.shape(), "narrow")?;
129        let err = |msg| {
130            Err::<(), _>(Error::SliceInvalidArgs {
131                shape: self.shape().clone(),
132                dim,
133                slice: slice.clone(),
134                msg,
135            })
136        };
137
138        let end = match slice.end {
139            Some(end) if end >= 0 => end as usize,
140            Some(end) => {
141                let dis = -end as usize;
142                if dis > dims[dim] {
143                    0
144                } else {
145                    dims[dim] - dis
146                }
147            }
148            None => dims[dim],
149        };
150        if slice.start > dims[dim] {
151            err("start > dim_len")?;
152        }
153        if end > dims[dim] {
154            err("end > dim_len")?
155        }
156        if slice.start == 0 && dims[dim] == end && slice.step == 1 {
157            Ok(self.clone())
158        } else {
159            let meta = T::AutogradMeta::on_slice_op(self, dim, slice.start, end, slice.step);
160            let layout = self.layout().slice(dim, slice.start, end, slice.step)?;
161            Ok(self.share_storage(layout, meta))
162        }
163    }
164
165    /// Reshape returns a tensor with the target shape provided that the number of elements of the
166    /// original tensor is the same.
167    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
168    /// a new storage and copies the data over, the returned tensor is always contiguous.
169    ///
170    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
171    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
172    /// as to match the number of elements in the tensor.
173    /// 
174    /// ```rust
175    /// use lumen_core::{Tensor, DType, D};
176    /// let a = Tensor::<f32>::zeros((2, 3)).unwrap();
177    ///
178    /// let c = a.reshape((1, 6)).unwrap();
179    /// assert_eq!(c.shape().dims(), &[1, 6]);
180    /// ```
181    pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
182        let shape = shape.into();
183        if shape.element_count() != self.element_count() {
184            return Err(Error::ShapeMismatchBinaryOp {
185                lhs: self.shape().clone(),
186                rhs: shape,
187                op: "reshape",
188            })?;
189        }
190
191        let meta = T::AutogradMeta::on_reshape_op(self);
192        if self.is_contiguous() {
193            let layout = Layout::contiguous_with_offset(shape, self.layout().start_offset());
194            Ok(self.share_storage(layout, meta))
195        } else {
196            let storage = self.storage_read()?.copy(self.layout());
197            Ok(Self::from_storage(storage, shape, meta))
198        }
199    }
200    
201    /// Returns a Tensor that is a transposed version of the input, the given dimensions are
202    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Self> {
203        let dim1 = dim1.to_index(self.shape(), "transpose")?;
204        let dim2 = dim2.to_index(self.shape(), "transpose")?;
205        if dim1 == dim2 {
206            return Ok(self.clone());
207        }
208
209        let meta = T::AutogradMeta::on_transpose_op(self, dim1, dim2);
210        let layout = self.layout().transpose(dim1, dim2)?;
211        Ok(self.share_storage(layout, meta))
212    }
213
214    pub fn transpose_last(&self) -> Result<Self> {
215        self.transpose(D::Minus1, D::Minus2)
216    }
217
218    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
219    /// dims must be a permutation, i.e. include each dimension index exactly once.
220    ///
221    /// ```rust
222    /// use lumen_core::Tensor;
223    /// let tensor = Tensor::<u32>::arange(0u32, 120u32).unwrap().reshape((2, 3, 4, 5)).unwrap();
224    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
225    /// let tensor = tensor.permute((2, 3, 1, 0)).unwrap();
226    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
227    /// ```
228    pub fn permute<D: Dims>(&self, dims: D) -> Result<Self> {
229        let dims = dims.to_indexes(self.shape(), "permute")?;
230        // O(n^2) permutation check but these arrays are small.
231        let is_permutation =
232            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
233        if !is_permutation {
234            crate::bail!(
235                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
236                self.dims(),
237                dims
238            )
239        }
240        // let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
241        let layout = self.layout().permute(&dims)?;
242        let meta = T::AutogradMeta::on_permute_op(self, dims);
243        Ok(self.share_storage(layout, meta))
244    }
245
246    /// Concatenates two or more tensors along a particular dimension.
247    ///
248    /// All tensors must of the same rank, and the output will have
249    /// the same rank
250    ///
251    /// ```rust
252    /// use lumen_core::Tensor;
253    /// let a = Tensor::<f32>::zeros((2, 3)).unwrap();
254    /// let b = Tensor::<f32>::zeros((2, 3)).unwrap();
255    ///
256    /// let c = Tensor::cat(&[&a, &b], 0).unwrap();
257    /// assert_eq!(c.dims(), &[4, 3]);
258    ///
259    /// let c = Tensor::cat(&[&a, &b], 1).unwrap();
260    /// assert_eq!(c.dims(), &[2, 6]);
261    /// ```
262    pub fn cat<A: AsRef<Tensor<T>>, D: Dim>(arrs: &[A], dim: D) -> Result<Self> {
263        // check shape
264        if arrs.is_empty() {
265            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" })?
266        }
267    
268        // first arr's infomation
269        let arr0 = &arrs[0];
270        let rank0 = arr0.as_ref().rank();
271
272        // cat_dim must be valid!
273        let cat_dim = dim.to_index(arr0.as_ref().shape(), "cat")?;
274        let mut target_dims = arr0.as_ref().dims().to_vec();
275        target_dims[cat_dim] = 0;
276        let mut dim_offsets = vec![];
277
278        for (_arr_index, arr) in arrs.iter().enumerate() {
279            // check shape 
280            let rank = arr.as_ref().rank();
281            if rank != rank0 {
282                Err(Error::UnexpectedNumberOfDims {
283                    expected: rank,
284                    got: arr.as_ref().rank(),
285                    shape: arr.as_ref().shape().clone(),
286                })?
287            }
288
289            // zip arr0's dims and arr's dims
290            for (dim_index, (v1, v2)) in arr0.as_ref().dims().iter()
291                                                                    .zip(arr.as_ref().dims().iter())
292                                                                    .enumerate()
293            {
294                // accumalte the cat dim
295                if dim_index == cat_dim {
296                    dim_offsets.push(target_dims[cat_dim]);
297                    target_dims[cat_dim] += v2;
298                }
299
300                // all other dims should be same
301                if dim_index != cat_dim && v1 != v2 {
302                    Err(Error::ShapeMismatchCat {
303                        dim: dim_index,
304                        first_shape: arr0.as_ref().shape().clone(),
305                        n: dim_index + 1,
306                        nth_shape: arr0.as_ref().shape().clone(),
307                    })?
308                }
309            }
310        }
311        
312        // Now, all arr in arrs has same rank, and except `cat_dim`, all dims are equal
313        // [ (a, n1, b, c), (a, n2, b, c), ... , (a, nk, b, c).... ]
314        // target_dims = (a, n1+n2+...+nk, b, c)
315
316        let target_shape: Shape = target_dims.into();
317        
318        // Create a new storgae and copy
319        let mut dst: Vec<T> = Vec::with_capacity(target_shape.element_count());
320        unsafe { dst.set_len(target_shape.element_count()) };
321        
322        let meta = T::AutogradMeta::on_cat_op(arrs, cat_dim);
323        let res_arr = Self::from_storage(Storage::new(dst), target_shape, meta);
324
325        for (arr_index, arr) in arrs.iter().enumerate() {
326            // Take sub Tensor 
327            let sub_res_arr = res_arr.narrow(cat_dim, dim_offsets[arr_index], arr.as_ref().dims()[cat_dim])?;
328            assert_eq!(sub_res_arr.shape(), arr.as_ref().shape());
329            // MARK: copy_from is no grad
330            sub_res_arr.copy_from(arr.as_ref())?;
331        }
332
333        Ok(res_arr)
334    }
335
336    /// Stacks two or more tensors along a particular dimension.
337    ///
338    /// All tensors must have the same rank, and the output has one additional rank
339    ///
340    /// ```rust
341    /// use lumen_core::Tensor;
342    /// let a = Tensor::<f32>::zeros((2, 3)).unwrap();
343    /// let b = Tensor::<f32>::zeros((2, 3)).unwrap();
344    ///
345    /// let c = Tensor::stack(&[&a, &b], 0).unwrap();
346    /// assert_eq!(c.dims(), &[2, 2, 3]);
347    ///
348    /// let c = Tensor::stack(&[&a, &b], 2).unwrap();
349    /// assert_eq!(c.dims(), &[2, 3, 2]);
350    /// ```
351    pub fn stack<A: AsRef<Tensor<T>>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
352        if args.is_empty() {
353            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" })?
354        }
355        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
356        let args = args
357            .iter()
358            .map(|t| t.as_ref().unsqueeze(dim))
359            .collect::<Result<Vec<_>>>()?;
360        Self::cat(&args, dim)
361    }
362
363    /// Splits a tensor along a specified dimension into multiple sub-tensors.
364    ///
365    /// The tensor is split along the given `dim` into as many sub-tensors as
366    /// the size of that dimension. Each sub-tensor has the same shape as the
367    /// original tensor, except the size along `dim` becomes 1.
368    ///
369    /// ```rust
370    /// use lumen_core::Tensor;
371    ///
372    /// let a = Tensor::new(&[[1, 2], [3, 4], [5, 6]]).unwrap();
373    ///
374    /// // Split along axis 0 (rows)
375    /// let splits = a.split(0).unwrap();
376    /// assert_eq!(splits.len(), 3);
377    /// assert_eq!(splits[0].to_vec().unwrap(), [1, 2]);
378    /// assert_eq!(splits[1].to_vec().unwrap(), [3, 4]);
379    /// assert_eq!(splits[2].to_vec().unwrap(), [5, 6]);
380    ///
381    /// // Split along axis 1 (columns)
382    /// let splits = a.split(1).unwrap();
383    /// assert_eq!(splits.len(), 2);
384    /// assert_eq!(splits[0].to_vec().unwrap(), [1, 3, 5]);
385    /// assert_eq!(splits[1].to_vec().unwrap(), [2, 4, 6]);
386    ///
387    /// // 1D array
388    /// let b = Tensor::new(&[10, 20, 30]).unwrap();
389    /// let splits = b.split(0).unwrap();
390    /// assert_eq!(splits.len(), 3);
391    /// assert_eq!(splits[0].to_vec().unwrap(), [10]);
392    /// assert_eq!(splits[1].to_vec().unwrap(), [20]);
393    /// assert_eq!(splits[2].to_vec().unwrap(), [30]);
394    /// ```
395    pub fn split<D: Dim>(&self, dim: D) -> Result<Vec<Self>> {
396        let split_index = dim.to_index(self.shape(), "split")?;
397        let split_dim_size = self.dims()[split_index];
398        let mut splited_shape = self.dims().to_vec();
399        splited_shape.remove(split_index);
400
401        let mut vec = vec![];  
402        for i in 0..split_dim_size {
403            let sub_tensor = self.narrow(split_index, i, 1)?.squeeze(split_index)?;
404            vec.push(sub_tensor);
405        }
406
407        Ok(vec)
408    }
409
410    /// Split a tensor into the specified number of chunks, this may return less chunks than
411    /// specified.
412    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
413        let dim = dim.to_index(self.shape(), "chunk")?;
414        let size = self.dim(dim)?;
415        if size < chunks {
416            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
417        } else {
418            let chunk_size = size / chunks;
419            let cnt_additional = size % chunks;
420            let mut tensors = vec![];
421            let mut sum_chunk_size = 0;
422            for i in 0..chunks {
423                let chunk_size = if i < cnt_additional {
424                    chunk_size + 1
425                } else {
426                    chunk_size
427                };
428                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
429                tensors.push(tensor);
430                sum_chunk_size += chunk_size
431            }
432            Ok(tensors)
433        }
434    }
435
436    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
437    /// inclusive).
438    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Self> {
439        self.flatten_(Some(start_dim), Some(end_dim))
440    }
441
442    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
443    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Self> {
444        self.flatten_(None::<usize>, Some(end_dim))
445    }
446
447    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
448    /// dimension.
449    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Self> {
450        self.flatten_(Some(start_dim), None::<usize>)
451    }
452
453    /// Flattens the input tensor by reshaping it into a one dimension tensor.
454    /// 
455    /// ```rust
456    /// use lumen_core::Tensor;
457    /// let arr = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]]).unwrap();
458    /// let arr = arr.flatten_all().unwrap();
459    /// let len = arr.dims1().unwrap();
460    /// assert_eq!(len, 6);
461    /// assert_eq!(arr.to_vec().unwrap(), [0., 1., 2., 3., 4., 5.]);
462    /// ```
463    pub fn flatten_all(&self) -> Result<Self> {
464        self.flatten_(None::<usize>, None::<usize>)
465    }
466
467    /// Repeat this tensor along the specified dimensions.
468    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
469        let repeats: Shape = shape.into();
470        let mut repeats = repeats.dims().to_vec();
471
472        if repeats.len() > self.rank() {
473            Err(Error::RepeatRankOutOfRange { repeats: repeats.clone().into(), shape: self.shape().into() })?;
474        } else if repeats.len() > self.rank() {
475            for _ in 0..(repeats.len() - self.rank()) {
476                repeats.push(1);
477            }
478        }
479
480        let mut arr = self.clone();
481
482        for (idx, &repeat) in repeats.iter().enumerate() {
483            if repeat > 1 {
484                arr = Tensor::cat(&vec![&arr; repeat], idx)?
485            }
486        }
487        Ok(arr)
488    }
489
490    /// Repeat this tensor along the specified dimension with specified times
491    pub fn repeat_dim<D: Dim>(&self, dim: D, times: usize) -> Result<Self> {
492        if times == 0 {
493            self.squeeze(dim)
494        } else if times == 1 {
495            Ok(self.clone())
496        } else {
497            Tensor::cat(&vec![self; times], dim)
498        }
499    }
500
501    fn flatten_<D1: Dim, D2: Dim>(
502        &self,
503        start_dim: Option<D1>,
504        end_dim: Option<D2>,
505    ) -> Result<Self> {
506        if self.rank() == 0 {
507            self.reshape(1)
508        } else {
509            let start_dim = match start_dim {
510                None => 0,
511                Some(dim) => dim.to_index(self.shape(), "flatten")?,
512            };
513            let end_dim = match end_dim {
514                None => self.rank() - 1,
515                Some(dim) => dim.to_index(self.shape(), "flatten")?,
516            };
517            if start_dim < end_dim {
518                let dims = self.dims();
519                let mut dst_dims = dims[..start_dim].to_vec();
520                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
521                if end_dim + 1 < dims.len() {
522                    dst_dims.extend(&dims[end_dim + 1..]);
523                }
524                self.reshape(dst_dims)
525            } else {
526                Ok(self.clone())
527            }
528        }
529    }
530}
531
532impl<T: WithDType> AsRef<Tensor<T>> for Tensor<T> {
533    fn as_ref(&self) -> &Tensor<T> {
534        self
535    }
536}
537
538#[cfg(test)]
539#[allow(unused)]
540mod test {
541    use super::*;
542
543    #[test]
544    fn test_unsqueeze_basic() -> Result<()> {
545        let t = Tensor::<i32>::zeros((2, 3))?;
546        
547        // Unsqueeze axis 0
548        let unsq0 = t.unsqueeze(0)?;
549        assert_eq!(unsq0.dims(), &[1, 2, 3]);
550        assert_eq!(unsq0.to_vec()?, t.to_vec()?);
551
552        // Unsqueeze axis 1
553        let unsq1 = t.unsqueeze(1)?;
554        assert_eq!(unsq1.dims(), &[2, 1, 3]);
555        assert_eq!(unsq1.to_vec()?, t.to_vec()?);
556
557        // Unsqueeze last axis
558        let unsq2 = t.unsqueeze(2)?;
559        assert_eq!(unsq2.dims(), &[2, 3, 1]);
560        assert_eq!(unsq2.to_vec()?, t.to_vec()?);
561
562        Ok(())
563    }
564
565    #[test]
566    fn test_squeeze_basic() -> Result<()> {
567        let t = Tensor::<i32>::zeros((2, 1, 3))?;
568        
569        let sq = t.squeeze(1)?;
570        assert_eq!(sq.dims(), &[2, 3]);
571        assert_eq!(sq.to_vec()?, t.to_vec()?);
572        
573        // Squeezing a non-1 dimension usually remains unchanged or returns error
574        // depending on implementation. Assuming strict behavior (error) or 
575        // identity if your API allows "squeeze if 1". 
576        // Here we test the successful case of squeezing a singleton.
577        let t2 = Tensor::<i32>::zeros((1, 5))?;
578        let sq2 = t2.squeeze(0)?;
579        assert_eq!(sq2.dims(), &[5]);
580
581        Ok(())
582    }
583
584    #[test]
585    fn test_squeeze_unsqueeze_consistency() -> Result<()> {
586        let t = Tensor::new(&[[1, 2, 3], [4, 5, 6]])?; // shape [2, 3]
587        
588        let unsq = t.unsqueeze(0)?; // [1, 2, 3]
589        let sq = unsq.squeeze(0)?;  // [2, 3]
590        
591        assert_eq!(t.dims(), sq.dims());
592        assert_eq!(t.to_vec()?, sq.to_vec()?);
593        Ok(())
594    }
595
596    #[test]
597    fn test_unsqueeze() -> Result<()> {
598        let t = Tensor::<i32>::zeros((2, 1, 3))?;
599        let sq = t.squeeze(1)?;
600        println!("{}", sq);
601        assert_eq!(sq.dims(), vec![2, 3]);
602
603        let unsq = sq.unsqueeze(0)?;
604        println!("{}", unsq);
605        assert_eq!(unsq.dims(), vec![1, 2, 3]);
606
607        Ok(())
608    }
609    
610    #[test]
611    fn test_cat_3d() -> Result<()> {
612        let a = Tensor::full((2, 2, 2), 1)?;
613        let b = Tensor::full((2, 2, 2), 2)?;
614    
615        let c = Tensor::cat(&[a, b], 0)?;
616        assert_eq!(c.dims(), [4, 2, 2]);
617    
618        let c2 = Tensor::cat(&[c.clone(), c.clone()], 1)?;
619        assert_eq!(c2.dims(), [4, 4, 2]);
620    
621        Ok(())
622    }
623
624    #[test]
625    fn test_cat_1d() -> Result<()> {
626        let a = Tensor::new(&[1, 2, 3])?;
627        let b = Tensor::new(&[4, 5, 6])?;
628        let c = Tensor::new(&[7])?;
629    
630        let res = Tensor::cat(&[a, b, c], 0)?;
631        assert_eq!(res.dims(), &[7]);
632        assert_eq!(res.to_vec()?, &[1, 2, 3, 4, 5, 6, 7]);
633    
634        Ok(())
635    }
636    
637    #[test]
638    fn test_cat_2d_axis0() -> Result<()> {
639        let a = Tensor::new(&[[1, 2], [3, 4]])?; // [2, 2]
640        let b = Tensor::new(&[[5, 6]])?;         // [1, 2]
641    
642        let c = Tensor::cat(&[a, b], 0)?;
643        assert_eq!(c.dims(), &[3, 2]);
644        assert_eq!(c.to_vec()?, &[1, 2, 3, 4, 5, 6]);
645    
646        Ok(())
647    }
648    
649    #[test]
650    fn test_cat_2d_axis1() -> Result<()> {
651        let a = Tensor::new(&[[1, 2], [3, 4]])?; // [2, 2]
652        let b = Tensor::new(&[[5], [6]])?;       // [2, 1]
653    
654        let c = Tensor::cat(&[a, b], 1)?;
655        assert_eq!(c.dims(), &[2, 3]);
656        // Row 1: 1, 2, 5. Row 2: 3, 4, 6.
657        assert_eq!(c.to_vec()?, &[1, 2, 5, 3, 4, 6]);
658    
659        Ok(())
660    }
661
662    #[test]
663    fn test_cat_shape_mismatch() {
664        // Mismatch on non-cat axis
665        let a = Tensor::new(&[[1, 2], [3, 4]]).unwrap(); // [2, 2]
666        let b = Tensor::new(&[[1, 2, 3]]).unwrap();      // [1, 3]
667    
668        // Try to cat on axis 0, axis 1 mismatch (2 vs 3)
669        let res = Tensor::cat(&[a, b], 0);
670        assert!(res.is_err());
671    }
672
673    #[test]
674    fn test_cat_empty_list_error() {
675        let res = Tensor::<i32>::cat::<Tensor<i32>, usize>(&[], 0);
676        assert!(res.is_err(), "Concatenating an empty list should return an error");
677    }
678    
679    #[test]
680    fn test_cat_bool() -> Result<()> {
681        let a = Tensor::new(&[[true, false]])?;
682        let b = Tensor::new(&[[false, true]])?;
683    
684        let c = Tensor::cat(&[a, b], 0)?;
685        assert_eq!(c.dims(), [2, 2]);
686        assert_eq!(c.to_vec().unwrap(), [true, false, false, true]);
687    
688        Ok(())
689    }
690    
691    #[test]
692    fn test_stack_1d_axis0() -> Result<()> {
693        let a = Tensor::new(&[1, 2, 3])?;
694        let b = Tensor::new(&[4, 5, 6])?;
695    
696        let c = Tensor::stack(&[a, b], 0)?;
697        assert_eq!(c.dims(), [2, 3]); 
698        assert_eq!(c.to_vec().unwrap(), [1, 2, 3, 4, 5, 6]);
699    
700        Ok(())
701    }
702    
703    #[test]
704    fn test_stack_1d_axis1() -> Result<()> {
705        let a = Tensor::new(&[1, 2, 3])?;
706        let b = Tensor::new(&[4, 5, 6])?;
707    
708        let c = Tensor::stack(&[a, b], 1)?;
709        assert_eq!(c.dims(), [3, 2]);
710        assert_eq!(c.to_vec().unwrap(), [1, 4, 2, 5, 3, 6]);
711    
712        Ok(())
713    }
714    
715    #[test]
716    fn test_stack_2d_axis0() -> Result<()> {
717        let a = Tensor::new(&[[1, 2], [3, 4]])?;
718        let b = Tensor::new(&[[5, 6], [7, 8]])?;
719    
720        let c = Tensor::stack(&[a, b], 0)?;
721        assert_eq!(c.dims(), [2, 2, 2]);
722        assert_eq!(c.to_vec().unwrap(), [1, 2, 3, 4, 5, 6, 7, 8]);
723    
724        Ok(())
725    }
726    
727    #[test]
728    fn test_stack_2d_axis1() -> Result<()> {
729        let a = Tensor::new(&[[1, 2], [3, 4]])?;
730        let b = Tensor::new(&[[5, 6], [7, 8]])?;
731    
732        let c = Tensor::stack(&[a, b], 1)?;
733        assert_eq!(c.dims(), [2, 2, 2]);
734        assert_eq!(c.to_vec().unwrap(), [1, 2, 5, 6, 3, 4, 7, 8]);
735    
736        Ok(())
737    }
738    
739    #[test]
740    fn test_stack_2d_axis2() -> Result<()> {
741        let a = Tensor::new(&[[1, 2], [3, 4]])?;
742        let b = Tensor::new(&[[5, 6], [7, 8]])?;
743    
744        let c = Tensor::stack(&[a, b], 2)?;
745        assert_eq!(c.dims(), [2, 2, 2]);
746        assert_eq!(c.to_vec().unwrap(), [1, 5, 2, 6, 3, 7, 4, 8]);
747    
748        Ok(())
749    }
750    
751    #[test]
752    fn test_stack_shape_mismatch() {
753        let a = Tensor::new(&[1, 2, 3]).unwrap();
754        let b = Tensor::new(&[4, 5]).unwrap();
755    
756        let res = Tensor::stack(&[a, b], 0);
757        assert!(res.is_err());
758    }
759    
760    #[test]
761    fn test_split_1d() -> Result<()> {
762        let a = Tensor::new(&[10, 20, 30, 40])?;
763        let splits = a.split(0)?; // axis 0
764    
765        assert_eq!(splits.len(), 4);
766        assert_eq!(splits[0].to_vec().unwrap(), [10]);
767        assert_eq!(splits[1].to_vec().unwrap(), [20]);
768        assert_eq!(splits[2].to_vec().unwrap(), [30]);
769        assert_eq!(splits[3].to_vec().unwrap(), [40]);
770    
771        Ok(())
772    }
773
774    #[test]
775    fn test_split_2d_axis0() -> Result<()> {
776        let a = Tensor::new(&[[1, 2], [3, 4], [5, 6], [7, 8]])?;
777        let splits = a.split(0)?;
778        
779        assert_eq!(splits.len(), 4);
780        assert_eq!(splits[0].to_vec().unwrap(), [1, 2]);
781        assert_eq!(splits[1].to_vec().unwrap(), [3, 4]);
782        assert_eq!(splits[2].to_vec().unwrap(), [5, 6]);
783        assert_eq!(splits[3].to_vec().unwrap(), [7, 8]);
784        
785        Ok(())
786    }
787    
788    #[test]
789    fn test_split_2d_axis1() -> Result<()> {
790        let a = Tensor::new(&[[1, 2, 3], [4, 5, 6]])?;
791        let splits = a.split(1)?;
792        
793        assert_eq!(splits.len(), 3);
794        assert_eq!(splits[0].to_vec().unwrap(), [1, 4]); 
795        assert_eq!(splits[1].to_vec().unwrap(), [2, 5]); 
796        assert_eq!(splits[2].to_vec().unwrap(), [3, 6]); 
797        
798        Ok(())
799    }
800    
801    #[test]
802    fn test_split_3d_axis2() -> Result<()> {
803        let a = Tensor::new(&[
804            [[1, 2], [3, 4]],
805            [[5, 6], [7, 8]]
806        ])?;
807        let splits = a.split(2)?;
808        
809        assert_eq!(splits.len(), 2);
810        assert_eq!(splits[0].to_vec().unwrap(), [1, 3, 5, 7]); 
811        assert_eq!(splits[1].to_vec().unwrap(), [2, 4, 6, 8]); 
812        
813        Ok(())
814    }
815    
816    #[test]
817    fn test_split_single_element() -> Result<()> {
818        let a = Tensor::new(&[42])?;
819        let splits = a.split(0)?;
820        
821        assert_eq!(splits.len(), 1);
822        assert_eq!(splits[0].to_vec().unwrap(), [42]);
823        
824        Ok(())
825    }
826    
827    #[test]
828    fn test_split_empty_array() -> Result<()> {
829        let a = Tensor::<i32>::zeros((0, 2))?;
830        let splits = a.split(0)?;
831        
832        assert!(splits.is_empty()); 
833        Ok(())
834    }
835
836    #[test]
837    fn test_repeat_1d() -> Result<()> {
838        let a = Tensor::new(&[1, 2, 3])?;
839        let b = a.repeat(3)?; // repeat each dimension 3 times
840        assert_eq!(b.dims(), [3 * 3]); // shape: [9]
841        assert_eq!(b.to_vec().unwrap(), [1, 2, 3, 1, 2, 3, 1, 2, 3]);
842        Ok(())
843    }
844
845    #[test]
846    fn test_repeat_2d() -> Result<()> {
847        let a = Tensor::new(&[[1, 2], [3, 4]])?;
848        let b = a.repeat((2, 3))?; // repeat 2 times along axis 0, 3 times along axis 1
849        assert_eq!(b.dims(), [4, 6]);
850        assert_eq!(
851            b.to_vec().unwrap(),
852            [
853                1, 2, 1, 2, 1, 2,
854                3, 4, 3, 4, 3, 4,
855                1, 2, 1, 2, 1, 2,
856                3, 4, 3, 4, 3, 4
857            ]
858        );
859        Ok(())
860    }
861
862    #[test]
863    fn test_repeat_dim() -> Result<()> {
864        let a = Tensor::new(&[1, 2, 3])?;
865        let b = a.repeat_dim(0, 2)?; // repeat along axis 0 two times
866        assert_eq!(b.dims(), [6]);
867        assert_eq!(b.to_vec().unwrap(), [1, 2, 3, 1, 2, 3]);
868
869        let c = a.repeat_dim(0, 1)?; // repeat once -> same as clone
870        assert_eq!(c.dims(), [3]);
871        assert_eq!(c.to_vec().unwrap(), [1, 2, 3]);
872
873        Ok(())
874    }
875
876    #[test]
877    fn test_repeat_high_dim() -> Result<()> {
878        let a = Tensor::new(&[[1, 2], [3, 4]])?;
879        let b = a.repeat((2, 3))?; // more dims than array, extra dims should be treated as 1
880        assert_eq!(b.dims(), [4, 6]);
881        Ok(())
882    }
883
884    #[test]
885    fn test_narrow_1d_basic() -> Result<()> {
886        let a = Tensor::new(&[0, 1, 2, 3, 4, 5])?;
887        
888        let b = a.narrow(0, 2, 3)?;
889        
890        assert_eq!(b.dims(), &[3]);
891        assert_eq!(b.to_vec().unwrap(), &[2, 3, 4]);
892
893        let b = Tensor::randn(0.0, 1.0, (5, 5))?;
894        println!("{:?}", b);
895        Ok(())
896    }
897
898    #[test]
899    fn test_narrow_2d_rows() -> Result<()> {
900        // Shape: [3, 3]
901        // [ 0,  1,  2 ]
902        // [ 3,  4,  5 ]
903        // [ 6,  7,  8 ]
904        let a = Tensor::new(&[
905            [0, 1, 2], 
906            [3, 4, 5], 
907            [6, 7, 8]
908        ])?;
909
910        let b = a.narrow(0, 1, 1)?;
911
912        assert_eq!(b.dims(), &[1, 3]);
913        assert_eq!(b.to_vec().unwrap(), &[3, 4, 5]);
914        Ok(())
915    }
916}