axonml_tensor/
shape.rs

1//! Shape and Strides - Tensor Dimension Management
2//!
3//! Provides types and functions for managing tensor shapes, strides, and
4//! broadcasting rules. Shapes define the dimensions of a tensor, while
5//! strides define how to traverse the underlying storage.
6//!
7//! # Key Features
8//! - Efficient shape representation with small-vector optimization
9//! - Stride computation for contiguous and transposed layouts
10//! - Broadcasting support following `NumPy` rules
11//! - Shape validation and manipulation
12//!
13//! @version 0.1.0
14//! @author `AutomataNexus` Development Team
15
16use smallvec::SmallVec;
17
18use axonml_core::error::{Error, Result};
19
20// =============================================================================
21// Type Aliases
22// =============================================================================
23
24/// Shape type - dimensions of a tensor.
25/// Uses `SmallVec` for stack allocation of small shapes (up to 6 dimensions).
26pub type Shape = SmallVec<[usize; 6]>;
27
28/// Strides type - step sizes for each dimension.
29pub type Strides = SmallVec<[isize; 6]>;
30
31// =============================================================================
32// Shape Utilities
33// =============================================================================
34
35/// Computes the total number of elements from a shape.
36///
37/// # Arguments
38/// * `shape` - The tensor shape
39///
40/// # Returns
41/// Total number of elements (product of dimensions).
42#[must_use]
43pub fn numel(shape: &[usize]) -> usize {
44    shape.iter().product()
45}
46
47/// Computes row-major (C-order) strides for a shape.
48///
49/// # Arguments
50/// * `shape` - The tensor shape
51///
52/// # Returns
53/// Strides for contiguous row-major layout.
54#[must_use]
55pub fn contiguous_strides(shape: &[usize]) -> Strides {
56    if shape.is_empty() {
57        return Strides::new();
58    }
59
60    let mut strides = Strides::with_capacity(shape.len());
61    let mut stride = 1isize;
62
63    // Compute strides from right to left
64    for &dim in shape.iter().rev() {
65        strides.push(stride);
66        stride *= dim as isize;
67    }
68
69    strides.reverse();
70    strides
71}
72
73/// Checks if strides represent a contiguous memory layout.
74///
75/// # Arguments
76/// * `shape` - The tensor shape
77/// * `strides` - The tensor strides
78///
79/// # Returns
80/// True if the tensor is contiguous in memory.
81#[must_use]
82pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
83    if shape.is_empty() {
84        return true;
85    }
86
87    let expected = contiguous_strides(shape);
88    strides == expected.as_slice()
89}
90
91/// Computes the linear index from multi-dimensional indices.
92///
93/// # Arguments
94/// * `indices` - Multi-dimensional indices
95/// * `strides` - Tensor strides
96///
97/// # Returns
98/// Linear offset into storage.
99#[must_use]
100pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
101    debug_assert_eq!(indices.len(), strides.len());
102
103    let mut offset = 0isize;
104    for (&idx, &stride) in indices.iter().zip(strides.iter()) {
105        offset += idx as isize * stride;
106    }
107    offset as usize
108}
109
110/// Converts a linear index to multi-dimensional indices.
111///
112/// # Arguments
113/// * `linear` - Linear index
114/// * `shape` - Tensor shape
115///
116/// # Returns
117/// Multi-dimensional indices.
118#[must_use]
119pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
120    let mut indices = vec![0; shape.len()];
121
122    for (i, &dim) in shape.iter().enumerate().rev() {
123        indices[i] = linear % dim;
124        linear /= dim;
125    }
126
127    indices
128}
129
130// =============================================================================
131// Broadcasting
132// =============================================================================
133
134/// Computes the broadcast shape of two shapes.
135///
136/// Broadcasting follows `NumPy` rules:
137/// 1. Shapes are aligned from the right
138/// 2. Dimensions are compatible if equal or one of them is 1
139/// 3. Missing dimensions are treated as 1
140///
141/// # Arguments
142/// * `shape1` - First shape
143/// * `shape2` - Second shape
144///
145/// # Returns
146/// Broadcast shape, or error if shapes are incompatible.
147pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
148    let max_ndim = shape1.len().max(shape2.len());
149    let mut result = Shape::with_capacity(max_ndim);
150
151    // Iterate from right to left
152    for i in 0..max_ndim {
153        let d1 = if i < shape1.len() {
154            shape1[shape1.len() - 1 - i]
155        } else {
156            1
157        };
158
159        let d2 = if i < shape2.len() {
160            shape2[shape2.len() - 1 - i]
161        } else {
162            1
163        };
164
165        if d1 == d2 {
166            result.push(d1);
167        } else if d1 == 1 {
168            result.push(d2);
169        } else if d2 == 1 {
170            result.push(d1);
171        } else {
172            return Err(Error::BroadcastError {
173                shape1: shape1.to_vec(),
174                shape2: shape2.to_vec(),
175            });
176        }
177    }
178
179    result.reverse();
180    Ok(result)
181}
182
183/// Computes broadcast strides for a shape to match a target shape.
184///
185/// # Arguments
186/// * `shape` - Original shape
187/// * `strides` - Original strides
188/// * `target_shape` - Target broadcast shape
189///
190/// # Returns
191/// New strides for broadcasting (0 stride for broadcast dimensions).
192#[must_use] pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
193    let mut result = Strides::with_capacity(target_shape.len());
194    let shape_offset = target_shape.len() - shape.len();
195
196    for (i, &target_dim) in target_shape.iter().enumerate() {
197        if i < shape_offset {
198            // Dimension doesn't exist in original - broadcast
199            result.push(0);
200        } else {
201            let orig_idx = i - shape_offset;
202            let orig_dim = shape[orig_idx];
203
204            if orig_dim == target_dim {
205                result.push(strides[orig_idx]);
206            } else if orig_dim == 1 {
207                // Broadcast dimension
208                result.push(0);
209            } else {
210                // Should not happen if broadcast_shape was computed correctly
211                result.push(strides[orig_idx]);
212            }
213        }
214    }
215
216    result
217}
218
219/// Checks if two shapes are broadcastable.
220#[must_use]
221pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
222    broadcast_shape(shape1, shape2).is_ok()
223}
224
225// =============================================================================
226// Shape Manipulation
227// =============================================================================
228
229/// Reshapes a tensor shape, validating that total elements match.
230///
231/// Supports -1 in one dimension to infer the size.
232///
233/// # Arguments
234/// * `old_shape` - Current shape
235/// * `new_shape` - Target shape (can contain -1)
236///
237/// # Returns
238/// Resolved shape, or error if incompatible.
239pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
240    let old_numel = numel(old_shape);
241    let mut result = Shape::with_capacity(new_shape.len());
242    let mut infer_idx = None;
243    let mut known_numel = 1usize;
244
245    for (i, &dim) in new_shape.iter().enumerate() {
246        if dim == -1 {
247            if infer_idx.is_some() {
248                return Err(Error::invalid_operation("Can only have one -1 in reshape"));
249            }
250            infer_idx = Some(i);
251            result.push(0); // Placeholder
252        } else if dim < 0 {
253            return Err(Error::invalid_operation("Invalid dimension in reshape"));
254        } else {
255            let d = dim as usize;
256            known_numel *= d;
257            result.push(d);
258        }
259    }
260
261    if let Some(idx) = infer_idx {
262        if old_numel % known_numel != 0 {
263            return Err(Error::invalid_operation(
264                "Cannot infer dimension: not evenly divisible",
265            ));
266        }
267        result[idx] = old_numel / known_numel;
268    } else if known_numel != old_numel {
269        return Err(Error::shape_mismatch(old_shape, &result));
270    }
271
272    Ok(result)
273}
274
275/// Computes the shape after squeezing (removing dimensions of size 1).
276///
277/// # Arguments
278/// * `shape` - Input shape
279/// * `dim` - Optional dimension to squeeze (None = all)
280///
281/// # Returns
282/// Squeezed shape.
283#[must_use]
284pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
285    match dim {
286        Some(d) => {
287            let mut result = Shape::from_slice(shape);
288            if d < shape.len() && shape[d] == 1 {
289                result.remove(d);
290            }
291            result
292        }
293        None => shape.iter().copied().filter(|&d| d != 1).collect(),
294    }
295}
296
297/// Computes the shape after unsqueezing (adding a dimension of size 1).
298///
299/// # Arguments
300/// * `shape` - Input shape
301/// * `dim` - Dimension at which to insert
302///
303/// # Returns
304/// Unsqueezed shape, or error if dim is invalid.
305pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
306    if dim > shape.len() {
307        return Err(Error::InvalidDimension {
308            index: dim as i64,
309            ndim: shape.len(),
310        });
311    }
312
313    let mut result = Shape::with_capacity(shape.len() + 1);
314    result.extend_from_slice(&shape[..dim]);
315    result.push(1);
316    result.extend_from_slice(&shape[dim..]);
317    Ok(result)
318}
319
320/// Computes the shape after transposing dimensions.
321///
322/// # Arguments
323/// * `shape` - Input shape
324/// * `dim0` - First dimension
325/// * `dim1` - Second dimension
326///
327/// # Returns
328/// Transposed shape and strides modifier.
329pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
330    if dim0 >= shape.len() || dim1 >= shape.len() {
331        return Err(Error::InvalidDimension {
332            index: dim0.max(dim1) as i64,
333            ndim: shape.len(),
334        });
335    }
336
337    let mut result = Shape::from_slice(shape);
338    result.swap(dim0, dim1);
339    Ok(result)
340}
341
342/// Swaps two stride values.
343#[must_use] pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
344    let mut result = Strides::from_slice(strides);
345    result.swap(dim0, dim1);
346    result
347}
348
349// =============================================================================
350// Validation
351// =============================================================================
352
353/// Normalizes a dimension index, supporting negative indexing.
354///
355/// # Arguments
356/// * `dim` - Dimension index (can be negative)
357/// * `ndim` - Number of dimensions
358///
359/// # Returns
360/// Normalized positive index, or error if out of bounds.
361pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
362    let ndim_i64 = ndim as i64;
363
364    let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
365
366    if normalized < 0 || normalized >= ndim_i64 {
367        return Err(Error::InvalidDimension { index: dim, ndim });
368    }
369
370    Ok(normalized as usize)
371}
372
373/// Validates that indices are within bounds for a shape.
374pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
375    if indices.len() != shape.len() {
376        return Err(Error::invalid_operation(format!(
377            "Expected {} indices, got {}",
378            shape.len(),
379            indices.len()
380        )));
381    }
382
383    for (&idx, &dim) in indices.iter().zip(shape.iter()) {
384        if idx >= dim {
385            return Err(Error::IndexOutOfBounds {
386                index: idx,
387                size: dim,
388            });
389        }
390    }
391
392    Ok(())
393}
394
395// =============================================================================
396// Tests
397// =============================================================================
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_numel() {
405        assert_eq!(numel(&[2, 3, 4]), 24);
406        assert_eq!(numel(&[]), 1);
407        assert_eq!(numel(&[5]), 5);
408    }
409
410    #[test]
411    fn test_contiguous_strides() {
412        let shape = [2, 3, 4];
413        let strides = contiguous_strides(&shape);
414        assert_eq!(strides.as_slice(), &[12, 4, 1]);
415    }
416
417    #[test]
418    fn test_is_contiguous() {
419        let shape = [2, 3];
420        let strides = contiguous_strides(&shape);
421        assert!(is_contiguous(&shape, &strides));
422
423        let non_contig_strides: Strides = smallvec::smallvec![1, 2];
424        assert!(!is_contiguous(&shape, &non_contig_strides));
425    }
426
427    #[test]
428    fn test_broadcast_shape() {
429        // Same shapes
430        assert_eq!(
431            broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
432            &[2, 3]
433        );
434
435        // Broadcasting
436        assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
437
438        assert_eq!(
439            broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
440            &[2, 3]
441        );
442
443        assert_eq!(
444            broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
445            &[5, 2, 3]
446        );
447
448        // Incompatible
449        assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
450    }
451
452    #[test]
453    fn test_reshape() {
454        let old_shape = [2, 3, 4];
455
456        // Simple reshape
457        let new = reshape(&old_shape, &[6, 4]).unwrap();
458        assert_eq!(new.as_slice(), &[6, 4]);
459
460        // With -1 inference
461        let new = reshape(&old_shape, &[-1, 4]).unwrap();
462        assert_eq!(new.as_slice(), &[6, 4]);
463
464        // Invalid
465        assert!(reshape(&old_shape, &[5, 5]).is_err());
466    }
467
468    #[test]
469    fn test_squeeze() {
470        let shape = [1, 2, 1, 3, 1];
471
472        // Squeeze all
473        let squeezed = squeeze(&shape, None);
474        assert_eq!(squeezed.as_slice(), &[2, 3]);
475
476        // Squeeze specific dimension
477        let squeezed = squeeze(&shape, Some(0));
478        assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
479    }
480
481    #[test]
482    fn test_unsqueeze() {
483        let shape = [2, 3];
484
485        let unsqueezed = unsqueeze(&shape, 0).unwrap();
486        assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
487
488        let unsqueezed = unsqueeze(&shape, 1).unwrap();
489        assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
490
491        let unsqueezed = unsqueeze(&shape, 2).unwrap();
492        assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
493    }
494
495    #[test]
496    fn test_normalize_dim() {
497        assert_eq!(normalize_dim(0, 3).unwrap(), 0);
498        assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
499        assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
500
501        assert!(normalize_dim(3, 3).is_err());
502        assert!(normalize_dim(-4, 3).is_err());
503    }
504
505    #[test]
506    fn test_linear_index() {
507        // 2x3 matrix, row-major
508        let strides: Strides = smallvec::smallvec![3, 1];
509
510        assert_eq!(linear_index(&[0, 0], &strides), 0);
511        assert_eq!(linear_index(&[0, 1], &strides), 1);
512        assert_eq!(linear_index(&[1, 0], &strides), 3);
513        assert_eq!(linear_index(&[1, 2], &strides), 5);
514    }
515
516    #[test]
517    fn test_unravel_index() {
518        let shape = [2, 3, 4];
519
520        assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
521        assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
522        assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
523        assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
524    }
525}