acme_tensor/impls/
reshape.rs

1/*
2    Appellation: reshape <impls>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::prelude::{TensorExpr, TensorId, TensorResult};
6use crate::shape::{Axis, IntoShape, ShapeError};
7use crate::tensor::TensorBase;
8
9impl<T> TensorBase<T>
10where
11    T: Clone,
12{
13    /// coerce the tensor to act like a larger shape.
14    /// This method doesn't change the underlying data, but it does change the layout.
15    pub fn broadcast(&self, shape: impl IntoShape) -> Self {
16        let layout = self.layout().broadcast_as(shape).unwrap();
17        let op = TensorExpr::broadcast(self.clone(), layout.shape().clone());
18        Self {
19            id: TensorId::new(),
20            kind: self.kind(),
21            layout,
22            op: op.into(),
23            data: self.data().clone(),
24        }
25    }
26    #[doc(hidden)]
27    pub fn pad(&self, shape: impl IntoShape, _with: T) -> Self {
28        let shape = shape.into_shape();
29
30        let _diff = *self.shape().rank() - *shape.rank();
31
32        todo!()
33    }
34    /// Swap two axes in the tensor.
35    pub fn swap_axes(&self, swap: Axis, with: Axis) -> Self {
36        let op = TensorExpr::swap_axes(self.clone(), swap, with);
37
38        let layout = self.layout().clone().swap_axes(swap, with);
39        let shape = self.layout.shape();
40        let mut data = self.data.to_vec();
41
42        for i in 0..shape[swap] {
43            for j in 0..shape[with] {
44                let scope = self.layout.index([i, j]);
45                let target = layout.index([j, i]);
46                data[target] = self.data()[scope].clone();
47            }
48        }
49
50        TensorBase {
51            id: TensorId::new(),
52            kind: self.kind,
53            layout,
54            op: op.into(),
55            data: data.clone(),
56        }
57    }
58    /// Transpose the tensor.
59    pub fn t(&self) -> Self {
60        let op = TensorExpr::transpose(self.clone());
61
62        let layout = self.layout().clone().reverse_axes();
63        TensorBase {
64            id: TensorId::new(),
65            kind: self.kind(),
66            layout,
67            op: op.into(),
68            data: self.data().clone(),
69        }
70    }
71    /// Reshape the tensor
72    /// returns an error if the new shape specifies a different number of elements.
73    pub fn reshape(self, shape: impl IntoShape) -> TensorResult<Self> {
74        let shape = shape.into_shape();
75        if self.size() != shape.size() {
76            return Err(ShapeError::MismatchedElements.into());
77        }
78
79        let mut tensor = self;
80
81        tensor.layout.reshape(shape);
82
83        Ok(tensor)
84    }
85}