1use smallvec::SmallVec;
17
18use axonml_core::error::{Error, Result};
19
20pub type Shape = SmallVec<[usize; 6]>;
27
28pub type Strides = SmallVec<[isize; 6]>;
30
31#[must_use]
43pub fn numel(shape: &[usize]) -> usize {
44 shape.iter().product()
45}
46
47#[must_use]
55pub fn contiguous_strides(shape: &[usize]) -> Strides {
56 if shape.is_empty() {
57 return Strides::new();
58 }
59
60 let mut strides = Strides::with_capacity(shape.len());
61 let mut stride = 1isize;
62
63 for &dim in shape.iter().rev() {
65 strides.push(stride);
66 stride *= dim as isize;
67 }
68
69 strides.reverse();
70 strides
71}
72
73#[must_use]
82pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
83 if shape.is_empty() {
84 return true;
85 }
86
87 let expected = contiguous_strides(shape);
88 strides == expected.as_slice()
89}
90
91#[must_use]
100pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
101 debug_assert_eq!(indices.len(), strides.len());
102
103 let mut offset = 0isize;
104 for (&idx, &stride) in indices.iter().zip(strides.iter()) {
105 offset += idx as isize * stride;
106 }
107 offset as usize
108}
109
110#[must_use]
119pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
120 let mut indices = vec![0; shape.len()];
121
122 for (i, &dim) in shape.iter().enumerate().rev() {
123 indices[i] = linear % dim;
124 linear /= dim;
125 }
126
127 indices
128}
129
130pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
148 let max_ndim = shape1.len().max(shape2.len());
149 let mut result = Shape::with_capacity(max_ndim);
150
151 for i in 0..max_ndim {
153 let d1 = if i < shape1.len() {
154 shape1[shape1.len() - 1 - i]
155 } else {
156 1
157 };
158
159 let d2 = if i < shape2.len() {
160 shape2[shape2.len() - 1 - i]
161 } else {
162 1
163 };
164
165 if d1 == d2 {
166 result.push(d1);
167 } else if d1 == 1 {
168 result.push(d2);
169 } else if d2 == 1 {
170 result.push(d1);
171 } else {
172 return Err(Error::BroadcastError {
173 shape1: shape1.to_vec(),
174 shape2: shape2.to_vec(),
175 });
176 }
177 }
178
179 result.reverse();
180 Ok(result)
181}
182
183#[must_use]
193pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
194 let mut result = Strides::with_capacity(target_shape.len());
195 let shape_offset = target_shape.len() - shape.len();
196
197 for (i, &target_dim) in target_shape.iter().enumerate() {
198 if i < shape_offset {
199 result.push(0);
201 } else {
202 let orig_idx = i - shape_offset;
203 let orig_dim = shape[orig_idx];
204
205 if orig_dim == target_dim {
206 result.push(strides[orig_idx]);
207 } else if orig_dim == 1 {
208 result.push(0);
210 } else {
211 result.push(strides[orig_idx]);
213 }
214 }
215 }
216
217 result
218}
219
220#[must_use]
222pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
223 broadcast_shape(shape1, shape2).is_ok()
224}
225
226pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
241 let old_numel = numel(old_shape);
242 let mut result = Shape::with_capacity(new_shape.len());
243 let mut infer_idx = None;
244 let mut known_numel = 1usize;
245
246 for (i, &dim) in new_shape.iter().enumerate() {
247 if dim == -1 {
248 if infer_idx.is_some() {
249 return Err(Error::invalid_operation("Can only have one -1 in reshape"));
250 }
251 infer_idx = Some(i);
252 result.push(0); } else if dim < 0 {
254 return Err(Error::invalid_operation("Invalid dimension in reshape"));
255 } else {
256 let d = dim as usize;
257 known_numel *= d;
258 result.push(d);
259 }
260 }
261
262 if let Some(idx) = infer_idx {
263 if old_numel % known_numel != 0 {
264 return Err(Error::invalid_operation(
265 "Cannot infer dimension: not evenly divisible",
266 ));
267 }
268 result[idx] = old_numel / known_numel;
269 } else if known_numel != old_numel {
270 return Err(Error::shape_mismatch(old_shape, &result));
271 }
272
273 Ok(result)
274}
275
276#[must_use]
285pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
286 match dim {
287 Some(d) => {
288 let mut result = Shape::from_slice(shape);
289 if d < shape.len() && shape[d] == 1 {
290 result.remove(d);
291 }
292 result
293 }
294 None => shape.iter().copied().filter(|&d| d != 1).collect(),
295 }
296}
297
298pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
307 if dim > shape.len() {
308 return Err(Error::InvalidDimension {
309 index: dim as i64,
310 ndim: shape.len(),
311 });
312 }
313
314 let mut result = Shape::with_capacity(shape.len() + 1);
315 result.extend_from_slice(&shape[..dim]);
316 result.push(1);
317 result.extend_from_slice(&shape[dim..]);
318 Ok(result)
319}
320
321pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
331 if dim0 >= shape.len() || dim1 >= shape.len() {
332 return Err(Error::InvalidDimension {
333 index: dim0.max(dim1) as i64,
334 ndim: shape.len(),
335 });
336 }
337
338 let mut result = Shape::from_slice(shape);
339 result.swap(dim0, dim1);
340 Ok(result)
341}
342
343#[must_use]
345pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
346 let mut result = Strides::from_slice(strides);
347 result.swap(dim0, dim1);
348 result
349}
350
351pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
364 let ndim_i64 = ndim as i64;
365
366 let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
367
368 if normalized < 0 || normalized >= ndim_i64 {
369 return Err(Error::InvalidDimension { index: dim, ndim });
370 }
371
372 Ok(normalized as usize)
373}
374
375pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
377 if indices.len() != shape.len() {
378 return Err(Error::invalid_operation(format!(
379 "Expected {} indices, got {}",
380 shape.len(),
381 indices.len()
382 )));
383 }
384
385 for (&idx, &dim) in indices.iter().zip(shape.iter()) {
386 if idx >= dim {
387 return Err(Error::IndexOutOfBounds {
388 index: idx,
389 size: dim,
390 });
391 }
392 }
393
394 Ok(())
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_numel() {
407 assert_eq!(numel(&[2, 3, 4]), 24);
408 assert_eq!(numel(&[]), 1);
409 assert_eq!(numel(&[5]), 5);
410 }
411
412 #[test]
413 fn test_contiguous_strides() {
414 let shape = [2, 3, 4];
415 let strides = contiguous_strides(&shape);
416 assert_eq!(strides.as_slice(), &[12, 4, 1]);
417 }
418
419 #[test]
420 fn test_is_contiguous() {
421 let shape = [2, 3];
422 let strides = contiguous_strides(&shape);
423 assert!(is_contiguous(&shape, &strides));
424
425 let non_contig_strides: Strides = smallvec::smallvec![1, 2];
426 assert!(!is_contiguous(&shape, &non_contig_strides));
427 }
428
429 #[test]
430 fn test_broadcast_shape() {
431 assert_eq!(
433 broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
434 &[2, 3]
435 );
436
437 assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
439
440 assert_eq!(
441 broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
442 &[2, 3]
443 );
444
445 assert_eq!(
446 broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
447 &[5, 2, 3]
448 );
449
450 assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
452 }
453
454 #[test]
455 fn test_reshape() {
456 let old_shape = [2, 3, 4];
457
458 let new = reshape(&old_shape, &[6, 4]).unwrap();
460 assert_eq!(new.as_slice(), &[6, 4]);
461
462 let new = reshape(&old_shape, &[-1, 4]).unwrap();
464 assert_eq!(new.as_slice(), &[6, 4]);
465
466 assert!(reshape(&old_shape, &[5, 5]).is_err());
468 }
469
470 #[test]
471 fn test_squeeze() {
472 let shape = [1, 2, 1, 3, 1];
473
474 let squeezed = squeeze(&shape, None);
476 assert_eq!(squeezed.as_slice(), &[2, 3]);
477
478 let squeezed = squeeze(&shape, Some(0));
480 assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
481 }
482
483 #[test]
484 fn test_unsqueeze() {
485 let shape = [2, 3];
486
487 let unsqueezed = unsqueeze(&shape, 0).unwrap();
488 assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
489
490 let unsqueezed = unsqueeze(&shape, 1).unwrap();
491 assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
492
493 let unsqueezed = unsqueeze(&shape, 2).unwrap();
494 assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
495 }
496
497 #[test]
498 fn test_normalize_dim() {
499 assert_eq!(normalize_dim(0, 3).unwrap(), 0);
500 assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
501 assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
502
503 assert!(normalize_dim(3, 3).is_err());
504 assert!(normalize_dim(-4, 3).is_err());
505 }
506
507 #[test]
508 fn test_linear_index() {
509 let strides: Strides = smallvec::smallvec![3, 1];
511
512 assert_eq!(linear_index(&[0, 0], &strides), 0);
513 assert_eq!(linear_index(&[0, 1], &strides), 1);
514 assert_eq!(linear_index(&[1, 0], &strides), 3);
515 assert_eq!(linear_index(&[1, 2], &strides), 5);
516 }
517
518 #[test]
519 fn test_unravel_index() {
520 let shape = [2, 3, 4];
521
522 assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
523 assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
524 assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
525 assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
526 }
527}