Skip to main content

axonml_tensor/
shape.rs

1//! Shape and Strides - Tensor Dimension Management
2//!
3//! Provides types and functions for managing tensor shapes, strides, and
4//! broadcasting rules. Shapes define the dimensions of a tensor, while
5//! strides define how to traverse the underlying storage.
6//!
7//! # Key Features
8//! - Efficient shape representation with small-vector optimization
9//! - Stride computation for contiguous and transposed layouts
10//! - Broadcasting support following `NumPy` rules
11//! - Shape validation and manipulation
12//!
13//! @version 0.1.0
14//! @author `AutomataNexus` Development Team
15
16use smallvec::SmallVec;
17
18use axonml_core::error::{Error, Result};
19
20// =============================================================================
21// Type Aliases
22// =============================================================================
23
24/// Shape type - dimensions of a tensor.
25/// Uses `SmallVec` for stack allocation of small shapes (up to 6 dimensions).
26pub type Shape = SmallVec<[usize; 6]>;
27
28/// Strides type - step sizes for each dimension.
29pub type Strides = SmallVec<[isize; 6]>;
30
31// =============================================================================
32// Shape Utilities
33// =============================================================================
34
35/// Computes the total number of elements from a shape.
36///
37/// # Arguments
38/// * `shape` - The tensor shape
39///
40/// # Returns
41/// Total number of elements (product of dimensions).
42#[must_use]
43pub fn numel(shape: &[usize]) -> usize {
44    shape.iter().product()
45}
46
47/// Computes row-major (C-order) strides for a shape.
48///
49/// # Arguments
50/// * `shape` - The tensor shape
51///
52/// # Returns
53/// Strides for contiguous row-major layout.
54#[must_use]
55pub fn contiguous_strides(shape: &[usize]) -> Strides {
56    if shape.is_empty() {
57        return Strides::new();
58    }
59
60    let mut strides = Strides::with_capacity(shape.len());
61    let mut stride = 1isize;
62
63    // Compute strides from right to left
64    for &dim in shape.iter().rev() {
65        strides.push(stride);
66        stride *= dim as isize;
67    }
68
69    strides.reverse();
70    strides
71}
72
73/// Checks if strides represent a contiguous memory layout.
74///
75/// # Arguments
76/// * `shape` - The tensor shape
77/// * `strides` - The tensor strides
78///
79/// # Returns
80/// True if the tensor is contiguous in memory.
81#[must_use]
82pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
83    if shape.is_empty() {
84        return true;
85    }
86
87    let expected = contiguous_strides(shape);
88    strides == expected.as_slice()
89}
90
91/// Computes the linear index from multi-dimensional indices.
92///
93/// # Arguments
94/// * `indices` - Multi-dimensional indices
95/// * `strides` - Tensor strides
96///
97/// # Returns
98/// Linear offset into storage.
99#[must_use]
100pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
101    debug_assert_eq!(indices.len(), strides.len());
102
103    let mut offset = 0isize;
104    for (&idx, &stride) in indices.iter().zip(strides.iter()) {
105        offset += idx as isize * stride;
106    }
107    offset as usize
108}
109
110/// Converts a linear index to multi-dimensional indices.
111///
112/// # Arguments
113/// * `linear` - Linear index
114/// * `shape` - Tensor shape
115///
116/// # Returns
117/// Multi-dimensional indices.
118#[must_use]
119pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
120    let mut indices = vec![0; shape.len()];
121
122    for (i, &dim) in shape.iter().enumerate().rev() {
123        indices[i] = linear % dim;
124        linear /= dim;
125    }
126
127    indices
128}
129
130// =============================================================================
131// Broadcasting
132// =============================================================================
133
134/// Computes the broadcast shape of two shapes.
135///
136/// Broadcasting follows `NumPy` rules:
137/// 1. Shapes are aligned from the right
138/// 2. Dimensions are compatible if equal or one of them is 1
139/// 3. Missing dimensions are treated as 1
140///
141/// # Arguments
142/// * `shape1` - First shape
143/// * `shape2` - Second shape
144///
145/// # Returns
146/// Broadcast shape, or error if shapes are incompatible.
147pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
148    let max_ndim = shape1.len().max(shape2.len());
149    let mut result = Shape::with_capacity(max_ndim);
150
151    // Iterate from right to left
152    for i in 0..max_ndim {
153        let d1 = if i < shape1.len() {
154            shape1[shape1.len() - 1 - i]
155        } else {
156            1
157        };
158
159        let d2 = if i < shape2.len() {
160            shape2[shape2.len() - 1 - i]
161        } else {
162            1
163        };
164
165        if d1 == d2 {
166            result.push(d1);
167        } else if d1 == 1 {
168            result.push(d2);
169        } else if d2 == 1 {
170            result.push(d1);
171        } else {
172            return Err(Error::BroadcastError {
173                shape1: shape1.to_vec(),
174                shape2: shape2.to_vec(),
175            });
176        }
177    }
178
179    result.reverse();
180    Ok(result)
181}
182
183/// Computes broadcast strides for a shape to match a target shape.
184///
185/// # Arguments
186/// * `shape` - Original shape
187/// * `strides` - Original strides
188/// * `target_shape` - Target broadcast shape
189///
190/// # Returns
191/// New strides for broadcasting (0 stride for broadcast dimensions).
192#[must_use]
193pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
194    let mut result = Strides::with_capacity(target_shape.len());
195    let shape_offset = target_shape.len() - shape.len();
196
197    for (i, &target_dim) in target_shape.iter().enumerate() {
198        if i < shape_offset {
199            // Dimension doesn't exist in original - broadcast
200            result.push(0);
201        } else {
202            let orig_idx = i - shape_offset;
203            let orig_dim = shape[orig_idx];
204
205            if orig_dim == target_dim {
206                result.push(strides[orig_idx]);
207            } else if orig_dim == 1 {
208                // Broadcast dimension
209                result.push(0);
210            } else {
211                // Should not happen if broadcast_shape was computed correctly
212                result.push(strides[orig_idx]);
213            }
214        }
215    }
216
217    result
218}
219
220/// Checks if two shapes are broadcastable.
221#[must_use]
222pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
223    broadcast_shape(shape1, shape2).is_ok()
224}
225
226// =============================================================================
227// Shape Manipulation
228// =============================================================================
229
230/// Reshapes a tensor shape, validating that total elements match.
231///
232/// Supports -1 in one dimension to infer the size.
233///
234/// # Arguments
235/// * `old_shape` - Current shape
236/// * `new_shape` - Target shape (can contain -1)
237///
238/// # Returns
239/// Resolved shape, or error if incompatible.
240pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
241    let old_numel = numel(old_shape);
242    let mut result = Shape::with_capacity(new_shape.len());
243    let mut infer_idx = None;
244    let mut known_numel = 1usize;
245
246    for (i, &dim) in new_shape.iter().enumerate() {
247        if dim == -1 {
248            if infer_idx.is_some() {
249                return Err(Error::invalid_operation("Can only have one -1 in reshape"));
250            }
251            infer_idx = Some(i);
252            result.push(0); // Placeholder
253        } else if dim < 0 {
254            return Err(Error::invalid_operation("Invalid dimension in reshape"));
255        } else {
256            let d = dim as usize;
257            known_numel *= d;
258            result.push(d);
259        }
260    }
261
262    if let Some(idx) = infer_idx {
263        if old_numel % known_numel != 0 {
264            return Err(Error::invalid_operation(
265                "Cannot infer dimension: not evenly divisible",
266            ));
267        }
268        result[idx] = old_numel / known_numel;
269    } else if known_numel != old_numel {
270        return Err(Error::shape_mismatch(old_shape, &result));
271    }
272
273    Ok(result)
274}
275
276/// Computes the shape after squeezing (removing dimensions of size 1).
277///
278/// # Arguments
279/// * `shape` - Input shape
280/// * `dim` - Optional dimension to squeeze (None = all)
281///
282/// # Returns
283/// Squeezed shape.
284#[must_use]
285pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
286    match dim {
287        Some(d) => {
288            let mut result = Shape::from_slice(shape);
289            if d < shape.len() && shape[d] == 1 {
290                result.remove(d);
291            }
292            result
293        }
294        None => shape.iter().copied().filter(|&d| d != 1).collect(),
295    }
296}
297
298/// Computes the shape after unsqueezing (adding a dimension of size 1).
299///
300/// # Arguments
301/// * `shape` - Input shape
302/// * `dim` - Dimension at which to insert
303///
304/// # Returns
305/// Unsqueezed shape, or error if dim is invalid.
306pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
307    if dim > shape.len() {
308        return Err(Error::InvalidDimension {
309            index: dim as i64,
310            ndim: shape.len(),
311        });
312    }
313
314    let mut result = Shape::with_capacity(shape.len() + 1);
315    result.extend_from_slice(&shape[..dim]);
316    result.push(1);
317    result.extend_from_slice(&shape[dim..]);
318    Ok(result)
319}
320
321/// Computes the shape after transposing dimensions.
322///
323/// # Arguments
324/// * `shape` - Input shape
325/// * `dim0` - First dimension
326/// * `dim1` - Second dimension
327///
328/// # Returns
329/// Transposed shape and strides modifier.
330pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
331    if dim0 >= shape.len() || dim1 >= shape.len() {
332        return Err(Error::InvalidDimension {
333            index: dim0.max(dim1) as i64,
334            ndim: shape.len(),
335        });
336    }
337
338    let mut result = Shape::from_slice(shape);
339    result.swap(dim0, dim1);
340    Ok(result)
341}
342
343/// Swaps two stride values.
344#[must_use]
345pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
346    let mut result = Strides::from_slice(strides);
347    result.swap(dim0, dim1);
348    result
349}
350
351// =============================================================================
352// Validation
353// =============================================================================
354
355/// Normalizes a dimension index, supporting negative indexing.
356///
357/// # Arguments
358/// * `dim` - Dimension index (can be negative)
359/// * `ndim` - Number of dimensions
360///
361/// # Returns
362/// Normalized positive index, or error if out of bounds.
363pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
364    let ndim_i64 = ndim as i64;
365
366    let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
367
368    if normalized < 0 || normalized >= ndim_i64 {
369        return Err(Error::InvalidDimension { index: dim, ndim });
370    }
371
372    Ok(normalized as usize)
373}
374
375/// Validates that indices are within bounds for a shape.
376pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
377    if indices.len() != shape.len() {
378        return Err(Error::invalid_operation(format!(
379            "Expected {} indices, got {}",
380            shape.len(),
381            indices.len()
382        )));
383    }
384
385    for (&idx, &dim) in indices.iter().zip(shape.iter()) {
386        if idx >= dim {
387            return Err(Error::IndexOutOfBounds {
388                index: idx,
389                size: dim,
390            });
391        }
392    }
393
394    Ok(())
395}
396
397// =============================================================================
398// Tests
399// =============================================================================
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_numel() {
407        assert_eq!(numel(&[2, 3, 4]), 24);
408        assert_eq!(numel(&[]), 1);
409        assert_eq!(numel(&[5]), 5);
410    }
411
412    #[test]
413    fn test_contiguous_strides() {
414        let shape = [2, 3, 4];
415        let strides = contiguous_strides(&shape);
416        assert_eq!(strides.as_slice(), &[12, 4, 1]);
417    }
418
419    #[test]
420    fn test_is_contiguous() {
421        let shape = [2, 3];
422        let strides = contiguous_strides(&shape);
423        assert!(is_contiguous(&shape, &strides));
424
425        let non_contig_strides: Strides = smallvec::smallvec![1, 2];
426        assert!(!is_contiguous(&shape, &non_contig_strides));
427    }
428
429    #[test]
430    fn test_broadcast_shape() {
431        // Same shapes
432        assert_eq!(
433            broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
434            &[2, 3]
435        );
436
437        // Broadcasting
438        assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
439
440        assert_eq!(
441            broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
442            &[2, 3]
443        );
444
445        assert_eq!(
446            broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
447            &[5, 2, 3]
448        );
449
450        // Incompatible
451        assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
452    }
453
454    #[test]
455    fn test_reshape() {
456        let old_shape = [2, 3, 4];
457
458        // Simple reshape
459        let new = reshape(&old_shape, &[6, 4]).unwrap();
460        assert_eq!(new.as_slice(), &[6, 4]);
461
462        // With -1 inference
463        let new = reshape(&old_shape, &[-1, 4]).unwrap();
464        assert_eq!(new.as_slice(), &[6, 4]);
465
466        // Invalid
467        assert!(reshape(&old_shape, &[5, 5]).is_err());
468    }
469
470    #[test]
471    fn test_squeeze() {
472        let shape = [1, 2, 1, 3, 1];
473
474        // Squeeze all
475        let squeezed = squeeze(&shape, None);
476        assert_eq!(squeezed.as_slice(), &[2, 3]);
477
478        // Squeeze specific dimension
479        let squeezed = squeeze(&shape, Some(0));
480        assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
481    }
482
483    #[test]
484    fn test_unsqueeze() {
485        let shape = [2, 3];
486
487        let unsqueezed = unsqueeze(&shape, 0).unwrap();
488        assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
489
490        let unsqueezed = unsqueeze(&shape, 1).unwrap();
491        assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
492
493        let unsqueezed = unsqueeze(&shape, 2).unwrap();
494        assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
495    }
496
497    #[test]
498    fn test_normalize_dim() {
499        assert_eq!(normalize_dim(0, 3).unwrap(), 0);
500        assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
501        assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
502
503        assert!(normalize_dim(3, 3).is_err());
504        assert!(normalize_dim(-4, 3).is_err());
505    }
506
507    #[test]
508    fn test_linear_index() {
509        // 2x3 matrix, row-major
510        let strides: Strides = smallvec::smallvec![3, 1];
511
512        assert_eq!(linear_index(&[0, 0], &strides), 0);
513        assert_eq!(linear_index(&[0, 1], &strides), 1);
514        assert_eq!(linear_index(&[1, 0], &strides), 3);
515        assert_eq!(linear_index(&[1, 2], &strides), 5);
516    }
517
518    #[test]
519    fn test_unravel_index() {
520        let shape = [2, 3, 4];
521
522        assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
523        assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
524        assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
525        assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
526    }
527}