1use crate::Shape;
4use crate::indexing::AsIndex;
5use alloc::format;
6use alloc::vec::Vec;
7use core::fmt::{Display, Formatter};
8use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
9use core::str::FromStr;
10
11pub trait SliceArg {
15 fn into_slices(self, shape: &Shape) -> Vec<Slice>;
19}
20
21impl<S: Into<Slice> + Clone> SliceArg for &[S] {
22 fn into_slices(self, shape: &Shape) -> Vec<Slice> {
23 assert!(
24 self.len() <= shape.num_dims(),
25 "Too many slices provided for shape, got {} but expected at most {}",
26 self.len(),
27 shape.num_dims()
28 );
29
30 shape
31 .iter()
32 .enumerate()
33 .map(|(i, dim_size)| {
34 let slice = if i >= self.len() {
35 Slice::full()
36 } else {
37 self[i].clone().into()
38 };
39 let clamped_range = slice.to_range(*dim_size);
41 Slice::new(
42 clamped_range.start as isize,
43 Some(clamped_range.end as isize),
44 slice.step(),
45 )
46 })
47 .collect::<Vec<_>>()
48 }
49}
50
51impl SliceArg for &Vec<Slice> {
52 fn into_slices(self, shape: &Shape) -> Vec<Slice> {
53 self.as_slice().into_slices(shape)
54 }
55}
56
57impl<const R: usize, T> SliceArg for [T; R]
58where
59 T: Into<Slice> + Clone,
60{
61 fn into_slices(self, shape: &Shape) -> Vec<Slice> {
62 self.as_slice().into_slices(shape)
63 }
64}
65
66impl<T> SliceArg for T
67where
68 T: Into<Slice>,
69{
70 fn into_slices(self, shape: &Shape) -> Vec<Slice> {
71 let slice: Slice = self.into();
72 [slice].as_slice().into_slices(shape)
73 }
74}
75
76#[macro_export]
222macro_rules! s {
223 [] => {
225 compile_error!("Empty slice specification")
226 };
227
228 [$range:expr; $step:expr] => {
230 {
231 #[allow(clippy::reversed_empty_ranges)]
232 {
233 $crate::tensor::Slice::from_range_stepped($range, $step)
234 }
235 }
236 };
237
238 [$range:expr] => {
240 {
241 #[allow(clippy::reversed_empty_ranges)]
242 {
243 $crate::tensor::Slice::from($range)
244 }
245 }
246 };
247
248 [$range:expr; $step:expr, $($rest:tt)*] => {
250 {
251 #[allow(clippy::reversed_empty_ranges)]
252 {
253 $crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
254 }
255 }
256 };
257
258 [$range:expr, $($rest:tt)*] => {
260 {
261 #[allow(clippy::reversed_empty_ranges)]
262 {
263 $crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
264 }
265 }
266 };
267
268 (@internal [$($acc:expr),*]) => {
270 [$($acc),*]
271 };
272
273 (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
275 $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
276 };
277
278 (@internal [$($acc:expr),*] $range:expr; $step:expr) => {
280 $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
281 };
282
283 (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
285 $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
286 };
287
288 (@internal [$($acc:expr),*] $range:expr) => {
290 $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
291 };
292}
293
294#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
337pub struct Slice {
338 pub start: isize,
340 pub end: Option<isize>,
342 pub step: isize,
344}
345
346#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
348pub struct SliceIter {
349 slice: Slice,
350 current: isize,
351}
352
353impl Iterator for SliceIter {
354 type Item = isize;
355
356 fn next(&mut self) -> Option<Self::Item> {
357 let next = self.current;
358 self.current += self.slice.step;
359
360 if let Some(end) = self.slice.end {
361 if self.slice.is_reversed() {
362 if next <= end {
363 return None;
364 }
365 } else if next >= end {
366 return None;
367 }
368 }
369
370 Some(next)
371 }
372}
373
374impl IntoIterator for Slice {
376 type Item = isize;
377 type IntoIter = SliceIter;
378
379 fn into_iter(self) -> Self::IntoIter {
380 SliceIter {
381 slice: self,
382 current: self.start,
383 }
384 }
385}
386
387impl Default for Slice {
388 fn default() -> Self {
389 Self::full()
390 }
391}
392
393impl Slice {
394 pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
396 assert!(step != 0, "Step cannot be zero");
397 Self { start, end, step }
398 }
399
400 pub const fn full() -> Self {
402 Self::new(0, None, 1)
403 }
404
405 pub fn index(idx: isize) -> Self {
407 Self {
408 start: idx,
409 end: handle_signed_inclusive_end(idx),
410 step: 1,
411 }
412 }
413
414 pub fn into_vec(self) -> Vec<isize> {
416 assert!(
417 self.end.is_some(),
418 "Slice must have an end to convert to a vector: {self:?}"
419 );
420 self.into_iter().collect()
421 }
422
423 pub fn bound_to(self, size: usize) -> Self {
442 let mut bounds = size as isize;
443
444 if let Some(end) = self.end {
445 if end > 0 {
446 bounds = end.min(bounds);
447 } else {
448 bounds = end.max(-(bounds + 1));
449 }
450 } else if self.is_reversed() {
451 bounds = -(bounds + 1);
452 }
453
454 Self {
455 end: Some(bounds),
456 ..self
457 }
458 }
459
460 pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
462 assert!(step != 0, "Step cannot be zero");
463 Self { start, end, step }
464 }
465
466 pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
468 assert!(step != 0, "Step cannot be zero");
469 let mut slice = range.into();
470 slice.step = step;
471 slice
472 }
473
474 pub fn step(&self) -> isize {
476 self.step
477 }
478
479 pub fn range(&self, size: usize) -> Range<usize> {
481 self.to_range(size)
482 }
483
484 pub fn to_range(&self, size: usize) -> Range<usize> {
494 let start = convert_signed_index(self.start, size);
497 let end = match self.end {
498 Some(end) => convert_signed_index(end, size),
499 None => size,
500 };
501 start..end
502 }
503
504 pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
506 let range = self.to_range(size);
507 (range, self.step)
508 }
509
510 pub fn is_reversed(&self) -> bool {
512 self.step < 0
513 }
514
515 pub fn output_size(&self, dim_size: usize) -> usize {
517 let range = self.to_range(dim_size);
518 if range.start >= range.end {
520 return 0;
521 }
522 let len = range.end - range.start;
523 if self.step.unsigned_abs() == 1 {
524 len
525 } else {
526 len.div_ceil(self.step.unsigned_abs())
527 }
528 }
529}
530
531fn convert_signed_index(index: isize, size: usize) -> usize {
532 if index < 0 {
533 (size as isize + index).max(0) as usize
534 } else {
535 (index as usize).min(size)
536 }
537}
538
539fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
540 match end {
541 -1 => None,
542 end => Some(end + 1),
543 }
544}
545
546impl<I: AsIndex> From<Range<I>> for Slice {
547 fn from(r: Range<I>) -> Self {
548 Self {
549 start: r.start.index(),
550 end: Some(r.end.index()),
551 step: 1,
552 }
553 }
554}
555
556impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
557 fn from(r: RangeInclusive<I>) -> Self {
558 Self {
559 start: (*r.start()).index(),
560 end: handle_signed_inclusive_end((*r.end()).index()),
561 step: 1,
562 }
563 }
564}
565
566impl<I: AsIndex> From<RangeFrom<I>> for Slice {
567 fn from(r: RangeFrom<I>) -> Self {
568 Self {
569 start: r.start.index(),
570 end: None,
571 step: 1,
572 }
573 }
574}
575
576impl<I: AsIndex> From<RangeTo<I>> for Slice {
577 fn from(r: RangeTo<I>) -> Self {
578 Self {
579 start: 0,
580 end: Some(r.end.index()),
581 step: 1,
582 }
583 }
584}
585
586impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
587 fn from(r: RangeToInclusive<I>) -> Self {
588 Self {
589 start: 0,
590 end: handle_signed_inclusive_end(r.end.index()),
591 step: 1,
592 }
593 }
594}
595
596impl From<RangeFull> for Slice {
597 fn from(_: RangeFull) -> Self {
598 Self {
599 start: 0,
600 end: None,
601 step: 1,
602 }
603 }
604}
605
606impl From<usize> for Slice {
607 fn from(i: usize) -> Self {
608 Slice::index(i as isize)
609 }
610}
611
612impl From<isize> for Slice {
613 fn from(i: isize) -> Self {
614 Slice::index(i)
615 }
616}
617
618impl From<i32> for Slice {
619 fn from(i: i32) -> Self {
620 Slice::index(i as isize)
621 }
622}
623
624impl From<i64> for Slice {
625 fn from(i: i64) -> Self {
626 Slice::index(i as isize)
627 }
628}
629
630impl Display for Slice {
631 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
632 if self.step == 1
633 && let Some(end) = self.end
634 && self.start == end - 1
635 {
636 f.write_fmt(format_args!("{}", self.start))
637 } else {
638 if self.start != 0 {
639 f.write_fmt(format_args!("{}", self.start))?;
640 }
641 f.write_str("..")?;
642 if let Some(end) = self.end {
643 f.write_fmt(format_args!("{}", end))?;
644 }
645 if self.step != 1 {
646 f.write_fmt(format_args!(";{}", self.step))?;
647 }
648 Ok(())
649 }
650 }
651}
652
653impl FromStr for Slice {
654 type Err = crate::ExpressionError;
655
656 fn from_str(source: &str) -> Result<Self, Self::Err> {
657 let mut s = source.trim();
658
659 let parse_int = |v: &str| -> Result<isize, Self::Err> {
660 v.parse::<isize>().map_err(|e| {
661 crate::ExpressionError::parse_error(
662 format!("Invalid integer: '{v}': {}", e),
663 source,
664 )
665 })
666 };
667
668 let mut start: isize = 0;
669 let mut end: Option<isize> = None;
670 let mut step: isize = 1;
671
672 if let Some((head, tail)) = s.split_once(";") {
673 step = parse_int(tail)?;
674 s = head;
675 }
676
677 if s.is_empty() {
678 return Err(crate::ExpressionError::parse_error(
679 "Empty expression",
680 source,
681 ));
682 }
683
684 if let Some((start_s, end_s)) = s.split_once("..") {
685 if !start_s.is_empty() {
686 start = parse_int(start_s)?;
687 }
688 if !end_s.is_empty() {
689 if let Some(end_s) = end_s.strip_prefix('=') {
690 end = Some(parse_int(end_s)? + 1);
691 } else {
692 end = Some(parse_int(end_s)?);
693 }
694 }
695 } else {
696 start = parse_int(s)?;
697 end = Some(start + 1);
698 }
699
700 if step == 0 {
701 return Err(crate::ExpressionError::invalid_expression(
702 "Step cannot be zero",
703 source,
704 ));
705 }
706
707 Ok(Slice::new(start, end, step))
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714 use alloc::string::ToString;
715 use alloc::vec;
716
717 #[test]
718 fn test_slice_to_str() {
719 assert_eq!(Slice::new(0, None, 1).to_string(), "..");
720
721 assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
722
723 assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
724 assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
725
726 assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
727 }
728
729 #[test]
730 fn test_slice_from_str() {
731 assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
732 assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
733 assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
734 assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
735
736 assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
737 assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
738
739 assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
740
741 assert_eq!(
742 "..;0".parse::<Slice>(),
743 Err(crate::ExpressionError::invalid_expression(
744 "Step cannot be zero",
745 "..;0"
746 ))
747 );
748
749 assert_eq!(
750 "".parse::<Slice>(),
751 Err(crate::ExpressionError::parse_error("Empty expression", ""))
752 );
753 assert_eq!(
754 "a".parse::<Slice>(),
755 Err(crate::ExpressionError::parse_error(
756 "Invalid integer: 'a': invalid digit found in string",
757 "a"
758 ))
759 );
760 assert_eq!(
761 "..a".parse::<Slice>(),
762 Err(crate::ExpressionError::parse_error(
763 "Invalid integer: 'a': invalid digit found in string",
764 "..a"
765 ))
766 );
767 assert_eq!(
768 "a:b:c".parse::<Slice>(),
769 Err(crate::ExpressionError::parse_error(
770 "Invalid integer: 'a:b:c': invalid digit found in string",
771 "a:b:c"
772 ))
773 );
774 }
775
776 #[test]
777 fn test_slice_output_size() {
778 assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
780 assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
781 assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
783 assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
784 assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); }
787
788 #[test]
789 fn test_bound_to() {
790 assert_eq!(
791 Slice::new(0, None, 1).bound_to(10),
792 Slice::new(0, Some(10), 1)
793 );
794 assert_eq!(
795 Slice::new(0, Some(5), 1).bound_to(10),
796 Slice::new(0, Some(5), 1)
797 );
798
799 assert_eq!(
800 Slice::new(0, None, -1).bound_to(10),
801 Slice::new(0, Some(-11), -1)
802 );
803 assert_eq!(
804 Slice::new(0, Some(-5), -1).bound_to(10),
805 Slice::new(0, Some(-5), -1)
806 );
807 }
808
809 #[test]
810 fn test_slice_iter() {
811 assert_eq!(
812 Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),
813 vec![2]
814 );
815 assert_eq!(
816 Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),
817 vec![3, 2, 1, 0]
818 );
819
820 assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);
821
822 assert_eq!(
823 Slice::new(3, None, 2)
824 .into_iter()
825 .take(3)
826 .collect::<Vec<_>>(),
827 vec![3, 5, 7]
828 );
829 assert_eq!(
830 Slice::new(3, None, 2)
831 .bound_to(8)
832 .into_iter()
833 .collect::<Vec<_>>(),
834 vec![3, 5, 7]
835 );
836 }
837
838 #[test]
839 #[should_panic(
840 expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }"
841 )]
842 fn test_unbound_slice_into_vec() {
843 Slice::new(0, None, 1).into_vec();
844 }
845
846 #[test]
847 fn into_slices_should_return_for_all_shape_dims() {
848 let slice = s![1];
849 let shape = Shape::new([2, 3, 1]);
850
851 let slices = slice.into_slices(&shape);
852
853 assert_eq!(slices.len(), shape.len());
854
855 assert_eq!(slices[0], Slice::new(1, Some(2), 1));
856 assert_eq!(slices[1], Slice::new(0, Some(3), 1));
857 assert_eq!(slices[2], Slice::new(0, Some(1), 1));
858
859 let slice = s![1, 0..2];
860 let slices = slice.into_slices(&shape);
861
862 assert_eq!(slices.len(), shape.len());
863
864 assert_eq!(slices[0], Slice::new(1, Some(2), 1));
865 assert_eq!(slices[1], Slice::new(0, Some(2), 1));
866 assert_eq!(slices[2], Slice::new(0, Some(1), 1));
867
868 let slice = s![..];
869 let slices = slice.into_slices(&shape);
870
871 assert_eq!(slices.len(), shape.len());
872
873 assert_eq!(slices[0], Slice::new(0, Some(2), 1));
874 assert_eq!(slices[1], Slice::new(0, Some(3), 1));
875 assert_eq!(slices[2], Slice::new(0, Some(1), 1));
876 }
877
878 #[test]
879 fn into_slices_all_dimensions() {
880 let slice = s![1, ..2, ..];
881 let shape = Shape::new([2, 3, 1]);
882
883 let slices = slice.into_slices(&shape);
884
885 assert_eq!(slices.len(), shape.len());
886
887 assert_eq!(slices[0], Slice::new(1, Some(2), 1));
888 assert_eq!(slices[1], Slice::new(0, Some(2), 1));
889 assert_eq!(slices[2], Slice::new(0, Some(1), 1));
890 }
891
892 #[test]
893 fn into_slices_supports_empty_dimensions() {
894 let slice = s![.., 1, ..];
895 let shape = Shape::new([0, 3, 1]);
896
897 let slices = slice.into_slices(&shape);
898
899 assert_eq!(slices.len(), shape.len());
900
901 assert_eq!(slices[0], Slice::new(0, Some(0), 1));
902 assert_eq!(slices[1], Slice::new(1, Some(2), 1));
903 assert_eq!(slices[2], Slice::new(0, Some(1), 1));
904 }
905
906 #[test]
907 #[should_panic = "Too many slices provided for shape"]
908 fn into_slices_should_match_shape_rank() {
909 let slice = s![.., 1, ..];
910 let shape = Shape::new([3, 1]);
911
912 let _ = slice.into_slices(&shape);
913 }
914}