1use smallvec::SmallVec;
18
19use axonml_core::error::{Error, Result};
20
21pub type Shape = SmallVec<[usize; 6]>;
28
29pub type Strides = SmallVec<[isize; 6]>;
31
32#[must_use]
44pub fn numel(shape: &[usize]) -> usize {
45 shape.iter().product()
46}
47
48#[must_use]
56pub fn contiguous_strides(shape: &[usize]) -> Strides {
57 if shape.is_empty() {
58 return Strides::new();
59 }
60
61 let mut strides = Strides::with_capacity(shape.len());
62 let mut stride = 1isize;
63
64 for &dim in shape.iter().rev() {
66 strides.push(stride);
67 stride *= dim as isize;
68 }
69
70 strides.reverse();
71 strides
72}
73
74#[must_use]
83pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
84 if shape.is_empty() {
85 return true;
86 }
87
88 let expected = contiguous_strides(shape);
89 strides == expected.as_slice()
90}
91
92#[must_use]
101pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
102 debug_assert_eq!(indices.len(), strides.len());
103
104 let mut offset = 0isize;
105 for (&idx, &stride) in indices.iter().zip(strides.iter()) {
106 offset += idx as isize * stride;
107 }
108 offset as usize
109}
110
111#[must_use]
120pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
121 let mut indices = vec![0; shape.len()];
122
123 for (i, &dim) in shape.iter().enumerate().rev() {
124 indices[i] = linear % dim;
125 linear /= dim;
126 }
127
128 indices
129}
130
131pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
149 let max_ndim = shape1.len().max(shape2.len());
150 let mut result = Shape::with_capacity(max_ndim);
151
152 for i in 0..max_ndim {
154 let d1 = if i < shape1.len() {
155 shape1[shape1.len() - 1 - i]
156 } else {
157 1
158 };
159
160 let d2 = if i < shape2.len() {
161 shape2[shape2.len() - 1 - i]
162 } else {
163 1
164 };
165
166 if d1 == d2 {
167 result.push(d1);
168 } else if d1 == 1 {
169 result.push(d2);
170 } else if d2 == 1 {
171 result.push(d1);
172 } else {
173 return Err(Error::BroadcastError {
174 shape1: shape1.to_vec(),
175 shape2: shape2.to_vec(),
176 });
177 }
178 }
179
180 result.reverse();
181 Ok(result)
182}
183
184#[must_use]
194pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
195 let mut result = Strides::with_capacity(target_shape.len());
196 let shape_offset = target_shape.len() - shape.len();
197
198 for (i, &target_dim) in target_shape.iter().enumerate() {
199 if i < shape_offset {
200 result.push(0);
202 } else {
203 let orig_idx = i - shape_offset;
204 let orig_dim = shape[orig_idx];
205
206 if orig_dim == target_dim {
207 result.push(strides[orig_idx]);
208 } else if orig_dim == 1 {
209 result.push(0);
211 } else {
212 result.push(strides[orig_idx]);
214 }
215 }
216 }
217
218 result
219}
220
221#[must_use]
223pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
224 broadcast_shape(shape1, shape2).is_ok()
225}
226
227pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
242 let old_numel = numel(old_shape);
243 let mut result = Shape::with_capacity(new_shape.len());
244 let mut infer_idx = None;
245 let mut known_numel = 1usize;
246
247 for (i, &dim) in new_shape.iter().enumerate() {
248 if dim == -1 {
249 if infer_idx.is_some() {
250 return Err(Error::invalid_operation("Can only have one -1 in reshape"));
251 }
252 infer_idx = Some(i);
253 result.push(0); } else if dim < 0 {
255 return Err(Error::invalid_operation("Invalid dimension in reshape"));
256 } else {
257 let d = dim as usize;
258 known_numel *= d;
259 result.push(d);
260 }
261 }
262
263 if let Some(idx) = infer_idx {
264 if old_numel % known_numel != 0 {
265 return Err(Error::invalid_operation(
266 "Cannot infer dimension: not evenly divisible",
267 ));
268 }
269 result[idx] = old_numel / known_numel;
270 } else if known_numel != old_numel {
271 return Err(Error::shape_mismatch(old_shape, &result));
272 }
273
274 Ok(result)
275}
276
277#[must_use]
286pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
287 match dim {
288 Some(d) => {
289 let mut result = Shape::from_slice(shape);
290 if d < shape.len() && shape[d] == 1 {
291 result.remove(d);
292 }
293 result
294 }
295 None => shape.iter().copied().filter(|&d| d != 1).collect(),
296 }
297}
298
299pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
308 if dim > shape.len() {
309 return Err(Error::InvalidDimension {
310 index: dim as i64,
311 ndim: shape.len(),
312 });
313 }
314
315 let mut result = Shape::with_capacity(shape.len() + 1);
316 result.extend_from_slice(&shape[..dim]);
317 result.push(1);
318 result.extend_from_slice(&shape[dim..]);
319 Ok(result)
320}
321
322pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
332 if dim0 >= shape.len() || dim1 >= shape.len() {
333 return Err(Error::InvalidDimension {
334 index: dim0.max(dim1) as i64,
335 ndim: shape.len(),
336 });
337 }
338
339 let mut result = Shape::from_slice(shape);
340 result.swap(dim0, dim1);
341 Ok(result)
342}
343
344#[must_use]
346pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
347 let mut result = Strides::from_slice(strides);
348 result.swap(dim0, dim1);
349 result
350}
351
352pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
365 let ndim_i64 = ndim as i64;
366
367 let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
368
369 if normalized < 0 || normalized >= ndim_i64 {
370 return Err(Error::InvalidDimension { index: dim, ndim });
371 }
372
373 Ok(normalized as usize)
374}
375
376pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
378 if indices.len() != shape.len() {
379 return Err(Error::invalid_operation(format!(
380 "Expected {} indices, got {}",
381 shape.len(),
382 indices.len()
383 )));
384 }
385
386 for (&idx, &dim) in indices.iter().zip(shape.iter()) {
387 if idx >= dim {
388 return Err(Error::IndexOutOfBounds {
389 index: idx,
390 size: dim,
391 });
392 }
393 }
394
395 Ok(())
396}
397
398#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_numel() {
408 assert_eq!(numel(&[2, 3, 4]), 24);
409 assert_eq!(numel(&[]), 1);
410 assert_eq!(numel(&[5]), 5);
411 }
412
413 #[test]
414 fn test_contiguous_strides() {
415 let shape = [2, 3, 4];
416 let strides = contiguous_strides(&shape);
417 assert_eq!(strides.as_slice(), &[12, 4, 1]);
418 }
419
420 #[test]
421 fn test_is_contiguous() {
422 let shape = [2, 3];
423 let strides = contiguous_strides(&shape);
424 assert!(is_contiguous(&shape, &strides));
425
426 let non_contig_strides: Strides = smallvec::smallvec![1, 2];
427 assert!(!is_contiguous(&shape, &non_contig_strides));
428 }
429
430 #[test]
431 fn test_broadcast_shape() {
432 assert_eq!(
434 broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
435 &[2, 3]
436 );
437
438 assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
440
441 assert_eq!(
442 broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
443 &[2, 3]
444 );
445
446 assert_eq!(
447 broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
448 &[5, 2, 3]
449 );
450
451 assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
453 }
454
455 #[test]
456 fn test_reshape() {
457 let old_shape = [2, 3, 4];
458
459 let new = reshape(&old_shape, &[6, 4]).unwrap();
461 assert_eq!(new.as_slice(), &[6, 4]);
462
463 let new = reshape(&old_shape, &[-1, 4]).unwrap();
465 assert_eq!(new.as_slice(), &[6, 4]);
466
467 assert!(reshape(&old_shape, &[5, 5]).is_err());
469 }
470
471 #[test]
472 fn test_squeeze() {
473 let shape = [1, 2, 1, 3, 1];
474
475 let squeezed = squeeze(&shape, None);
477 assert_eq!(squeezed.as_slice(), &[2, 3]);
478
479 let squeezed = squeeze(&shape, Some(0));
481 assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
482 }
483
484 #[test]
485 fn test_unsqueeze() {
486 let shape = [2, 3];
487
488 let unsqueezed = unsqueeze(&shape, 0).unwrap();
489 assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
490
491 let unsqueezed = unsqueeze(&shape, 1).unwrap();
492 assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
493
494 let unsqueezed = unsqueeze(&shape, 2).unwrap();
495 assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
496 }
497
498 #[test]
499 fn test_normalize_dim() {
500 assert_eq!(normalize_dim(0, 3).unwrap(), 0);
501 assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
502 assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
503
504 assert!(normalize_dim(3, 3).is_err());
505 assert!(normalize_dim(-4, 3).is_err());
506 }
507
508 #[test]
509 fn test_linear_index() {
510 let strides: Strides = smallvec::smallvec![3, 1];
512
513 assert_eq!(linear_index(&[0, 0], &strides), 0);
514 assert_eq!(linear_index(&[0, 1], &strides), 1);
515 assert_eq!(linear_index(&[1, 0], &strides), 3);
516 assert_eq!(linear_index(&[1, 2], &strides), 5);
517 }
518
519 #[test]
520 fn test_unravel_index() {
521 let shape = [2, 3, 4];
522
523 assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
524 assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
525 assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
526 assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
527 }
528}