1use smallvec::SmallVec;
26
27use axonml_core::error::{Error, Result};
28
29pub type Shape = SmallVec<[usize; 6]>;
36
37pub type Strides = SmallVec<[isize; 6]>;
39
40#[must_use]
52pub fn numel(shape: &[usize]) -> usize {
53 shape.iter().product()
54}
55
56#[must_use]
64pub fn contiguous_strides(shape: &[usize]) -> Strides {
65 if shape.is_empty() {
66 return Strides::new();
67 }
68
69 let mut strides = Strides::with_capacity(shape.len());
70 let mut stride = 1isize;
71
72 for &dim in shape.iter().rev() {
74 strides.push(stride);
75 stride *= dim as isize;
76 }
77
78 strides.reverse();
79 strides
80}
81
82#[must_use]
91pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
92 if shape.is_empty() {
93 return true;
94 }
95
96 let expected = contiguous_strides(shape);
97 strides == expected.as_slice()
98}
99
100#[must_use]
109pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
110 debug_assert_eq!(indices.len(), strides.len());
111
112 let mut offset = 0isize;
113 for (&idx, &stride) in indices.iter().zip(strides.iter()) {
114 offset += idx as isize * stride;
115 }
116 offset as usize
117}
118
119#[must_use]
128pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
129 let mut indices = vec![0; shape.len()];
130
131 for (i, &dim) in shape.iter().enumerate().rev() {
132 indices[i] = linear % dim;
133 linear /= dim;
134 }
135
136 indices
137}
138
139pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
157 let max_ndim = shape1.len().max(shape2.len());
158 let mut result = Shape::with_capacity(max_ndim);
159
160 for i in 0..max_ndim {
162 let d1 = if i < shape1.len() {
163 shape1[shape1.len() - 1 - i]
164 } else {
165 1
166 };
167
168 let d2 = if i < shape2.len() {
169 shape2[shape2.len() - 1 - i]
170 } else {
171 1
172 };
173
174 if d1 == d2 {
175 result.push(d1);
176 } else if d1 == 1 {
177 result.push(d2);
178 } else if d2 == 1 {
179 result.push(d1);
180 } else {
181 return Err(Error::BroadcastError {
182 shape1: shape1.to_vec(),
183 shape2: shape2.to_vec(),
184 });
185 }
186 }
187
188 result.reverse();
189 Ok(result)
190}
191
192#[must_use]
202pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
203 let mut result = Strides::with_capacity(target_shape.len());
204 let shape_offset = target_shape.len() - shape.len();
205
206 for (i, &target_dim) in target_shape.iter().enumerate() {
207 if i < shape_offset {
208 result.push(0);
210 } else {
211 let orig_idx = i - shape_offset;
212 let orig_dim = shape[orig_idx];
213
214 if orig_dim == target_dim {
215 result.push(strides[orig_idx]);
216 } else if orig_dim == 1 {
217 result.push(0);
219 } else {
220 result.push(strides[orig_idx]);
222 }
223 }
224 }
225
226 result
227}
228
229#[must_use]
231pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
232 broadcast_shape(shape1, shape2).is_ok()
233}
234
235pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
250 let old_numel = numel(old_shape);
251 let mut result = Shape::with_capacity(new_shape.len());
252 let mut infer_idx = None;
253 let mut known_numel = 1usize;
254
255 for (i, &dim) in new_shape.iter().enumerate() {
256 if dim == -1 {
257 if infer_idx.is_some() {
258 return Err(Error::invalid_operation("Can only have one -1 in reshape"));
259 }
260 infer_idx = Some(i);
261 result.push(0); } else if dim < 0 {
263 return Err(Error::invalid_operation("Invalid dimension in reshape"));
264 } else {
265 let d = dim as usize;
266 known_numel *= d;
267 result.push(d);
268 }
269 }
270
271 if let Some(idx) = infer_idx {
272 if old_numel % known_numel != 0 {
273 return Err(Error::invalid_operation(
274 "Cannot infer dimension: not evenly divisible",
275 ));
276 }
277 result[idx] = old_numel / known_numel;
278 } else if known_numel != old_numel {
279 return Err(Error::shape_mismatch(old_shape, &result));
280 }
281
282 Ok(result)
283}
284
285#[must_use]
294pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
295 match dim {
296 Some(d) => {
297 let mut result = Shape::from_slice(shape);
298 if d < shape.len() && shape[d] == 1 {
299 result.remove(d);
300 }
301 result
302 }
303 None => shape.iter().copied().filter(|&d| d != 1).collect(),
304 }
305}
306
307pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
316 if dim > shape.len() {
317 return Err(Error::InvalidDimension {
318 index: dim as i64,
319 ndim: shape.len(),
320 });
321 }
322
323 let mut result = Shape::with_capacity(shape.len() + 1);
324 result.extend_from_slice(&shape[..dim]);
325 result.push(1);
326 result.extend_from_slice(&shape[dim..]);
327 Ok(result)
328}
329
330pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
340 if dim0 >= shape.len() || dim1 >= shape.len() {
341 return Err(Error::InvalidDimension {
342 index: dim0.max(dim1) as i64,
343 ndim: shape.len(),
344 });
345 }
346
347 let mut result = Shape::from_slice(shape);
348 result.swap(dim0, dim1);
349 Ok(result)
350}
351
352#[must_use]
354pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
355 let mut result = Strides::from_slice(strides);
356 result.swap(dim0, dim1);
357 result
358}
359
360pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
373 let ndim_i64 = ndim as i64;
374
375 let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
376
377 if normalized < 0 || normalized >= ndim_i64 {
378 return Err(Error::InvalidDimension { index: dim, ndim });
379 }
380
381 Ok(normalized as usize)
382}
383
384pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
386 if indices.len() != shape.len() {
387 return Err(Error::invalid_operation(format!(
388 "Expected {} indices, got {}",
389 shape.len(),
390 indices.len()
391 )));
392 }
393
394 for (&idx, &dim) in indices.iter().zip(shape.iter()) {
395 if idx >= dim {
396 return Err(Error::IndexOutOfBounds {
397 index: idx,
398 size: dim,
399 });
400 }
401 }
402
403 Ok(())
404}
405
406#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_numel() {
416 assert_eq!(numel(&[2, 3, 4]), 24);
417 assert_eq!(numel(&[]), 1);
418 assert_eq!(numel(&[5]), 5);
419 }
420
421 #[test]
422 fn test_contiguous_strides() {
423 let shape = [2, 3, 4];
424 let strides = contiguous_strides(&shape);
425 assert_eq!(strides.as_slice(), &[12, 4, 1]);
426 }
427
428 #[test]
429 fn test_is_contiguous() {
430 let shape = [2, 3];
431 let strides = contiguous_strides(&shape);
432 assert!(is_contiguous(&shape, &strides));
433
434 let non_contig_strides: Strides = smallvec::smallvec![1, 2];
435 assert!(!is_contiguous(&shape, &non_contig_strides));
436 }
437
438 #[test]
439 fn test_broadcast_shape() {
440 assert_eq!(
442 broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
443 &[2, 3]
444 );
445
446 assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
448
449 assert_eq!(
450 broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
451 &[2, 3]
452 );
453
454 assert_eq!(
455 broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
456 &[5, 2, 3]
457 );
458
459 assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
461 }
462
463 #[test]
464 fn test_reshape() {
465 let old_shape = [2, 3, 4];
466
467 let new = reshape(&old_shape, &[6, 4]).unwrap();
469 assert_eq!(new.as_slice(), &[6, 4]);
470
471 let new = reshape(&old_shape, &[-1, 4]).unwrap();
473 assert_eq!(new.as_slice(), &[6, 4]);
474
475 assert!(reshape(&old_shape, &[5, 5]).is_err());
477 }
478
479 #[test]
480 fn test_squeeze() {
481 let shape = [1, 2, 1, 3, 1];
482
483 let squeezed = squeeze(&shape, None);
485 assert_eq!(squeezed.as_slice(), &[2, 3]);
486
487 let squeezed = squeeze(&shape, Some(0));
489 assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
490 }
491
492 #[test]
493 fn test_unsqueeze() {
494 let shape = [2, 3];
495
496 let unsqueezed = unsqueeze(&shape, 0).unwrap();
497 assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
498
499 let unsqueezed = unsqueeze(&shape, 1).unwrap();
500 assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
501
502 let unsqueezed = unsqueeze(&shape, 2).unwrap();
503 assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
504 }
505
506 #[test]
507 fn test_normalize_dim() {
508 assert_eq!(normalize_dim(0, 3).unwrap(), 0);
509 assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
510 assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
511
512 assert!(normalize_dim(3, 3).is_err());
513 assert!(normalize_dim(-4, 3).is_err());
514 }
515
516 #[test]
517 fn test_linear_index() {
518 let strides: Strides = smallvec::smallvec![3, 1];
520
521 assert_eq!(linear_index(&[0, 0], &strides), 0);
522 assert_eq!(linear_index(&[0, 1], &strides), 1);
523 assert_eq!(linear_index(&[1, 0], &strides), 3);
524 assert_eq!(linear_index(&[1, 2], &strides), 5);
525 }
526
527 #[test]
528 fn test_unravel_index() {
529 let shape = [2, 3, 4];
530
531 assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
532 assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
533 assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
534 assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
535 }
536}