1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
use std::fmt;
use itertools::Either;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_ensure_eq;
/// Metadata for a `FixedShapeTensor` extension type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FixedShapeTensorMetadata {
/// The logical shape of the tensor.
///
/// `logical_shape[i]` is the size of the `i`-th logical dimension. When a `permutation` is
/// present, the physical shape (i.e., the row-major memory layout) is derived as
/// `physical_shape[permutation[i]] = logical_shape[i]`.
///
/// May be empty (0D scalar tensor) or contain dimensions of size 0 (degenerate tensor).
logical_shape: Vec<usize>,
/// Optional names for each logical dimension. Each name corresponds to an entry in
/// `logical_shape`.
///
/// If names exist, there must be an equal number of names to logical dimensions.
dim_names: Option<Vec<String>>,
/// The permutation of the tensor's dimensions. `permutation[i]` is the physical dimension
/// index that logical dimension `i` maps to.
///
/// If this is `None`, then the logical and physical layouts are identical, equivalent to
/// the identity permutation `[0, 1, ..., N-1]`.
permutation: Option<Vec<usize>>,
}
impl FixedShapeTensorMetadata {
/// Creates a new [`FixedShapeTensorMetadata`] with the given logical `shape`.
///
/// Use [`with_dim_names`](Self::with_dim_names) and
/// [`with_permutation`](Self::with_permutation) to further configure the metadata.
pub fn new(shape: Vec<usize>) -> Self {
Self {
logical_shape: shape,
dim_names: None,
permutation: None,
}
}
/// Sets the dimension names for this tensor. An empty vec is normalized to `None` since a
/// 0-dimensional tensor has no dimensions to name.
///
/// The number of names must match the number of logical dimensions.
pub fn with_dim_names(mut self, names: Vec<String>) -> VortexResult<Self> {
if !names.is_empty() {
vortex_ensure_eq!(
names.len(),
self.logical_shape.len(),
"dim_names length ({}) must match logical_shape length ({})",
names.len(),
self.logical_shape.len()
);
self.dim_names = Some(names);
}
Ok(self)
}
/// Sets the permutation for this tensor. An empty vec is normalized to `None` since a
/// 0-dimensional tensor has no dimensions to permute.
///
/// The permutation must be a valid permutation of `[0, 1, ..., N-1]` where `N` is the
/// number of logical dimensions.
pub fn with_permutation(mut self, permutation: Vec<usize>) -> VortexResult<Self> {
if !permutation.is_empty() {
vortex_ensure_eq!(
permutation.len(),
self.logical_shape.len(),
"permutation length ({}) must match logical_shape length ({})",
permutation.len(),
self.logical_shape.len()
);
// Verify this is actually a permutation of [0..N).
let mut seen = vec![false; permutation.len()];
for &p in &permutation {
vortex_ensure!(
p < permutation.len(),
"permutation index {p} is out of range for {} dimensions",
permutation.len()
);
vortex_ensure!(!seen[p], "permutation contains duplicate index {p}");
seen[p] = true;
}
self.permutation = Some(permutation);
}
Ok(self)
}
/// Returns the number of dimensions (rank) of the tensor.
pub fn ndim(&self) -> usize {
self.logical_shape.len()
}
/// Returns the logical dimensions of the tensor as a slice.
pub fn logical_shape(&self) -> &[usize] {
&self.logical_shape
}
/// Returns the dimension names, if set.
pub fn dim_names(&self) -> Option<&[String]> {
self.dim_names.as_deref()
}
/// Returns the permutation, if set.
pub fn permutation(&self) -> Option<&[usize]> {
self.permutation.as_deref()
}
/// Returns an iterator over the physical shape of the tensor.
///
/// The physical shape describes the row-major memory layout. It is derived from the logical
/// shape by placing each logical dimension's size at its physical position:
/// `physical_shape[permutation[i]] = logical_shape[i]`.
///
/// When no permutation is present, the physical shape is identical to the logical shape.
pub fn physical_shape(&self) -> impl Iterator<Item = usize> + '_ {
let ndim = self.logical_shape.len();
let permutation = self.permutation.as_deref();
match permutation {
None => Either::Left(self.logical_shape.iter().copied()),
Some(perm) => Either::Right(
(0..ndim).map(move |p| self.logical_shape[Self::inverse_perm(perm, p)]),
),
}
}
/// Returns an iterator over the strides for each logical dimension of the tensor.
///
/// The stride for a logical dimension is the number of elements to skip in the flat backing
/// array in order to move one step along that logical dimension.
///
/// When a permutation is present, the physical memory is laid out in row-major order over the
/// physical dimensions (the logical dimensions reordered by the permutation), so the strides
/// are derived from that physical layout.
pub fn strides(&self) -> impl Iterator<Item = usize> + '_ {
let ndim = self.logical_shape.len();
let permutation = self.permutation.as_deref();
match permutation {
None => Either::Left(
(0..ndim).map(|i| self.logical_shape[i + 1..].iter().product::<usize>()),
),
Some(permutation) => {
Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation)))
}
}
}
/// Computes the stride for logical dimension `i` given a `permutation`.
///
/// The stride is the product of `logical_shape[j]` for all logical dimensions `j` whose
/// physical position (`perm[j]`) comes after the physical position of dimension `i`
/// (`perm[i]`).
fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize {
let phys = perm[i];
// Each call scans the full permutation, making `strides()` O(ndim^2) overall. Tensor rank
// is typically small, so avoiding a Vec allocation is a net win.
perm.iter()
.enumerate()
.filter(|&(_, &p)| p > phys)
.map(|(l, _)| self.logical_shape[l])
.product::<usize>()
}
/// Returns the logical dimension index that maps to physical position `p`. This is the
/// inverse of the permutation: if `perm[i] == p`, returns `i`.
///
/// Each call is a linear scan of `perm`, making callers that invoke this for every physical
/// position O(ndim^2) overall. Tensor rank is typically small (2–5), so avoiding a Vec
/// allocation for the full inverse permutation is a net win.
fn inverse_perm(perm: &[usize], p: usize) -> usize {
perm.iter()
.position(|&pi| pi == p)
.vortex_expect("permutation must contain every physical position exactly once")
}
}
impl fmt::Display for FixedShapeTensorMetadata {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Tensor(")?;
match &self.dim_names {
Some(names) => {
for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{name}: {dim}")?;
}
}
None => {
for (i, dim) in self.logical_shape.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{dim}")?;
}
}
}
if let Some(perm) = &self.permutation {
for (i, p) in perm.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{p}")?;
}
write!(f, "]")?;
}
write!(f, ")")
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
/// Reference implementation that computes permuted strides in an explicit, step-by-step way.
///
/// 1. Build the physical shape: `physical_shape[perm[i]] = logical_shape[i]`.
/// 2. Compute row-major strides over the physical shape.
/// 3. Map back to logical: `logical_stride[i] = physical_strides[perm[i]]`.
fn slow_strides(shape: &[usize], perm: &[usize]) -> Vec<usize> {
let ndim = shape.len();
// Derive the physical shape from the logical shape and the permutation.
let mut physical_shape = vec![0usize; ndim];
for l in 0..ndim {
physical_shape[perm[l]] = shape[l];
}
// Compute row-major strides over the physical shape.
let mut physical_strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
physical_strides[i] = physical_strides[i + 1] * physical_shape[i + 1];
}
// Map physical strides back to logical dimension order.
(0..ndim).map(|l| physical_strides[perm[l]]).collect()
}
// -- Row-major strides (no permutation) --
#[rstest]
#[case::scalar_0d(vec![], vec![])]
#[case::vector_1d(vec![5], vec![1])]
#[case::matrix_2d(vec![3, 4], vec![4, 1])]
#[case::tensor_3d(vec![2, 3, 4], vec![12, 4, 1])]
#[case::zero_dim( vec![3, 0, 4], vec![0, 4, 1])]
fn strides_row_major(#[case] shape: Vec<usize>, #[case] expected: Vec<usize>) {
let m = FixedShapeTensorMetadata::new(shape);
assert_eq!(m.strides().collect::<Vec<_>>(), expected);
}
// -- Permuted strides --
//
// Each case is checked against the expected value and cross-validated against the
// `slow_strides` reference implementation.
#[rstest]
// 2D transpose: physical shape = [4, 3].
#[case::transpose_2d(vec![3, 4], vec![1, 0], vec![1, 3])]
// 3D: physical shape = [3, 4, 2].
#[case::perm_3d_201( vec![2, 3, 4], vec![2, 0, 1], vec![1, 8, 2])]
// 3D with zero-sized dimension: physical shape = [4, 3, 0].
#[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![0, 1, 0])]
fn strides_permuted(
#[case] shape: Vec<usize>,
#[case] perm: Vec<usize>,
#[case] expected: Vec<usize>,
) -> VortexResult<()> {
let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
let actual: Vec<usize> = m.strides().collect();
assert_eq!(actual, expected);
assert_eq!(actual, slow_strides(&shape, &perm));
Ok(())
}
#[test]
fn strides_identity_permutation_matches_row_major() -> VortexResult<()> {
let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
let identity =
FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?;
assert_eq!(
row_major.strides().collect::<Vec<_>>(),
identity.strides().collect::<Vec<_>>(),
);
Ok(())
}
/// Cross-validates the fast `permuted_stride` against the reference `slow_strides` across a
/// broader set of shapes and permutations.
#[rstest]
#[case::perm_3d_120(vec![2, 3, 4], vec![1, 2, 0])]
#[case::perm_3d_021(vec![2, 3, 4], vec![0, 2, 1])]
#[case::identity_3d(vec![2, 3, 4], vec![0, 1, 2])]
#[case::zero_lead( vec![0, 3, 4], vec![2, 0, 1])]
#[case::rev_4d( vec![2, 3, 4, 5], vec![3, 2, 1, 0])]
#[case::swap_4d( vec![2, 3, 4, 5], vec![1, 0, 3, 2])]
#[case::half_4d( vec![2, 3, 4, 5], vec![2, 3, 0, 1])]
fn strides_match_slow_reference(
#[case] shape: Vec<usize>,
#[case] perm: Vec<usize>,
) -> VortexResult<()> {
let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
assert_eq!(m.strides().collect::<Vec<_>>(), slow_strides(&shape, &perm));
Ok(())
}
// -- Physical shape --
#[test]
fn physical_shape_no_permutation() {
let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
assert_eq!(m.physical_shape().collect::<Vec<_>>(), vec![2, 3, 4]);
}
#[rstest]
// Logical [3, 4] with perm [1, 0] → physical [4, 3].
#[case::transpose_2d(vec![3, 4], vec![1, 0], vec![4, 3])]
// Logical [2, 3, 4] with perm [2, 0, 1] → physical [3, 4, 2].
#[case::perm_3d( vec![2, 3, 4], vec![2, 0, 1], vec![3, 4, 2])]
// Identity: physical = logical.
#[case::identity( vec![2, 3, 4], vec![0, 1, 2], vec![2, 3, 4])]
// Logical [3, 0, 4] with perm [1, 2, 0] → physical [4, 3, 0].
#[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![4, 3, 0])]
fn physical_shape_permuted(
#[case] shape: Vec<usize>,
#[case] perm: Vec<usize>,
#[case] expected: Vec<usize>,
) -> VortexResult<()> {
let m = FixedShapeTensorMetadata::new(shape).with_permutation(perm)?;
assert_eq!(m.physical_shape().collect::<Vec<_>>(), expected);
Ok(())
}
#[test]
fn dim_names_wrong_length() {
let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]);
assert!(result.is_err());
}
#[test]
fn permutation_wrong_length() {
let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0]);
assert!(result.is_err());
}
#[test]
fn permutation_out_of_range() {
let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 5]);
assert!(result.is_err());
}
#[test]
fn permutation_duplicate_index() {
let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 0]);
assert!(result.is_err());
}
}