1use smallvec::SmallVec;
17
18use axonml_core::error::{Error, Result};
19
20pub type Shape = SmallVec<[usize; 6]>;
27
28pub type Strides = SmallVec<[isize; 6]>;
30
31#[must_use]
43pub fn numel(shape: &[usize]) -> usize {
44 shape.iter().product()
45}
46
47#[must_use]
55pub fn contiguous_strides(shape: &[usize]) -> Strides {
56 if shape.is_empty() {
57 return Strides::new();
58 }
59
60 let mut strides = Strides::with_capacity(shape.len());
61 let mut stride = 1isize;
62
63 for &dim in shape.iter().rev() {
65 strides.push(stride);
66 stride *= dim as isize;
67 }
68
69 strides.reverse();
70 strides
71}
72
73#[must_use]
82pub fn is_contiguous(shape: &[usize], strides: &[isize]) -> bool {
83 if shape.is_empty() {
84 return true;
85 }
86
87 let expected = contiguous_strides(shape);
88 strides == expected.as_slice()
89}
90
91#[must_use]
100pub fn linear_index(indices: &[usize], strides: &[isize]) -> usize {
101 debug_assert_eq!(indices.len(), strides.len());
102
103 let mut offset = 0isize;
104 for (&idx, &stride) in indices.iter().zip(strides.iter()) {
105 offset += idx as isize * stride;
106 }
107 offset as usize
108}
109
110#[must_use]
119pub fn unravel_index(mut linear: usize, shape: &[usize]) -> Vec<usize> {
120 let mut indices = vec![0; shape.len()];
121
122 for (i, &dim) in shape.iter().enumerate().rev() {
123 indices[i] = linear % dim;
124 linear /= dim;
125 }
126
127 indices
128}
129
130pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Shape> {
148 let max_ndim = shape1.len().max(shape2.len());
149 let mut result = Shape::with_capacity(max_ndim);
150
151 for i in 0..max_ndim {
153 let d1 = if i < shape1.len() {
154 shape1[shape1.len() - 1 - i]
155 } else {
156 1
157 };
158
159 let d2 = if i < shape2.len() {
160 shape2[shape2.len() - 1 - i]
161 } else {
162 1
163 };
164
165 if d1 == d2 {
166 result.push(d1);
167 } else if d1 == 1 {
168 result.push(d2);
169 } else if d2 == 1 {
170 result.push(d1);
171 } else {
172 return Err(Error::BroadcastError {
173 shape1: shape1.to_vec(),
174 shape2: shape2.to_vec(),
175 });
176 }
177 }
178
179 result.reverse();
180 Ok(result)
181}
182
183#[must_use] pub fn broadcast_strides(shape: &[usize], strides: &[isize], target_shape: &[usize]) -> Strides {
193 let mut result = Strides::with_capacity(target_shape.len());
194 let shape_offset = target_shape.len() - shape.len();
195
196 for (i, &target_dim) in target_shape.iter().enumerate() {
197 if i < shape_offset {
198 result.push(0);
200 } else {
201 let orig_idx = i - shape_offset;
202 let orig_dim = shape[orig_idx];
203
204 if orig_dim == target_dim {
205 result.push(strides[orig_idx]);
206 } else if orig_dim == 1 {
207 result.push(0);
209 } else {
210 result.push(strides[orig_idx]);
212 }
213 }
214 }
215
216 result
217}
218
219#[must_use]
221pub fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
222 broadcast_shape(shape1, shape2).is_ok()
223}
224
225pub fn reshape(old_shape: &[usize], new_shape: &[isize]) -> Result<Shape> {
240 let old_numel = numel(old_shape);
241 let mut result = Shape::with_capacity(new_shape.len());
242 let mut infer_idx = None;
243 let mut known_numel = 1usize;
244
245 for (i, &dim) in new_shape.iter().enumerate() {
246 if dim == -1 {
247 if infer_idx.is_some() {
248 return Err(Error::invalid_operation("Can only have one -1 in reshape"));
249 }
250 infer_idx = Some(i);
251 result.push(0); } else if dim < 0 {
253 return Err(Error::invalid_operation("Invalid dimension in reshape"));
254 } else {
255 let d = dim as usize;
256 known_numel *= d;
257 result.push(d);
258 }
259 }
260
261 if let Some(idx) = infer_idx {
262 if old_numel % known_numel != 0 {
263 return Err(Error::invalid_operation(
264 "Cannot infer dimension: not evenly divisible",
265 ));
266 }
267 result[idx] = old_numel / known_numel;
268 } else if known_numel != old_numel {
269 return Err(Error::shape_mismatch(old_shape, &result));
270 }
271
272 Ok(result)
273}
274
275#[must_use]
284pub fn squeeze(shape: &[usize], dim: Option<usize>) -> Shape {
285 match dim {
286 Some(d) => {
287 let mut result = Shape::from_slice(shape);
288 if d < shape.len() && shape[d] == 1 {
289 result.remove(d);
290 }
291 result
292 }
293 None => shape.iter().copied().filter(|&d| d != 1).collect(),
294 }
295}
296
297pub fn unsqueeze(shape: &[usize], dim: usize) -> Result<Shape> {
306 if dim > shape.len() {
307 return Err(Error::InvalidDimension {
308 index: dim as i64,
309 ndim: shape.len(),
310 });
311 }
312
313 let mut result = Shape::with_capacity(shape.len() + 1);
314 result.extend_from_slice(&shape[..dim]);
315 result.push(1);
316 result.extend_from_slice(&shape[dim..]);
317 Ok(result)
318}
319
320pub fn transpose_shape(shape: &[usize], dim0: usize, dim1: usize) -> Result<Shape> {
330 if dim0 >= shape.len() || dim1 >= shape.len() {
331 return Err(Error::InvalidDimension {
332 index: dim0.max(dim1) as i64,
333 ndim: shape.len(),
334 });
335 }
336
337 let mut result = Shape::from_slice(shape);
338 result.swap(dim0, dim1);
339 Ok(result)
340}
341
342#[must_use] pub fn transpose_strides(strides: &[isize], dim0: usize, dim1: usize) -> Strides {
344 let mut result = Strides::from_slice(strides);
345 result.swap(dim0, dim1);
346 result
347}
348
349pub fn normalize_dim(dim: i64, ndim: usize) -> Result<usize> {
362 let ndim_i64 = ndim as i64;
363
364 let normalized = if dim < 0 { dim + ndim_i64 } else { dim };
365
366 if normalized < 0 || normalized >= ndim_i64 {
367 return Err(Error::InvalidDimension { index: dim, ndim });
368 }
369
370 Ok(normalized as usize)
371}
372
373pub fn validate_indices(indices: &[usize], shape: &[usize]) -> Result<()> {
375 if indices.len() != shape.len() {
376 return Err(Error::invalid_operation(format!(
377 "Expected {} indices, got {}",
378 shape.len(),
379 indices.len()
380 )));
381 }
382
383 for (&idx, &dim) in indices.iter().zip(shape.iter()) {
384 if idx >= dim {
385 return Err(Error::IndexOutOfBounds {
386 index: idx,
387 size: dim,
388 });
389 }
390 }
391
392 Ok(())
393}
394
395#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_numel() {
405 assert_eq!(numel(&[2, 3, 4]), 24);
406 assert_eq!(numel(&[]), 1);
407 assert_eq!(numel(&[5]), 5);
408 }
409
410 #[test]
411 fn test_contiguous_strides() {
412 let shape = [2, 3, 4];
413 let strides = contiguous_strides(&shape);
414 assert_eq!(strides.as_slice(), &[12, 4, 1]);
415 }
416
417 #[test]
418 fn test_is_contiguous() {
419 let shape = [2, 3];
420 let strides = contiguous_strides(&shape);
421 assert!(is_contiguous(&shape, &strides));
422
423 let non_contig_strides: Strides = smallvec::smallvec![1, 2];
424 assert!(!is_contiguous(&shape, &non_contig_strides));
425 }
426
427 #[test]
428 fn test_broadcast_shape() {
429 assert_eq!(
431 broadcast_shape(&[2, 3], &[2, 3]).unwrap().as_slice(),
432 &[2, 3]
433 );
434
435 assert_eq!(broadcast_shape(&[2, 3], &[3]).unwrap().as_slice(), &[2, 3]);
437
438 assert_eq!(
439 broadcast_shape(&[2, 1], &[1, 3]).unwrap().as_slice(),
440 &[2, 3]
441 );
442
443 assert_eq!(
444 broadcast_shape(&[5, 1, 3], &[2, 3]).unwrap().as_slice(),
445 &[5, 2, 3]
446 );
447
448 assert!(broadcast_shape(&[2, 3], &[2, 4]).is_err());
450 }
451
452 #[test]
453 fn test_reshape() {
454 let old_shape = [2, 3, 4];
455
456 let new = reshape(&old_shape, &[6, 4]).unwrap();
458 assert_eq!(new.as_slice(), &[6, 4]);
459
460 let new = reshape(&old_shape, &[-1, 4]).unwrap();
462 assert_eq!(new.as_slice(), &[6, 4]);
463
464 assert!(reshape(&old_shape, &[5, 5]).is_err());
466 }
467
468 #[test]
469 fn test_squeeze() {
470 let shape = [1, 2, 1, 3, 1];
471
472 let squeezed = squeeze(&shape, None);
474 assert_eq!(squeezed.as_slice(), &[2, 3]);
475
476 let squeezed = squeeze(&shape, Some(0));
478 assert_eq!(squeezed.as_slice(), &[2, 1, 3, 1]);
479 }
480
481 #[test]
482 fn test_unsqueeze() {
483 let shape = [2, 3];
484
485 let unsqueezed = unsqueeze(&shape, 0).unwrap();
486 assert_eq!(unsqueezed.as_slice(), &[1, 2, 3]);
487
488 let unsqueezed = unsqueeze(&shape, 1).unwrap();
489 assert_eq!(unsqueezed.as_slice(), &[2, 1, 3]);
490
491 let unsqueezed = unsqueeze(&shape, 2).unwrap();
492 assert_eq!(unsqueezed.as_slice(), &[2, 3, 1]);
493 }
494
495 #[test]
496 fn test_normalize_dim() {
497 assert_eq!(normalize_dim(0, 3).unwrap(), 0);
498 assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
499 assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
500
501 assert!(normalize_dim(3, 3).is_err());
502 assert!(normalize_dim(-4, 3).is_err());
503 }
504
505 #[test]
506 fn test_linear_index() {
507 let strides: Strides = smallvec::smallvec![3, 1];
509
510 assert_eq!(linear_index(&[0, 0], &strides), 0);
511 assert_eq!(linear_index(&[0, 1], &strides), 1);
512 assert_eq!(linear_index(&[1, 0], &strides), 3);
513 assert_eq!(linear_index(&[1, 2], &strides), 5);
514 }
515
516 #[test]
517 fn test_unravel_index() {
518 let shape = [2, 3, 4];
519
520 assert_eq!(unravel_index(0, &shape), vec![0, 0, 0]);
521 assert_eq!(unravel_index(1, &shape), vec![0, 0, 1]);
522 assert_eq!(unravel_index(4, &shape), vec![0, 1, 0]);
523 assert_eq!(unravel_index(12, &shape), vec![1, 0, 0]);
524 }
525}