1use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
2use crate::tensors::{Dimension, InvalidDimensionsError, InvalidShapeError};
3use std::error::Error;
4use std::fmt;
5use std::marker::PhantomData;
6
7pub use crate::matrices::views::IndexRange;
8
9#[derive(Clone, Debug)]
79pub struct TensorRange<T, S, const D: usize> {
80 source: S,
81 range: [IndexRange; D],
82 _type: PhantomData<T>,
83}
84
85#[derive(Clone, Debug)]
155pub struct TensorMask<T, S, const D: usize> {
156 source: S,
157 mask: [IndexRange; D],
158 _type: PhantomData<T>,
159}
160
161#[derive(Clone, Debug, Eq, PartialEq)]
165pub enum IndexRangeValidationError<const D: usize, const P: usize> {
166 InvalidShape(InvalidShapeError<D>),
170 InvalidDimensions(InvalidDimensionsError<D, P>),
175}
176
177impl<const D: usize, const P: usize> fmt::Display for IndexRangeValidationError<D, P> {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 match self {
180 IndexRangeValidationError::InvalidShape(error) => write!(f, "{:?}", error),
181 IndexRangeValidationError::InvalidDimensions(error) => write!(f, "{:?}", error),
182 }
183 }
184}
185
186impl<const D: usize, const P: usize> Error for IndexRangeValidationError<D, P> {
187 fn source(&self) -> Option<&(dyn Error + 'static)> {
188 match self {
189 IndexRangeValidationError::InvalidShape(error) => Some(error),
190 IndexRangeValidationError::InvalidDimensions(error) => Some(error),
191 }
192 }
193}
194
195#[derive(Clone, Debug, Eq, PartialEq)]
200pub enum StrictIndexRangeValidationError<const D: usize, const P: usize> {
201 OutsideShape {
208 shape: [(Dimension, usize); D],
209 index_range: [Option<IndexRange>; D],
210 },
211 Error(IndexRangeValidationError<D, P>),
212}
213
214impl<const D: usize, const P: usize> fmt::Display for StrictIndexRangeValidationError<D, P> {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 use StrictIndexRangeValidationError as S;
217 match self {
218 S::OutsideShape { shape, index_range } => write!(
219 f,
220 "IndexRange array {:?} is out of bounds of shape {:?}",
221 index_range, shape
222 ),
223 S::Error(error) => write!(f, "{:?}", error),
224 }
225 }
226}
227
228impl<const D: usize, const P: usize> Error for StrictIndexRangeValidationError<D, P> {
229 fn source(&self) -> Option<&(dyn Error + 'static)> {
230 use StrictIndexRangeValidationError as S;
231 match self {
232 S::OutsideShape {
233 shape: _,
234 index_range: _,
235 } => None,
236 S::Error(error) => Some(error),
237 }
238 }
239}
240
241fn from_named_to_all<T, S, R, const D: usize, const P: usize>(
242 source: &S,
243 ranges: [(Dimension, R); P],
244) -> Result<[Option<IndexRange>; D], IndexRangeValidationError<D, P>>
245where
246 S: TensorRef<T, D>,
247 R: Into<IndexRange>,
248{
249 let shape = source.view_shape();
250 let ranges = ranges.map(|(d, r)| (d, r.into()));
251 let dimensions = InvalidDimensionsError {
252 provided: ranges.clone().map(|(d, _)| d),
253 valid: shape.map(|(d, _)| d),
254 };
255 if dimensions.has_duplicates() {
256 return Err(IndexRangeValidationError::InvalidDimensions(dimensions));
257 }
258 let mut all_ranges: [Option<IndexRange>; D] = std::array::from_fn(|_| None);
261 for (name, range) in ranges.into_iter() {
262 match crate::tensors::dimensions::position_of(&shape, name) {
263 Some(d) => all_ranges[d] = Some(range),
264 None => return Err(IndexRangeValidationError::InvalidDimensions(dimensions)),
265 };
266 }
267 Ok(all_ranges)
268}
269
270impl<T, S, const D: usize> TensorRange<T, S, D>
271where
272 S: TensorRef<T, D>,
273{
274 pub fn from<R, const P: usize>(
282 source: S,
283 ranges: [(Dimension, R); P],
284 ) -> Result<TensorRange<T, S, D>, IndexRangeValidationError<D, P>>
285 where
286 R: Into<IndexRange>,
287 {
288 let all_ranges = from_named_to_all(&source, ranges)?;
289 match TensorRange::from_all(source, all_ranges) {
290 Ok(tensor_range) => Ok(tensor_range),
291 Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
292 }
293 }
294
295 pub fn from_strict<R, const P: usize>(
303 source: S,
304 ranges: [(Dimension, R); P],
305 ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, P>>
306 where
307 R: Into<IndexRange>,
308 {
309 use StrictIndexRangeValidationError as S;
310 let all_ranges = match from_named_to_all(&source, ranges) {
311 Ok(all_ranges) => all_ranges,
312 Err(error) => return Err(S::Error(error)),
313 };
314 match TensorRange::from_all_strict(source, all_ranges) {
315 Ok(tensor_range) => Ok(tensor_range),
316 Err(S::OutsideShape { shape, index_range }) => {
317 Err(S::OutsideShape { shape, index_range })
318 }
319 Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
320 Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
321 }
322 Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
323 "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
324 ),
325 }
326 }
327
328 pub fn from_all<R>(
335 source: S,
336 ranges: [Option<R>; D],
337 ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>>
338 where
339 R: Into<IndexRange>,
340 {
341 TensorRange::clip_from(
342 source,
343 ranges.map(|option| option.map(|range| range.into())),
344 )
345 }
346
347 fn clip_from(
348 source: S,
349 ranges: [Option<IndexRange>; D],
350 ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>> {
351 let shape = source.view_shape();
352 let mut ranges = {
353 let mut d = 0;
356 ranges.map(|option| {
357 let range = option.unwrap_or_else(|| IndexRange::new(0, shape[d].1));
359 d += 1;
360 range
361 })
362 };
363 let shape = InvalidShapeError {
364 shape: clip_range_shape(&shape, &mut ranges),
365 };
366 if !shape.is_valid() {
367 return Err(shape);
368 }
369
370 Ok(TensorRange {
371 source,
372 range: ranges,
373 _type: PhantomData,
374 })
375 }
376
377 pub fn from_all_strict<R>(
386 source: S,
387 range: [Option<R>; D],
388 ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, D>>
389 where
390 R: Into<IndexRange>,
391 {
392 let shape = source.view_shape();
393 let range = range.map(|option| option.map(|range| range.into()));
394 if range_exceeds_bounds(&shape, &range) {
395 return Err(StrictIndexRangeValidationError::OutsideShape {
396 shape,
397 index_range: range,
398 });
399 }
400
401 match TensorRange::clip_from(source, range) {
402 Ok(tensor_range) => Ok(tensor_range),
403 Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
404 IndexRangeValidationError::InvalidShape(invalid_shape),
405 )),
406 }
407 }
408
409 #[allow(dead_code)]
413 pub fn source(self) -> S {
414 self.source
415 }
416
417 #[allow(dead_code)]
427 pub fn source_ref(&self) -> &S {
428 &self.source
429 }
430}
431
432fn range_exceeds_bounds<const D: usize>(
433 source: &[(Dimension, usize); D],
434 range: &[Option<IndexRange>; D],
435) -> bool {
436 for (d, (_, end)) in source.iter().enumerate() {
437 let end = *end;
438 match &range[d] {
439 None => continue,
440 Some(range) => {
441 let range_end = range.start + range.length;
442 match range_end > end {
443 true => return true,
444 false => (),
445 };
446 }
447 }
448 }
449 false
450}
451
452fn clip_range_shape<const D: usize>(
455 source: &[(Dimension, usize); D],
456 range: &mut [IndexRange; D],
457) -> [(Dimension, usize); D] {
458 let mut shape = *source;
459 for (d, (_, length)) in shape.iter_mut().enumerate() {
460 let range = &mut range[d];
461 range.clip(*length);
462 *length = range.length;
464 }
465 shape
466}
467
468impl<T, S, const D: usize> TensorMask<T, S, D>
469where
470 S: TensorRef<T, D>,
471{
472 pub fn from<R, const P: usize>(
479 source: S,
480 masks: [(Dimension, R); P],
481 ) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>
482 where
483 R: Into<IndexRange>,
484 {
485 let all_masks = from_named_to_all(&source, masks)?;
486 match TensorMask::from_all(source, all_masks) {
487 Ok(tensor_mask) => Ok(tensor_mask),
488 Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
489 }
490 }
491
492 pub fn from_strict<R, const P: usize>(
500 source: S,
501 masks: [(Dimension, R); P],
502 ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>
503 where
504 R: Into<IndexRange>,
505 {
506 use StrictIndexRangeValidationError as S;
507 let all_masks = match from_named_to_all(&source, masks) {
508 Ok(all_masks) => all_masks,
509 Err(error) => return Err(S::Error(error)),
510 };
511 match TensorMask::from_all_strict(source, all_masks) {
512 Ok(tensor_mask) => Ok(tensor_mask),
513 Err(S::OutsideShape { shape, index_range }) => {
514 Err(S::OutsideShape { shape, index_range })
515 }
516 Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
517 Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
518 }
519 Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
520 "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
521 ),
522 }
523 }
524
525 pub fn from_all<R>(
532 source: S,
533 mask: [Option<R>; D],
534 ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>
535 where
536 R: Into<IndexRange>,
537 {
538 TensorMask::clip_from(source, mask.map(|option| option.map(|mask| mask.into())))
539 }
540
541 fn clip_from(
542 source: S,
543 masks: [Option<IndexRange>; D],
544 ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>> {
545 let shape = source.view_shape();
546 let mut masks = masks.map(|option| option.unwrap_or_else(|| IndexRange::new(0, 0)));
547 let shape = InvalidShapeError {
548 shape: clip_masked_shape(&shape, &mut masks),
549 };
550 if !shape.is_valid() {
551 return Err(shape);
552 }
553
554 Ok(TensorMask {
555 source,
556 mask: masks,
557 _type: PhantomData,
558 })
559 }
560
561 pub fn from_all_strict<R>(
570 source: S,
571 masks: [Option<R>; D],
572 ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>
573 where
574 R: Into<IndexRange>,
575 {
576 let shape = source.view_shape();
577 let masks = masks.map(|option| option.map(|mask| mask.into()));
578 if mask_exceeds_bounds(&shape, &masks) {
579 return Err(StrictIndexRangeValidationError::OutsideShape {
580 shape,
581 index_range: masks,
582 });
583 }
584
585 match TensorMask::clip_from(source, masks) {
586 Ok(tensor_mask) => Ok(tensor_mask),
587 Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
588 IndexRangeValidationError::InvalidShape(invalid_shape),
589 )),
590 }
591 }
592
593 #[allow(dead_code)]
597 pub fn source(self) -> S {
598 self.source
599 }
600
601 #[allow(dead_code)]
611 pub fn source_ref(&self) -> &S {
612 &self.source
613 }
614}
615
616fn clip_masked_shape<const D: usize>(
619 source: &[(Dimension, usize); D],
620 mask: &mut [IndexRange; D],
621) -> [(Dimension, usize); D] {
622 let mut shape = *source;
623 for (d, (_, length)) in shape.iter_mut().enumerate() {
624 let mask = &mut mask[d];
625 mask.clip(*length);
626 *length -= mask.length;
628 }
629 shape
630}
631
632fn mask_exceeds_bounds<const D: usize>(
633 source: &[(Dimension, usize); D],
634 mask: &[Option<IndexRange>; D],
635) -> bool {
636 range_exceeds_bounds(source, mask)
638}
639
640fn map_indexes_by_range<const D: usize>(
641 indexes: [usize; D],
642 ranges: &[IndexRange; D],
643) -> Option<[usize; D]> {
644 let mut mapped = [0; D];
645 for (d, (r, i)) in ranges.iter().zip(indexes.into_iter()).enumerate() {
646 mapped[d] = r.map(i)?;
647 }
648 Some(mapped)
649}
650
651unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRange<T, S, D>
660where
661 S: TensorRef<T, D>,
662{
663 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
664 self.source
665 .get_reference(map_indexes_by_range(indexes, &self.range)?)
666 }
667
668 fn view_shape(&self) -> [(Dimension, usize); D] {
669 let mut shape = self.source.view_shape();
672 for (pair, range) in shape.iter_mut().zip(self.range.iter()) {
674 pair.1 = range.length;
675 }
676 shape
677 }
678
679 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
680 unsafe {
681 self.source
686 .get_reference_unchecked(map_indexes_by_range(indexes, &self.range).unwrap())
687 }
688 }
689
690 fn data_layout(&self) -> DataLayout<D> {
691 DataLayout::NonLinear
695 }
696}
697
698unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRange<T, S, D>
707where
708 S: TensorMut<T, D>,
709{
710 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
711 self.source
712 .get_reference_mut(map_indexes_by_range(indexes, &self.range)?)
713 }
714
715 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
716 unsafe {
717 self.source
722 .get_reference_unchecked_mut(map_indexes_by_range(indexes, &self.range).unwrap())
723 }
724 }
725}
726
727fn map_indexes_by_mask<const D: usize>(indexes: [usize; D], masks: &[IndexRange; D]) -> [usize; D] {
728 let mut mapped = [0; D];
729 for (d, (r, i)) in masks.iter().zip(indexes.into_iter()).enumerate() {
730 mapped[d] = r.mask(i);
731 }
732 mapped
733}
734
735unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>
744where
745 S: TensorRef<T, D>,
746{
747 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
748 self.source
749 .get_reference(map_indexes_by_mask(indexes, &self.mask))
750 }
751
752 fn view_shape(&self) -> [(Dimension, usize); D] {
753 let mut shape = self.source.view_shape();
756 for (pair, mask) in shape.iter_mut().zip(self.mask.iter()) {
758 pair.1 -= mask.length;
759 }
760 shape
761 }
762
763 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
764 unsafe {
765 self.source
768 .get_reference_unchecked(map_indexes_by_mask(indexes, &self.mask))
769 }
770 }
771
772 fn data_layout(&self) -> DataLayout<D> {
773 DataLayout::NonLinear
776 }
777}
778
779unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>
788where
789 S: TensorMut<T, D>,
790{
791 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
792 self.source
793 .get_reference_mut(map_indexes_by_mask(indexes, &self.mask))
794 }
795
796 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
797 unsafe {
798 self.source
801 .get_reference_unchecked_mut(map_indexes_by_mask(indexes, &self.mask))
802 }
803 }
804}
805
806#[test]
807#[rustfmt::skip]
808fn test_constructors() {
809 use crate::tensors::Tensor;
810 use crate::tensors::views::TensorView;
811 let tensor = Tensor::from([("rows", 3), ("columns", 3)], (0..9).collect());
812 assert_eq!(
814 TensorView::from(TensorRange::from(&tensor, [("rows", IndexRange::new(1, 2))]).unwrap()),
815 Tensor::from([("rows", 2), ("columns", 3)], vec![
816 3, 4, 5,
817 6, 7, 8
818 ])
819 );
820 assert_eq!(
821 TensorView::from(TensorRange::from(&tensor, [("columns", 2..3)]).unwrap()),
822 Tensor::from([("rows", 3), ("columns", 1)], vec![
823 2,
824 5,
825 8
826 ])
827 );
828 assert_eq!(
829 TensorView::from(TensorRange::from(&tensor, [("rows", (1, 1)), ("columns", (2, 1))]).unwrap()),
830 Tensor::from([("rows", 1), ("columns", 1)], vec![5])
831 );
832 assert_eq!(
833 TensorView::from(TensorRange::from(&tensor, [("columns", 1..3)]).unwrap()),
834 Tensor::from([("rows", 3), ("columns", 2)], vec![
835 1, 2,
836 4, 5,
837 7, 8
838 ])
839 );
840
841 assert_eq!(
842 TensorView::from(TensorMask::from(&tensor, [("rows", IndexRange::new(1, 1))]).unwrap()),
843 Tensor::from([("rows", 2), ("columns", 3)], vec![
844 0, 1, 2,
845 6, 7, 8
846 ])
847 );
848 assert_eq!(
849 TensorView::from(TensorMask::from(&tensor, [("rows", 2..3), ("columns", 0..1)]).unwrap()),
850 Tensor::from([("rows", 2), ("columns", 2)], vec![
851 1, 2,
852 4, 5
853 ])
854 );
855
856 use IndexRangeValidationError as IRVError;
857 use InvalidShapeError as ShapeError;
858 use StrictIndexRangeValidationError::Error as SError;
859 use StrictIndexRangeValidationError::OutsideShape as OutsideShape;
860 use InvalidDimensionsError as DError;
861 assert_eq!(
863 TensorRange::from(&tensor, [("invalid", 1..2)]).unwrap_err(),
864 IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"]))
865 );
866 assert_eq!(
867 TensorMask::from(&tensor, [("wrong", 0..1)]).unwrap_err(),
868 IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"]))
869 );
870 assert_eq!(
871 TensorRange::from_strict(&tensor, [("invalid", 1..2)]).unwrap_err(),
872 SError(IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"])))
873 );
874 assert_eq!(
875 TensorMask::from_strict(&tensor, [("wrong", 0..1)]).unwrap_err(),
876 SError(IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"])))
877 );
878
879 assert_eq!(
881 TensorRange::from(&tensor, [("rows", 0..0)]).unwrap_err(),
882 IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)]))
883 );
884 assert_eq!(
885 TensorMask::from(&tensor, [("columns", 0..3)]).unwrap_err(),
886 IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)]))
887 );
888 assert_eq!(
889 TensorRange::from_strict(&tensor, [("rows", 0..0)]).unwrap_err(),
890 SError(IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)])))
891 );
892 assert_eq!(
893 TensorMask::from_strict(&tensor, [("columns", 0..3)]).unwrap_err(),
894 SError(IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)])))
895 );
896
897 assert_eq!(
899 TensorRange::from(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
900 IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"]))
901 );
902 assert_eq!(
903 TensorMask::from(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
904 IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"]))
905 );
906 assert_eq!(
907 TensorRange::from_strict(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
908 SError(IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"])))
909 );
910 assert_eq!(
911 TensorMask::from_strict(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
912 SError(IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"])))
913 );
914
915 assert!(
917 TensorView::from(TensorRange::from(&tensor, [("rows", 0..4)]).unwrap()).eq(&tensor),
918 );
919 assert_eq!(
920 TensorRange::from_strict(&tensor, [("rows", 0..4)]).unwrap_err(),
921 OutsideShape {
922 shape: [("rows", 3), ("columns", 3)],
923 index_range: [Some(IndexRange::new(0, 4)), None],
924 }
925 );
926 assert_eq!(
927 TensorView::from(TensorMask::from(&tensor, [("columns", 1..4)]).unwrap()),
928 Tensor::from([("rows", 3), ("columns", 1)], vec![
929 0,
930 3,
931 6,
932 ])
933 );
934 assert_eq!(
935 TensorMask::from_strict(&tensor, [("columns", 1..4)]).unwrap_err(),
936 OutsideShape {
937 shape: [("rows", 3), ("columns", 3)],
938 index_range: [None, Some(IndexRange::new(1, 3))],
939 }
940 );
941}