Skip to main content

axonml_tensor/
shape.rs

1//! Shape and stride utilities for tensor dimension management.
2//!
3//! 528 lines. `Shape` (SmallVec-backed dim list) and `Strides` types,
4//! contiguous stride computation, `broadcast_shape` / `broadcast_strides`
5//! (NumPy rules), `reshape` (with -1 inference), `transpose_shape` /
6//! `transpose_strides`, `squeeze` / `unsqueeze`, `normalize_dim` (negative
7//! index), `linear_index` (multi-dim → flat offset via strides), `numel`,
8//! and `is_contiguous` check.
9//!
10//! # File
11//! `crates/axonml-tensor/src/shape.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use smallvec::SmallVec;
26
27use axonml_core::error::{Error, Result};
28
29// =============================================================================
30// Type Aliases
31// =============================================================================
32
33/// Shape type - dimensions of a tensor.
34/// Uses `SmallVec` for stack allocation of small shapes (up to 6 dimensions).
35pub type Shape = SmallVec<[usize; 6]>;
36
37/// Strides type - step sizes for each dimension.
38pub type Strides = SmallVec<[isize; 6]>;
39
40// =============================================================================
41// Shape Utilities
42// =============================================================================
43
44/// Computes the total number of elements from a shape.
45///
46/// # Arguments
47/// * `shape` - The tensor shape
48///
49/// # Returns
50/// Total number of elements (product of dimensions).
51#[must_use]
52pub fn numel(shape: &[usize]) -> usize {
53    shape.iter().product()
54}
55
56/// Computes row-major (C-order) strides for a shape.
57///
58/// # Arguments
59/// * `shape` - The tensor shape
60///
61/// # Returns
62/// Strides for contiguous row-major layout.
63#[must_use]
64pub fn contiguous_strides(shape: &[usize]) -> Strides {
65    if shape.is_empty() {
66        return Strides::new();
67    }
68
69    let mut strides = Strides::with_capacity(shape.len());
70    let mut stride = 1isize;
71
72    // Compute strides from right to left
73    for &dim in shape.iter().rev() {
74        strides.push(stride);
75        stride *= dim as isize;
76    }
77
78    strides.reverse();
79    strides
80}
81
82/// Checks if strides represent a contiguous memory layout.
83///
84/// # Arguments
85/// * `shape` - The tensor shape
86/// * `strides` - The tensor strides
87///
88/// # Returns
89/// True if the tensor is contiguous in memory.
90#[must_use]
91pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
92    if shape.is_empty() {
93        return true;
94    }
95
96    let expected = contiguous_strides(shape);
97    strides == expected.as_slice()
98}
99
100/// Computes the linear index from multi-dimensional indices.
101///
102/// # Arguments
103/// * `indices` - Multi-dimensional indices
104/// * `strides` - Tensor strides
105///
106/// # Returns
107/// Linear offset into storage.
108#[must_use]
109pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
110    debug_assert_eq!(indices.len(), strides.len());
111
112    let mut offset = 0isize;
113    for (&idx, &stride) in indices.iter().zip(strides.iter()) {
114        offset += idx as isize * stride;
115    }
116    offset as usize
117}
118
119/// Converts a linear index to multi-dimensional indices.
120///
121/// # Arguments
122/// * `linear` - Linear index
123/// * `shape` - Tensor shape
124///
125/// # Returns
126/// Multi-dimensional indices.
127#[must_use]
128pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
129    let mut indices = vec![0; shape.len()];
130
131    for (i, &dim) in shape.iter().enumerate().rev() {
132        indices[i] = linear % dim;
133        linear /= dim;
134    }
135
136    indices
137}
138
139// =============================================================================
140// Broadcasting
141// =============================================================================
142
143/// Computes the broadcast shape of two shapes.
144///
145/// Broadcasting follows `NumPy` rules:
146/// 1. Shapes are aligned from the right
147/// 2. Dimensions are compatible if equal or one of them is 1
148/// 3. Missing dimensions are treated as 1
149///
150/// # Arguments
151/// * `shape1` - First shape
152/// * `shape2` - Second shape
153///
154/// # Returns
155/// Broadcast shape, or error if shapes are incompatible.
156pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
157    let max_ndim = shape1.len().max(shape2.len());
158    let mut result = Shape::with_capacity(max_ndim);
159
160    // Iterate from right to left
161    for i in 0..max_ndim {
162        let d1 = if i < shape1.len() {
163            shape1[shape1.len() - 1 - i]
164        } else {
165            1
166        };
167
168        let d2 = if i < shape2.len() {
169            shape2[shape2.len() - 1 - i]
170        } else {
171            1
172        };
173
174        if d1 == d2 {
175            result.push(d1);
176        } else if d1 == 1 {
177            result.push(d2);
178        } else if d2 == 1 {
179            result.push(d1);
180        } else {
181            return Err(Error::BroadcastError {
182                shape1: shape1.to_vec(),
183                shape2: shape2.to_vec(),
184            });
185        }
186    }
187
188    result.reverse();
189    Ok(result)
190}
191
192/// Computes broadcast strides for a shape to match a target shape.
193///
194/// # Arguments
195/// * `shape` - Original shape
196/// * `strides` - Original strides
197/// * `target_shape` - Target broadcast shape
198///
199/// # Returns
200/// New strides for broadcasting (0 stride for broadcast dimensions).
201#[must_use]
202pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
203    let mut result = Strides::with_capacity(target_shape.len());
204    let shape_offset = target_shape.len() - shape.len();
205
206    for (i, &target_dim) in target_shape.iter().enumerate() {
207        if i < shape_offset {
208            // Dimension doesn't exist in original - broadcast
209            result.push(0);
210        } else {
211            let orig_idx = i - shape_offset;
212            let orig_dim = shape[orig_idx];
213
214            if orig_dim == target_dim {
215                result.push(strides[orig_idx]);
216            } else if orig_dim == 1 {
217                // Broadcast dimension
218                result.push(0);
219            } else {
220                // Should not happen if broadcast_shape was computed correctly
221                result.push(strides[orig_idx]);
222            }
223        }
224    }
225
226    result
227}
228
229/// Checks if two shapes are broadcastable.
230#[must_use]
231pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
232    broadcast_shape(shape1, shape2).is_ok()
233}
234
235// =============================================================================
236// Shape Manipulation
237// =============================================================================
238
239/// Reshapes a tensor shape, validating that total elements match.
240///
241/// Supports -1 in one dimension to infer the size.
242///
243/// # Arguments
244/// * `old_shape` - Current shape
245/// * `new_shape` - Target shape (can contain -1)
246///
247/// # Returns
248/// Resolved shape, or error if incompatible.
249pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
250    let old_numel = numel(old_shape);
251    let mut result = Shape::with_capacity(new_shape.len());
252    let mut infer_idx = None;
253    let mut known_numel = 1usize;
254
255    for (i, &dim) in new_shape.iter().enumerate() {
256        if dim == -1 {
257            if infer_idx.is_some() {
258                return Err(Error::invalid_operation("Can only have one -1 in reshape"));
259            }
260            infer_idx = Some(i);
261            result.push(0); // Will be inferred on line 269
262        } else if dim < 0 {
263            return Err(Error::invalid_operation("Invalid dimension in reshape"));
264        } else {
265            let d = dim as usize;
266            known_numel *= d;
267            result.push(d);
268        }
269    }
270
271    if let Some(idx) = infer_idx {
272        if old_numel % known_numel != 0 {
273            return Err(Error::invalid_operation(
274                "Cannot infer dimension: not evenly divisible",
275            ));
276        }
277        result[idx] = old_numel / known_numel;
278    } else if known_numel != old_numel {
279        return Err(Error::shape_mismatch(old_shape, &result));
280    }
281
282    Ok(result)
283}
284
285/// Computes the shape after squeezing (removing dimensions of size 1).
286///
287/// # Arguments
288/// * `shape` - Input shape
289/// * `dim` - Optional dimension to squeeze (None = all)
290///
291/// # Returns
292/// Squeezed shape.
293#[must_use]
294pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
295    match dim {
296        Some(d) => {
297            let mut result = Shape::from_slice(shape);
298            if d < shape.len() && shape[d] == 1 {
299                result.remove(d);
300            }
301            result
302        }
303        None => shape.iter().copied().filter(|&d| d != 1).collect(),
304    }
305}
306
307/// Computes the shape after unsqueezing (adding a dimension of size 1).
308///
309/// # Arguments
310/// * `shape` - Input shape
311/// * `dim` - Dimension at which to insert
312///
313/// # Returns
314/// Unsqueezed shape, or error if dim is invalid.
315pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
316    if dim > shape.len() {
317        return Err(Error::InvalidDimension {
318            index: dim as i64,
319            ndim: shape.len(),
320        });
321    }
322
323    let mut result = Shape::with_capacity(shape.len() + 1);
324    result.extend_from_slice(&shape[..dim]);
325    result.push(1);
326    result.extend_from_slice(&shape[dim..]);
327    Ok(result)
328}
329
330/// Computes the shape after transposing dimensions.
331///
332/// # Arguments
333/// * `shape` - Input shape
334/// * `dim0` - First dimension
335/// * `dim1` - Second dimension
336///
337/// # Returns
338/// Transposed shape and strides modifier.
339pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
340    if dim0 >= shape.len() || dim1 >= shape.len() {
341        return Err(Error::InvalidDimension {
342            index: dim0.max(dim1) as i64,
343            ndim: shape.len(),
344        });
345    }
346
347    let mut result = Shape::from_slice(shape);
348    result.swap(dim0, dim1);
349    Ok(result)
350}
351
352/// Swaps two stride values.
353#[must_use]
354pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
355    let mut result = Strides::from_slice(strides);
356    result.swap(dim0, dim1);
357    result
358}
359
360// =============================================================================
361// Validation
362// =============================================================================
363
364/// Normalizes a dimension index, supporting negative indexing.
365///
366/// # Arguments
367/// * `dim` - Dimension index (can be negative)
368/// * `ndim` - Number of dimensions
369///
370/// # Returns
371/// Normalized positive index, or error if out of bounds.
372pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
373    let ndim_i64 = ndim as i64;
374
375    let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
376
377    if normalized < 0 || normalized >= ndim_i64 {
378        return Err(Error::InvalidDimension { index: dim, ndim });
379    }
380
381    Ok(normalized as usize)
382}
383
384/// Validates that indices are within bounds for a shape.
385pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
386    if indices.len() != shape.len() {
387        return Err(Error::invalid_operation(format!(
388            "Expected {} indices, got {}",
389            shape.len(),
390            indices.len()
391        )));
392    }
393
394    for (&idx, &dim) in indices.iter().zip(shape.iter()) {
395        if idx >= dim {
396            return Err(Error::IndexOutOfBounds {
397                index: idx,
398                size: dim,
399            });
400        }
401    }
402
403    Ok(())
404}
405
406// =============================================================================
407// Tests
408// =============================================================================
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_numel() {
416        assert_eq!(numel(&[2, 3, 4]), 24);
417        assert_eq!(numel(&[]), 1);
418        assert_eq!(numel(&[5]), 5);
419    }
420
421    #[test]
422    fn test_contiguous_strides() {
423        let shape = [2, 3, 4];
424        let strides = contiguous_strides(&shape);
425        assert_eq!(strides.as_slice(), &[12, 4, 1]);
426    }
427
428    #[test]
429    fn test_is_contiguous() {
430        let shape = [2, 3];
431        let strides = contiguous_strides(&shape);
432        assert!(is_contiguous(&shape, &strides));
433
434        let non_contig_strides: Strides = smallvec::smallvec![1, 2];
435        assert!(!is_contiguous(&shape, &non_contig_strides));
436    }
437
438    #[test]
439    fn test_broadcast_shape() {
440        // Same shapes
441        assert_eq!(
442            broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
443            &[2, 3]
444        );
445
446        // Broadcasting
447        assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
448
449        assert_eq!(
450            broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
451            &[2, 3]
452        );
453
454        assert_eq!(
455            broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
456            &[5, 2, 3]
457        );
458
459        // Incompatible
460        assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
461    }
462
463    #[test]
464    fn test_reshape() {
465        let old_shape = [2, 3, 4];
466
467        // Simple reshape
468        let new = reshape(&old_shape, &[6, 4]).unwrap();
469        assert_eq!(new.as_slice(), &[6, 4]);
470
471        // With -1 inference
472        let new = reshape(&old_shape, &[-1, 4]).unwrap();
473        assert_eq!(new.as_slice(), &[6, 4]);
474
475        // Invalid
476        assert!(reshape(&old_shape, &[5, 5]).is_err());
477    }
478
479    #[test]
480    fn test_squeeze() {
481        let shape = [1, 2, 1, 3, 1];
482
483        // Squeeze all
484        let squeezed = squeeze(&shape, None);
485        assert_eq!(squeezed.as_slice(), &[2, 3]);
486
487        // Squeeze specific dimension
488        let squeezed = squeeze(&shape, Some(0));
489        assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
490    }
491
492    #[test]
493    fn test_unsqueeze() {
494        let shape = [2, 3];
495
496        let unsqueezed = unsqueeze(&shape, 0).unwrap();
497        assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
498
499        let unsqueezed = unsqueeze(&shape, 1).unwrap();
500        assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
501
502        let unsqueezed = unsqueeze(&shape, 2).unwrap();
503        assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
504    }
505
506    #[test]
507    fn test_normalize_dim() {
508        assert_eq!(normalize_dim(0, 3).unwrap(), 0);
509        assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
510        assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
511
512        assert!(normalize_dim(3, 3).is_err());
513        assert!(normalize_dim(-4, 3).is_err());
514    }
515
516    #[test]
517    fn test_linear_index() {
518        // 2x3 matrix, row-major
519        let strides: Strides = smallvec::smallvec![3, 1];
520
521        assert_eq!(linear_index(&[0, 0], &strides), 0);
522        assert_eq!(linear_index(&[0, 1], &strides), 1);
523        assert_eq!(linear_index(&[1, 0], &strides), 3);
524        assert_eq!(linear_index(&[1, 2], &strides), 5);
525    }
526
527    #[test]
528    fn test_unravel_index() {
529        let shape = [2, 3, 4];
530
531        assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
532        assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
533        assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
534        assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
535    }
536}