Skip to main content

axonml_tensor/
shape.rs

1//! Shape and Strides - Tensor Dimension Management
2//!
3//! # File
4//! `crates/axonml-tensor/src/shape.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use smallvec::SmallVec;
18
19use axonml_core::error::{Error, Result};
20
21// =============================================================================
22// Type Aliases
23// =============================================================================
24
25/// Shape type - dimensions of a tensor.
26/// Uses `SmallVec` for stack allocation of small shapes (up to 6 dimensions).
27pub type Shape = SmallVec<[usize; 6]>;
28
29/// Strides type - step sizes for each dimension.
30pub type Strides = SmallVec<[isize; 6]>;
31
32// =============================================================================
33// Shape Utilities
34// =============================================================================
35
36/// Computes the total number of elements from a shape.
37///
38/// # Arguments
39/// * `shape` - The tensor shape
40///
41/// # Returns
42/// Total number of elements (product of dimensions).
43#[must_use]
44pub fn numel(shape: &[usize]) -> usize {
45    shape.iter().product()
46}
47
48/// Computes row-major (C-order) strides for a shape.
49///
50/// # Arguments
51/// * `shape` - The tensor shape
52///
53/// # Returns
54/// Strides for contiguous row-major layout.
55#[must_use]
56pub fn contiguous_strides(shape: &[usize]) -> Strides {
57    if shape.is_empty() {
58        return Strides::new();
59    }
60
61    let mut strides = Strides::with_capacity(shape.len());
62    let mut stride = 1isize;
63
64    // Compute strides from right to left
65    for &dim in shape.iter().rev() {
66        strides.push(stride);
67        stride *= dim as isize;
68    }
69
70    strides.reverse();
71    strides
72}
73
74/// Checks if strides represent a contiguous memory layout.
75///
76/// # Arguments
77/// * `shape` - The tensor shape
78/// * `strides` - The tensor strides
79///
80/// # Returns
81/// True if the tensor is contiguous in memory.
82#[must_use]
83pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
84    if shape.is_empty() {
85        return true;
86    }
87
88    let expected = contiguous_strides(shape);
89    strides == expected.as_slice()
90}
91
92/// Computes the linear index from multi-dimensional indices.
93///
94/// # Arguments
95/// * `indices` - Multi-dimensional indices
96/// * `strides` - Tensor strides
97///
98/// # Returns
99/// Linear offset into storage.
100#[must_use]
101pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
102    debug_assert_eq!(indices.len(), strides.len());
103
104    let mut offset = 0isize;
105    for (&idx, &stride) in indices.iter().zip(strides.iter()) {
106        offset += idx as isize * stride;
107    }
108    offset as usize
109}
110
111/// Converts a linear index to multi-dimensional indices.
112///
113/// # Arguments
114/// * `linear` - Linear index
115/// * `shape` - Tensor shape
116///
117/// # Returns
118/// Multi-dimensional indices.
119#[must_use]
120pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
121    let mut indices = vec![0; shape.len()];
122
123    for (i, &dim) in shape.iter().enumerate().rev() {
124        indices[i] = linear % dim;
125        linear /= dim;
126    }
127
128    indices
129}
130
131// =============================================================================
132// Broadcasting
133// =============================================================================
134
135/// Computes the broadcast shape of two shapes.
136///
137/// Broadcasting follows `NumPy` rules:
138/// 1. Shapes are aligned from the right
139/// 2. Dimensions are compatible if equal or one of them is 1
140/// 3. Missing dimensions are treated as 1
141///
142/// # Arguments
143/// * `shape1` - First shape
144/// * `shape2` - Second shape
145///
146/// # Returns
147/// Broadcast shape, or error if shapes are incompatible.
148pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
149    let max_ndim = shape1.len().max(shape2.len());
150    let mut result = Shape::with_capacity(max_ndim);
151
152    // Iterate from right to left
153    for i in 0..max_ndim {
154        let d1 = if i < shape1.len() {
155            shape1[shape1.len() - 1 - i]
156        } else {
157            1
158        };
159
160        let d2 = if i < shape2.len() {
161            shape2[shape2.len() - 1 - i]
162        } else {
163            1
164        };
165
166        if d1 == d2 {
167            result.push(d1);
168        } else if d1 == 1 {
169            result.push(d2);
170        } else if d2 == 1 {
171            result.push(d1);
172        } else {
173            return Err(Error::BroadcastError {
174                shape1: shape1.to_vec(),
175                shape2: shape2.to_vec(),
176            });
177        }
178    }
179
180    result.reverse();
181    Ok(result)
182}
183
184/// Computes broadcast strides for a shape to match a target shape.
185///
186/// # Arguments
187/// * `shape` - Original shape
188/// * `strides` - Original strides
189/// * `target_shape` - Target broadcast shape
190///
191/// # Returns
192/// New strides for broadcasting (0 stride for broadcast dimensions).
193#[must_use]
194pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
195    let mut result = Strides::with_capacity(target_shape.len());
196    let shape_offset = target_shape.len() - shape.len();
197
198    for (i, &target_dim) in target_shape.iter().enumerate() {
199        if i < shape_offset {
200            // Dimension doesn't exist in original - broadcast
201            result.push(0);
202        } else {
203            let orig_idx = i - shape_offset;
204            let orig_dim = shape[orig_idx];
205
206            if orig_dim == target_dim {
207                result.push(strides[orig_idx]);
208            } else if orig_dim == 1 {
209                // Broadcast dimension
210                result.push(0);
211            } else {
212                // Should not happen if broadcast_shape was computed correctly
213                result.push(strides[orig_idx]);
214            }
215        }
216    }
217
218    result
219}
220
221/// Checks if two shapes are broadcastable.
222#[must_use]
223pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
224    broadcast_shape(shape1, shape2).is_ok()
225}
226
227// =============================================================================
228// Shape Manipulation
229// =============================================================================
230
231/// Reshapes a tensor shape, validating that total elements match.
232///
233/// Supports -1 in one dimension to infer the size.
234///
235/// # Arguments
236/// * `old_shape` - Current shape
237/// * `new_shape` - Target shape (can contain -1)
238///
239/// # Returns
240/// Resolved shape, or error if incompatible.
241pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
242    let old_numel = numel(old_shape);
243    let mut result = Shape::with_capacity(new_shape.len());
244    let mut infer_idx = None;
245    let mut known_numel = 1usize;
246
247    for (i, &dim) in new_shape.iter().enumerate() {
248        if dim == -1 {
249            if infer_idx.is_some() {
250                return Err(Error::invalid_operation("Can only have one -1 in reshape"));
251            }
252            infer_idx = Some(i);
253            result.push(0); // Placeholder
254        } else if dim < 0 {
255            return Err(Error::invalid_operation("Invalid dimension in reshape"));
256        } else {
257            let d = dim as usize;
258            known_numel *= d;
259            result.push(d);
260        }
261    }
262
263    if let Some(idx) = infer_idx {
264        if old_numel % known_numel != 0 {
265            return Err(Error::invalid_operation(
266                "Cannot infer dimension: not evenly divisible",
267            ));
268        }
269        result[idx] = old_numel / known_numel;
270    } else if known_numel != old_numel {
271        return Err(Error::shape_mismatch(old_shape, &result));
272    }
273
274    Ok(result)
275}
276
277/// Computes the shape after squeezing (removing dimensions of size 1).
278///
279/// # Arguments
280/// * `shape` - Input shape
281/// * `dim` - Optional dimension to squeeze (None = all)
282///
283/// # Returns
284/// Squeezed shape.
285#[must_use]
286pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
287    match dim {
288        Some(d) => {
289            let mut result = Shape::from_slice(shape);
290            if d < shape.len() && shape[d] == 1 {
291                result.remove(d);
292            }
293            result
294        }
295        None => shape.iter().copied().filter(|&d| d != 1).collect(),
296    }
297}
298
299/// Computes the shape after unsqueezing (adding a dimension of size 1).
300///
301/// # Arguments
302/// * `shape` - Input shape
303/// * `dim` - Dimension at which to insert
304///
305/// # Returns
306/// Unsqueezed shape, or error if dim is invalid.
307pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
308    if dim > shape.len() {
309        return Err(Error::InvalidDimension {
310            index: dim as i64,
311            ndim: shape.len(),
312        });
313    }
314
315    let mut result = Shape::with_capacity(shape.len() + 1);
316    result.extend_from_slice(&shape[..dim]);
317    result.push(1);
318    result.extend_from_slice(&shape[dim..]);
319    Ok(result)
320}
321
322/// Computes the shape after transposing dimensions.
323///
324/// # Arguments
325/// * `shape` - Input shape
326/// * `dim0` - First dimension
327/// * `dim1` - Second dimension
328///
329/// # Returns
330/// Transposed shape and strides modifier.
331pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
332    if dim0 >= shape.len() || dim1 >= shape.len() {
333        return Err(Error::InvalidDimension {
334            index: dim0.max(dim1) as i64,
335            ndim: shape.len(),
336        });
337    }
338
339    let mut result = Shape::from_slice(shape);
340    result.swap(dim0, dim1);
341    Ok(result)
342}
343
344/// Swaps two stride values.
345#[must_use]
346pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
347    let mut result = Strides::from_slice(strides);
348    result.swap(dim0, dim1);
349    result
350}
351
352// =============================================================================
353// Validation
354// =============================================================================
355
356/// Normalizes a dimension index, supporting negative indexing.
357///
358/// # Arguments
359/// * `dim` - Dimension index (can be negative)
360/// * `ndim` - Number of dimensions
361///
362/// # Returns
363/// Normalized positive index, or error if out of bounds.
364pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
365    let ndim_i64 = ndim as i64;
366
367    let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
368
369    if normalized < 0 || normalized >= ndim_i64 {
370        return Err(Error::InvalidDimension { index: dim, ndim });
371    }
372
373    Ok(normalized as usize)
374}
375
376/// Validates that indices are within bounds for a shape.
377pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
378    if indices.len() != shape.len() {
379        return Err(Error::invalid_operation(format!(
380            "Expected {} indices, got {}",
381            shape.len(),
382            indices.len()
383        )));
384    }
385
386    for (&idx, &dim) in indices.iter().zip(shape.iter()) {
387        if idx >= dim {
388            return Err(Error::IndexOutOfBounds {
389                index: idx,
390                size: dim,
391            });
392        }
393    }
394
395    Ok(())
396}
397
398// =============================================================================
399// Tests
400// =============================================================================
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_numel() {
408        assert_eq!(numel(&[2, 3, 4]), 24);
409        assert_eq!(numel(&[]), 1);
410        assert_eq!(numel(&[5]), 5);
411    }
412
413    #[test]
414    fn test_contiguous_strides() {
415        let shape = [2, 3, 4];
416        let strides = contiguous_strides(&shape);
417        assert_eq!(strides.as_slice(), &[12, 4, 1]);
418    }
419
420    #[test]
421    fn test_is_contiguous() {
422        let shape = [2, 3];
423        let strides = contiguous_strides(&shape);
424        assert!(is_contiguous(&shape, &strides));
425
426        let non_contig_strides: Strides = smallvec::smallvec![1, 2];
427        assert!(!is_contiguous(&shape, &non_contig_strides));
428    }
429
430    #[test]
431    fn test_broadcast_shape() {
432        // Same shapes
433        assert_eq!(
434            broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
435            &[2, 3]
436        );
437
438        // Broadcasting
439        assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
440
441        assert_eq!(
442            broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
443            &[2, 3]
444        );
445
446        assert_eq!(
447            broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
448            &[5, 2, 3]
449        );
450
451        // Incompatible
452        assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
453    }
454
455    #[test]
456    fn test_reshape() {
457        let old_shape = [2, 3, 4];
458
459        // Simple reshape
460        let new = reshape(&old_shape, &[6, 4]).unwrap();
461        assert_eq!(new.as_slice(), &[6, 4]);
462
463        // With -1 inference
464        let new = reshape(&old_shape, &[-1, 4]).unwrap();
465        assert_eq!(new.as_slice(), &[6, 4]);
466
467        // Invalid
468        assert!(reshape(&old_shape, &[5, 5]).is_err());
469    }
470
471    #[test]
472    fn test_squeeze() {
473        let shape = [1, 2, 1, 3, 1];
474
475        // Squeeze all
476        let squeezed = squeeze(&shape, None);
477        assert_eq!(squeezed.as_slice(), &[2, 3]);
478
479        // Squeeze specific dimension
480        let squeezed = squeeze(&shape, Some(0));
481        assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
482    }
483
484    #[test]
485    fn test_unsqueeze() {
486        let shape = [2, 3];
487
488        let unsqueezed = unsqueeze(&shape, 0).unwrap();
489        assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
490
491        let unsqueezed = unsqueeze(&shape, 1).unwrap();
492        assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
493
494        let unsqueezed = unsqueeze(&shape, 2).unwrap();
495        assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
496    }
497
498    #[test]
499    fn test_normalize_dim() {
500        assert_eq!(normalize_dim(0, 3).unwrap(), 0);
501        assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
502        assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
503
504        assert!(normalize_dim(3, 3).is_err());
505        assert!(normalize_dim(-4, 3).is_err());
506    }
507
508    #[test]
509    fn test_linear_index() {
510        // 2x3 matrix, row-major
511        let strides: Strides = smallvec::smallvec![3, 1];
512
513        assert_eq!(linear_index(&[0, 0], &strides), 0);
514        assert_eq!(linear_index(&[0, 1], &strides), 1);
515        assert_eq!(linear_index(&[1, 0], &strides), 3);
516        assert_eq!(linear_index(&[1, 2], &strides), 5);
517    }
518
519    #[test]
520    fn test_unravel_index() {
521        let shape = [2, 3, 4];
522
523        assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
524        assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
525        assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
526        assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
527    }
528}