1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[repr(transparent)]
8pub struct Strides(Arc<[usize]>);
9
10impl Strides {
11 pub fn as_slice(&self) -> &[usize] {
12 &self.0
13 }
14
15 pub fn stride_at(&self, index: usize) -> usize {
16 self.0[index]
17 }
18
19 pub fn strides(&self) -> &[usize] {
20 &self.0
21 }
22}
23
24impl From<&[usize]> for Strides {
25 fn from(strides: &[usize]) -> Self {
26 Self(Arc::from(strides))
27 }
28}
29
30impl From<Vec<usize>> for Strides {
31 fn from(strides: Vec<usize>) -> Self {
32 Self(Arc::from(strides))
33 }
34}
35
36impl From<&Shape> for Strides {
37 fn from(shape: &Shape) -> Self {
38 let rank = shape.dimensions();
39 if rank == 0 {
40 return Self(Arc::from(Vec::<usize>::new()));
41 }
42
43 let mut strides = vec![1usize; rank];
44 if rank >= 2 {
45 for i in (0..rank - 1).rev() {
46 let next = shape.dim_at(i + 1);
47 strides[i] = strides[i + 1].saturating_mul(next);
48 }
49 }
50
51 Self(Arc::from(strides))
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57#[repr(transparent)]
58pub struct Shape(Arc<[usize]>);
59
60impl Shape {
61 pub fn new(dims: impl Into<Arc<[usize]>>) -> Self {
62 let dims = dims.into();
63 Shape(dims)
64 }
65
66 pub fn size(&self) -> usize {
69 self.0.iter().fold(1usize, |acc, &d| acc.saturating_mul(d))
70 }
71
72 pub fn try_size(&self) -> Option<usize> {
74 let mut acc = 1usize;
75 for &d in self.0.iter() {
76 acc = acc.checked_mul(d)?;
77 }
78
79 Some(acc)
80 }
81
82 pub fn dimensions(&self) -> usize {
83 self.0.len()
84 }
85
86 pub fn contains_dim(&self, dim: usize) -> bool {
87 self.0.contains(&dim)
88 }
89
90 pub fn dim_at(&self, index: usize) -> usize {
91 self.0[index]
92 }
93
94 pub fn rank(&self) -> usize {
95 self.0.len()
96 }
97
98 pub fn is_empty(&self) -> bool {
99 self.0.is_empty()
100 }
101
102 pub fn is_scalar(&self) -> bool {
103 self.0.len() == 1 && self.0[0] == 1
104 }
105
106 pub fn is_vector(&self) -> bool {
107 self.0.len() == 1
108 }
109
110 pub fn is_matrix(&self) -> bool {
111 self.0.len() == 2
112 }
113
114 pub fn is_tensor(&self) -> bool {
115 self.0.len() > 2
116 }
117
118 pub fn is_square(&self) -> bool {
119 self.0.len() == 2 && self.0[0] == self.0[1]
120 }
121
122 pub fn iter(&self) -> impl Iterator<Item = &usize> {
123 self.0.iter()
124 }
125
126 pub fn as_slice(&self) -> &[usize] {
127 &self.0
128 }
129}
130
131impl AsRef<[usize]> for Shape {
132 fn as_ref(&self) -> &[usize] {
133 self.as_slice()
134 }
135}
136
137impl AsRef<[usize]> for Strides {
138 fn as_ref(&self) -> &[usize] {
139 self.as_slice()
140 }
141}
142
143impl From<&Shape> for Shape {
144 fn from(shape: &Shape) -> Self {
145 Shape::new(Arc::clone(&shape.0))
146 }
147}
148
149impl From<Vec<i32>> for Shape {
150 fn from(dims: Vec<i32>) -> Self {
151 Shape::new(dims.into_iter().map(|d| d as usize).collect::<Vec<usize>>())
152 }
153}
154
155impl From<Vec<usize>> for Shape {
156 fn from(dims: Vec<usize>) -> Self {
157 Shape::new(dims)
158 }
159}
160
161impl From<usize> for Shape {
162 fn from(value: usize) -> Shape {
163 Shape::new(vec![value])
164 }
165}
166
167impl From<(usize, usize)> for Shape {
168 fn from(value: (usize, usize)) -> Shape {
169 Shape::new(vec![value.0, value.1])
170 }
171}
172
173impl From<(usize, usize, usize)> for Shape {
174 fn from(value: (usize, usize, usize)) -> Shape {
175 Shape::new(vec![value.0, value.1, value.2])
176 }
177}
178
179impl From<(usize, usize, usize, usize)> for Shape {
180 fn from(value: (usize, usize, usize, usize)) -> Shape {
181 Shape::new(vec![value.0, value.1, value.2, value.3])
182 }
183}
184
185impl From<(usize, usize, usize, usize, usize)> for Shape {
186 fn from(value: (usize, usize, usize, usize, usize)) -> Shape {
187 Shape::new(vec![value.0, value.1, value.2, value.3, value.4])
188 }
189}
190
191impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
192 fn from(value: (usize, usize, usize, usize, usize, usize)) -> Shape {
193 Shape::new(vec![value.0, value.1, value.2, value.3, value.4, value.5])
194 }
195}
196
197impl From<(usize, usize, usize, usize, usize, usize, usize)> for Shape {
198 fn from(value: (usize, usize, usize, usize, usize, usize, usize)) -> Shape {
199 Shape::new(vec![
200 value.0, value.1, value.2, value.3, value.4, value.5, value.6,
201 ])
202 }
203}
204
205impl From<(usize, usize, usize, usize, usize, usize, usize, usize)> for Shape {
206 fn from(value: (usize, usize, usize, usize, usize, usize, usize, usize)) -> Shape {
207 Shape::new(vec![
208 value.0, value.1, value.2, value.3, value.4, value.5, value.6, value.7,
209 ])
210 }
211}
212
213impl From<&[usize]> for Shape {
214 fn from(dims: &[usize]) -> Self {
215 Shape::new(dims.to_vec())
216 }
217}
218
219