acme_tensor/impls/
reshape.rs1use crate::prelude::{TensorExpr, TensorId, TensorResult};
6use crate::shape::{Axis, IntoShape, ShapeError};
7use crate::tensor::TensorBase;
8
9impl<T> TensorBase<T>
10where
11 T: Clone,
12{
13 pub fn broadcast(&self, shape: impl IntoShape) -> Self {
16 let layout = self.layout().broadcast_as(shape).unwrap();
17 let op = TensorExpr::broadcast(self.clone(), layout.shape().clone());
18 Self {
19 id: TensorId::new(),
20 kind: self.kind(),
21 layout,
22 op: op.into(),
23 data: self.data().clone(),
24 }
25 }
26 #[doc(hidden)]
27 pub fn pad(&self, shape: impl IntoShape, _with: T) -> Self {
28 let shape = shape.into_shape();
29
30 let _diff = *self.shape().rank() - *shape.rank();
31
32 todo!()
33 }
34 pub fn swap_axes(&self, swap: Axis, with: Axis) -> Self {
36 let op = TensorExpr::swap_axes(self.clone(), swap, with);
37
38 let layout = self.layout().clone().swap_axes(swap, with);
39 let shape = self.layout.shape();
40 let mut data = self.data.to_vec();
41
42 for i in 0..shape[swap] {
43 for j in 0..shape[with] {
44 let scope = self.layout.index([i, j]);
45 let target = layout.index([j, i]);
46 data[target] = self.data()[scope].clone();
47 }
48 }
49
50 TensorBase {
51 id: TensorId::new(),
52 kind: self.kind,
53 layout,
54 op: op.into(),
55 data: data.clone(),
56 }
57 }
58 pub fn t(&self) -> Self {
60 let op = TensorExpr::transpose(self.clone());
61
62 let layout = self.layout().clone().reverse_axes();
63 TensorBase {
64 id: TensorId::new(),
65 kind: self.kind(),
66 layout,
67 op: op.into(),
68 data: self.data().clone(),
69 }
70 }
71 pub fn reshape(self, shape: impl IntoShape) -> TensorResult<Self> {
74 let shape = shape.into_shape();
75 if self.size() != shape.size() {
76 return Err(ShapeError::MismatchedElements.into());
77 }
78
79 let mut tensor = self;
80
81 tensor.layout.reshape(shape);
82
83 Ok(tensor)
84 }
85}