Skip to main content

ha_ndarray/
lib.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Rem, Sub};
4
5use number_general as ng;
6use safecast::CastFrom;
7
8#[cfg(feature = "complex")]
9pub use num_complex as complex;
10pub use smallvec::smallvec as axes;
11pub use smallvec::smallvec as coord;
12pub use smallvec::smallvec as range;
13pub use smallvec::smallvec as slice;
14pub use smallvec::smallvec as shape;
15pub use smallvec::smallvec as stackvec;
16use smallvec::SmallVec;
17
18pub use access::*;
19pub use array::{
20    same_shape, MatrixDual, MatrixUnary, NDArray, NDArrayAbs, NDArrayBoolean, NDArrayBooleanScalar,
21    NDArrayCast, NDArrayCompare, NDArrayCompareScalar, NDArrayMath, NDArrayMathScalar,
22    NDArrayNumeric, NDArrayRead, NDArrayReduce, NDArrayReduceAll, NDArrayReduceBoolean,
23    NDArrayTransform, NDArrayTrig, NDArrayUnary, NDArrayUnaryBoolean, NDArrayWhere, NDArrayWrite,
24};
25#[cfg(feature = "complex")]
26pub use array::{MatrixUnaryComplex, NDArrayComplex, NDArrayFourier};
27pub use buffer::{Buffer, BufferConverter, BufferInstance, BufferMut};
28pub use host::StackVec;
29pub use platform::*;
30
31mod access;
32mod array;
33mod buffer;
34#[cfg(feature = "complex")]
35pub mod fft;
36pub mod host;
37#[cfg(feature = "opencl")]
38pub mod opencl;
39pub mod ops;
40mod platform;
41
42fn id<T>(this: T) -> T {
43    this
44}
45
46#[cfg(feature = "opencl")]
47pub trait CLType:
48    opencl::CLElement + PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
49{
50}
51
52#[cfg(not(feature = "opencl"))]
53pub trait CLType: PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static {}
54
55impl CLType for f32 {}
56impl CLType for f64 {}
57impl CLType for i8 {}
58impl CLType for i16 {}
59impl CLType for i32 {}
60impl CLType for i64 {}
61impl CLType for u8 {}
62impl CLType for u16 {}
63impl CLType for u32 {}
64impl CLType for u64 {}
65#[cfg(feature = "complex")]
66impl CLType for complex::Complex<f32> {}
67#[cfg(feature = "complex")]
68impl CLType for complex::Complex<f64> {}
69
70/// A numeric type supported by ha-ndarray
71pub trait Number: CLType + Into<ng::Number> + CastFrom<ng::Number> + Default {
72    /// The zero value of this data type.
73    const ZERO: Self;
74
75    /// The one value of this data type.
76    const ONE: Self;
77
78    /// The absolute value type of this [`Number`].
79    type Abs: Number;
80
81    // arithmetic
82
83    /// Construct an instance of this type from a [`f64`].
84    fn abs(self) -> Self::Abs;
85
86    /// Add two instances of this type.
87    fn add(self, other: Self) -> Self;
88
89    /// Divide two instances of this type.
90    fn div(self, other: Self) -> Self;
91
92    /// Multiply two instances of this type.
93    fn mul(self, other: Self) -> Self;
94
95    /// Subtract two instances of this type.
96    fn sub(self, other: Self) -> Self;
97
98    /// Raise this value to the power of the given `exp`onent.
99    fn pow(self, exp: Self) -> Self;
100}
101
102macro_rules! number {
103    ($t:ty, $abs_t:ty, $one:expr, $zero:expr, $abs:expr, $add:expr, $div:expr, $mul:expr, $sub:expr, $pow:expr) => {
104        impl Number for $t {
105            const ONE: Self = $one;
106
107            const ZERO: Self = $zero;
108
109            type Abs = $abs_t;
110
111            fn abs(self) -> Self::Abs {
112                $abs(self)
113            }
114
115            fn add(self, other: Self) -> Self {
116                $add(self, other)
117            }
118
119            fn div(self, other: Self) -> Self {
120                $div(self, other)
121            }
122
123            fn mul(self, other: Self) -> Self {
124                $mul(self, other)
125            }
126
127            fn sub(self, other: Self) -> Self {
128                $sub(self, other)
129            }
130
131            fn pow(self, exp: Self) -> Self {
132                ($pow)(self, exp)
133            }
134        }
135    };
136}
137
138#[cfg(feature = "complex")]
139number!(
140    complex::Complex32,
141    f32,
142    complex::Complex32::ONE,
143    complex::Complex32::ZERO,
144    complex::Complex32::norm,
145    Add::add,
146    Div::div,
147    Mul::mul,
148    Sub::sub,
149    complex::Complex32::powc
150);
151
152#[cfg(feature = "complex")]
153number!(
154    complex::Complex64,
155    f64,
156    complex::Complex64::ONE,
157    complex::Complex64::ZERO,
158    complex::Complex64::norm,
159    Add::add,
160    Div::div,
161    Mul::mul,
162    Sub::sub,
163    complex::Complex64::powc
164);
165
166number!(
167    f32,
168    Self,
169    1.,
170    0.,
171    f32::abs,
172    Add::add,
173    Div::div,
174    Mul::mul,
175    Sub::sub,
176    f32::powf
177);
178
179number!(
180    f64,
181    Self,
182    1.,
183    0.,
184    f64::abs,
185    Add::add,
186    Div::div,
187    Mul::mul,
188    Sub::sub,
189    f64::powf
190);
191
192number!(
193    i8,
194    Self,
195    1,
196    0,
197    Self::wrapping_abs,
198    Self::wrapping_add,
199    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
200    Self::wrapping_mul,
201    Self::wrapping_sub,
202    |a, e| f32::powi(a as f32, e as i32) as i8
203);
204
205number!(
206    i16,
207    Self,
208    1,
209    0,
210    Self::wrapping_abs,
211    Self::wrapping_add,
212    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
213    Self::wrapping_mul,
214    Self::wrapping_sub,
215    |a, e| f32::powi(a as f32, e as i32) as i16
216);
217
218number!(
219    i32,
220    Self,
221    1,
222    0,
223    Self::wrapping_abs,
224    Self::wrapping_add,
225    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
226    Self::wrapping_mul,
227    Self::wrapping_sub,
228    |a, e| f32::powi(a as f32, e) as i32
229);
230
231number!(
232    i64,
233    Self,
234    1,
235    0,
236    Self::wrapping_abs,
237    Self::wrapping_add,
238    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
239    Self::wrapping_mul,
240    Self::wrapping_sub,
241    |a, e| f64::powi(
242        a as f64,
243        i32::try_from(e).unwrap_or(if e >= 0 { i32::MAX } else { i32::MIN })
244    ) as i64
245);
246
247number!(
248    u8,
249    Self,
250    1,
251    0,
252    id,
253    Self::wrapping_add,
254    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
255    Self::wrapping_mul,
256    Self::wrapping_sub,
257    |a, e| u8::pow(a, e as u32)
258);
259
260number!(
261    u16,
262    Self,
263    1,
264    0,
265    id,
266    Self::wrapping_add,
267    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
268    Self::wrapping_mul,
269    Self::wrapping_sub,
270    |a, e| u16::pow(a, e as u32)
271);
272
273number!(
274    u32,
275    Self,
276    1,
277    0,
278    id,
279    Self::wrapping_add,
280    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
281    Self::wrapping_mul,
282    Self::wrapping_sub,
283    u32::pow
284);
285
286number!(
287    u64,
288    Self,
289    1,
290    0,
291    id,
292    Self::wrapping_add,
293    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
294    Self::wrapping_mul,
295    Self::wrapping_sub,
296    |a, e| u64::pow(a, u32::try_from(e).unwrap_or(u32::MAX))
297);
298
299#[cfg(not(feature = "opencl"))]
300/// A real-valued [`Number`]
301pub trait Real: Number + PartialOrd {
302    /// The maximum value of this data type.
303    const MAX: Self;
304
305    /// The minimum value of this data type.
306    const MIN: Self;
307
308    /// Return the maximum of the given values.
309    fn max(l: Self, r: Self) -> Self;
310
311    /// Return the maximum of the given values.
312    fn min(l: Self, r: Self) -> Self;
313
314    /// Compute the remainder of `self.div(other)`.
315    fn rem(self, other: Self) -> Self;
316
317    /// Round this value to the nearest integer.
318    fn round(self) -> Self;
319}
320
321#[cfg(feature = "opencl")]
322/// A real-valued [`Number`]
323pub trait Real: Number + PartialOrd + opencl::CLElementReal {
324    /// The maximum value of this data type.
325    const MAX: Self;
326
327    /// The minimum value of this data type.
328    const MIN: Self;
329
330    /// Return the maximum of the given values.
331    fn max(l: Self, r: Self) -> Self;
332
333    /// Return the maximum of the given values.
334    fn min(l: Self, r: Self) -> Self;
335
336    /// Compute the remainder of `self.div(other)`.
337    fn rem(self, other: Self) -> Self;
338
339    /// Round this value to the nearest integer.
340    fn round(self) -> Self;
341}
342
343macro_rules! real {
344    ($t:ty, $rem:expr, $ord:expr, $round:expr) => {
345        impl Real for $t {
346            const MAX: Self = <$t>::MAX;
347
348            const MIN: Self = <$t>::MIN;
349
350            fn max(l: Self, r: Self) -> $t {
351                match $ord(&l, &r) {
352                    Ordering::Greater | Ordering::Equal => l,
353                    Ordering::Less => r,
354                }
355            }
356
357            fn min(l: Self, r: Self) -> $t {
358                match $ord(&l, &r) {
359                    Ordering::Less | Ordering::Equal => l,
360                    Ordering::Greater => r,
361                }
362            }
363
364            fn rem(self, other: Self) -> Self {
365                $rem(self, other)
366            }
367
368            fn round(self) -> Self {
369                $round(self)
370            }
371        }
372    };
373}
374
375real!(f32, Rem::rem, f32::total_cmp, f32::round);
376real!(f64, Rem::rem, f64::total_cmp, f64::round);
377real!(i8, Self::wrapping_rem, Ord::cmp, id);
378real!(i16, Self::wrapping_rem, Ord::cmp, id);
379real!(i32, Self::wrapping_rem, Ord::cmp, id);
380real!(i64, Self::wrapping_rem, Ord::cmp, id);
381real!(u8, Self::wrapping_rem, Ord::cmp, id);
382real!(u16, Self::wrapping_rem, Ord::cmp, id);
383real!(u32, Self::wrapping_rem, Ord::cmp, id);
384real!(u64, Self::wrapping_rem, Ord::cmp, id);
385
386#[cfg(not(feature = "opencl"))]
387/// A floating-point [`Number`]
388pub trait Float: Number {
389    // numeric methods
390    /// Return `true` if this [`Float`] is infinite (positive or negative infinity).
391    fn is_inf(self) -> bool;
392
393    /// Return `true` if this [`Float`] is not a number (e.g. a float representation of `1.0 / 0.0`).
394    fn is_nan(self) -> bool;
395
396    // logarithms
397    /// Exponentiate this number (equivalent to `consts::E.pow(self)`).
398    fn exp(self) -> Self;
399
400    /// Return the natural logarithm of this [`Float`].
401    fn ln(self) -> Self;
402
403    /// Calculate the logarithm of this [`Float`] w/r/t the given `base`.
404    fn log(self, base: Self) -> Self;
405
406    // trigonometry
407    /// Return the sine of this [`Float`] (in radians).
408    fn sin(self) -> Self;
409
410    /// Return the arcsine of this [`Float`] (in radians).
411    fn asin(self) -> Self;
412
413    /// Return the hyperbolic sine of this [`Float`] (in radians).
414    fn sinh(self) -> Self;
415
416    /// Return the cosine of this [`Float`] (in radians).
417    fn cos(self) -> Self;
418
419    /// Return the arcsine of this [`Float`] (in radians).
420    fn acos(self) -> Self;
421
422    /// Return the hyperbolic cosine of this [`Float`] (in radians).
423    fn cosh(self) -> Self;
424
425    /// Return the tangent of this [`Float`] (in radians).
426    fn tan(self) -> Self;
427
428    /// Return the arctangent of this [`Float`] (in radians).
429    fn atan(self) -> Self;
430
431    /// Return the hyperbolic tangent of this [`Float`] (in radians).
432    fn tanh(self) -> Self;
433}
434
435#[cfg(feature = "opencl")]
436/// A floating-point [`Number`]
437pub trait Float: Number + opencl::CLElementTrig {
438    // numeric methods
439    /// Return `true` if this [`Float`] is infinite (positive or negative infinity).
440    fn is_inf(self) -> bool;
441
442    /// Return `true` if this [`Float`] is not a number (e.g. a float representation of `1.0 / 0.0`).
443    fn is_nan(self) -> bool;
444
445    // logarithms
446    /// Exponentiate this number (equivalent to `consts::E.pow(self)`).
447    fn exp(self) -> Self;
448
449    /// Return the natural logarithm of this [`Float`].
450    fn ln(self) -> Self;
451
452    /// Calculate the logarithm of this [`Float`] w/r/t the given `base`.
453    fn log(self, base: Self) -> Self;
454
455    // trigonometry
456    /// Return the sine of this [`Float`] (in radians).
457    fn sin(self) -> Self;
458
459    /// Return the arcsine of this [`Float`] (in radians).
460    fn asin(self) -> Self;
461
462    /// Return the hyperbolic sine of this [`Float`] (in radians).
463    fn sinh(self) -> Self;
464
465    /// Return the cosine of this [`Float`] (in radians).
466    fn cos(self) -> Self;
467
468    /// Return the arcsine of this [`Float`] (in radians).
469    fn acos(self) -> Self;
470
471    /// Return the hyperbolic cosine of this [`Float`] (in radians).
472    fn cosh(self) -> Self;
473
474    /// Return the tangent of this [`Float`] (in radians).
475    fn tan(self) -> Self;
476
477    /// Return the arctangent of this [`Float`] (in radians).
478    fn atan(self) -> Self;
479
480    /// Return the hyperbolic tangent of this [`Float`] (in radians).
481    fn tanh(self) -> Self;
482}
483
484macro_rules! float_type {
485    ($t:ty, $inf:expr, $nan:expr) => {
486        impl Float for $t {
487            fn is_inf(self) -> bool {
488                $inf(self)
489            }
490
491            fn is_nan(self) -> bool {
492                $nan(self)
493            }
494
495            fn exp(self) -> Self {
496                <$t>::exp(self)
497            }
498
499            fn ln(self) -> Self {
500                <$t>::ln(self)
501            }
502
503            fn log(self, base: Self) -> Self {
504                self.ln() / base.ln()
505            }
506
507            fn sin(self) -> Self {
508                <$t>::sin(self)
509            }
510
511            fn asin(self) -> Self {
512                <$t>::asin(self)
513            }
514
515            fn sinh(self) -> Self {
516                <$t>::sinh(self)
517            }
518
519            fn cos(self) -> Self {
520                <$t>::cos(self)
521            }
522
523            fn acos(self) -> Self {
524                <$t>::acos(self)
525            }
526
527            fn cosh(self) -> Self {
528                <$t>::cosh(self)
529            }
530
531            fn tan(self) -> Self {
532                <$t>::tan(self)
533            }
534
535            fn atan(self) -> Self {
536                <$t>::atan(self)
537            }
538
539            fn tanh(self) -> Self {
540                <$t>::tanh(self)
541            }
542        }
543    };
544}
545
546#[cfg(feature = "complex")]
547float_type!(complex::Complex32, |_| false, |_| false);
548#[cfg(feature = "complex")]
549float_type!(complex::Complex64, |_| false, |_| false);
550float_type!(f32, f32::is_infinite, f32::is_nan);
551float_type!(f64, f64::is_infinite, f64::is_nan);
552
553#[cfg(all(feature = "complex", not(feature = "opencl")))]
554/// A complex [`Number`]
555pub trait Complex: Float<Abs = Self::Real> {
556    type Real: Float + Real;
557
558    fn angle(self) -> Self::Real;
559
560    fn conj(self) -> Self;
561
562    fn im(self) -> Self::Real;
563
564    fn re(self) -> Self::Real;
565}
566
567#[cfg(all(feature = "complex", feature = "opencl"))]
568/// A complex [`Number`]
569pub trait Complex: Float<Abs = Self::Real> + opencl::CLElementComplex {
570    type Real: Float + Real;
571
572    fn angle(self) -> Self::Real;
573
574    fn conj(self) -> Self;
575
576    fn im(self) -> Self::Real;
577
578    fn re(self) -> Self::Real;
579}
580
581#[cfg(feature = "complex")]
582macro_rules! complex_type {
583    ($t:ty, $r:ty) => {
584        impl Complex for $t {
585            type Real = $r;
586
587            fn angle(self) -> $r {
588                Self::arg(self)
589            }
590
591            fn conj(self) -> Self {
592                complex::Complex::<$r>::conj(&self)
593            }
594
595            fn im(self) -> $r {
596                self.im
597            }
598
599            fn re(self) -> $r {
600                self.re
601            }
602        }
603    };
604}
605
606#[cfg(feature = "complex")]
607complex_type!(complex::Complex32, f32);
608#[cfg(feature = "complex")]
609complex_type!(complex::Complex64, f64);
610
611/// An array math error
612pub enum Error {
613    Bounds(String),
614    Unsupported(String),
615    #[cfg(feature = "opencl")]
616    OCL(std::sync::Arc<ocl::Error>),
617}
618
619impl Error {
620    pub fn bounds(msg: String) -> Self {
621        #[cfg(feature = "debug_crash")]
622        panic!("{}", msg);
623
624        #[cfg(not(feature = "debug_crash"))]
625        Self::Bounds(msg)
626    }
627
628    pub fn unsupported(msg: String) -> Self {
629        #[cfg(feature = "debug_crash")]
630        panic!("{}", msg);
631
632        #[cfg(not(feature = "debug_crash"))]
633        Self::Unsupported(msg)
634    }
635}
636
637// Clone is required to support memoizing OpenCL programs
638// since constructing an [`ocl::Program`] may return an error
639impl Clone for Error {
640    fn clone(&self) -> Self {
641        match self {
642            Self::Bounds(msg) => Self::Bounds(msg.clone()),
643            Self::Unsupported(msg) => Self::Unsupported(msg.clone()),
644            #[cfg(feature = "opencl")]
645            Self::OCL(cause) => Self::OCL(cause.clone()),
646        }
647    }
648}
649
650#[cfg(feature = "opencl")]
651impl From<ocl::Error> for Error {
652    fn from(cause: ocl::Error) -> Self {
653        #[cfg(feature = "debug_crash")]
654        panic!("OpenCL error: {:?}", cause);
655
656        #[cfg(not(feature = "debug_crash"))]
657        Self::OCL(std::sync::Arc::new(cause))
658    }
659}
660
661impl fmt::Debug for Error {
662    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
663        match self {
664            Self::Bounds(cause) => f.write_str(cause),
665            Self::Unsupported(cause) => f.write_str(cause),
666            #[cfg(feature = "opencl")]
667            Self::OCL(cause) => cause.fmt(f),
668        }
669    }
670}
671
672impl fmt::Display for Error {
673    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
674        match self {
675            Self::Bounds(cause) => f.write_str(cause),
676            Self::Unsupported(cause) => f.write_str(cause),
677            #[cfg(feature = "opencl")]
678            Self::OCL(cause) => cause.fmt(f),
679        }
680    }
681}
682
683impl std::error::Error for Error {}
684
685/// A list of n-dimensional array axes
686pub type Axes = SmallVec<[usize; 8]>;
687
688/// An n-dimensional selection range, used to slice an array
689pub type Range = SmallVec<[AxisRange; 8]>;
690
691/// The shape of an n-dimensional array
692pub type Shape = SmallVec<[usize; 8]>;
693
694/// The strides used to access an n-dimensional array
695pub type Strides = SmallVec<[usize; 8]>;
696
697/// An n-dimensional array on the top-level [`Platform`]
698pub type Array<T, A> = array::Array<T, A, Platform>;
699
700/// An n-dimensional array backed by a buffer on the top-level [`Platform`]
701pub type ArrayBuf<T, B> = array::Array<T, AccessBuf<B>, Platform>;
702
703/// The result of an n-dimensional array operation
704pub type ArrayOp<T, Op> = array::Array<T, AccessOp<Op>, Platform>;
705
706/// A general type of n-dimensional array used to elide recursive types
707pub type ArrayAccess<'a, T> = array::Array<T, Accessor<'a, T>, Platform>;
708
709/// An accessor for the result of an n-dimensional array operation on the top-level [`Platform`]
710pub type AccessOp<Op> = access::AccessOp<Op, Platform>;
711
712/// Bounds on an individual array axis
713#[derive(Clone, Eq, PartialEq, Hash)]
714pub enum AxisRange {
715    At(usize),
716    In(usize, usize, usize),
717    Of(SmallVec<[usize; 8]>),
718}
719
720impl AxisRange {
721    /// Return `true` if this is an index bound (i.e. not a slice)
722    pub fn is_index(&self) -> bool {
723        matches!(self, Self::At(_))
724    }
725
726    /// Return the number of elements contained within this bound.
727    /// Returns `None` for an index bound.
728    pub fn size(&self) -> Option<usize> {
729        match self {
730            Self::At(_) => None,
731            Self::In(start, stop, step) => Some((stop - start) / step),
732            Self::Of(indices) => Some(indices.len()),
733        }
734    }
735}
736
737impl From<usize> for AxisRange {
738    fn from(i: usize) -> Self {
739        Self::At(i)
740    }
741}
742
743impl From<std::ops::Range<usize>> for AxisRange {
744    fn from(range: std::ops::Range<usize>) -> Self {
745        Self::In(range.start, range.end, 1)
746    }
747}
748
749impl From<SmallVec<[usize; 8]>> for AxisRange {
750    fn from(indices: SmallVec<[usize; 8]>) -> Self {
751        Self::Of(indices)
752    }
753}
754
755impl fmt::Debug for AxisRange {
756    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
757        match self {
758            Self::At(i) => write!(f, "{}", i),
759            Self::In(start, stop, 1) => write!(f, "{}:{}", start, stop),
760            Self::In(start, stop, step) => write!(f, "{}:{}:{}", start, stop, step),
761            Self::Of(indices) => write!(f, "{:?}", indices),
762        }
763    }
764}
765
766/// Compute the shape which results from broadcasting the `left` and `right` shapes, if possible.
767#[inline]
768pub fn broadcast_shape(left: &[usize], right: &[usize]) -> Result<Shape, Error> {
769    let ndim = usize::max(left.len(), right.len());
770    let mut shape = Shape::with_capacity(ndim);
771
772    let mut left = left.iter().rev().copied();
773    let mut right = right.iter().rev().copied();
774
775    while let Some(dim) = broadcast_dim(left.next(), right.next())? {
776        shape.push(dim)
777    }
778
779    shape.reverse();
780
781    Ok(shape)
782}
783
784/// Compute the shapes needed to multiply the `left` and `right` matrices, if possible.
785#[inline]
786pub fn broadcast_matmul_shape(left: &[usize], right: &[usize]) -> Result<(Shape, Shape), Error> {
787    let (left_ndim, right_ndim) = (left.len(), right.len());
788    let ndim = usize::max(left_ndim, right_ndim);
789
790    let mut left = left.iter().rev().copied();
791    let mut right = right.iter().rev().copied();
792
793    let k = right.next().unwrap_or(1);
794    let j = match (left.next(), right.next()) {
795        (Some(jl), Some(jr)) => match (jl, jr) {
796            (jl, jr) if jl == jr => Ok(jl),
797            (1, jr) => Ok(jr),
798            (jl, 1) => Ok(jl),
799            _ => Err(Error::bounds(format!(
800                "cannot matrix-multiply shapes {left:?} and {right:?}"
801            ))),
802        },
803        (Some(jl), None) => Ok(jl),
804        (None, Some(jr)) => Ok(jr),
805        (None, None) => Ok(1),
806    }?;
807    let i = left.next().unwrap_or(1);
808
809    let mut broadcast_shape = Shape::with_capacity(ndim);
810    while let Some(dim) = broadcast_dim(left.next(), right.next())? {
811        broadcast_shape.push(dim);
812    }
813
814    broadcast_shape.reverse();
815
816    let left = broadcast_shape.iter().copied().chain([i, j]).collect();
817    let right = broadcast_shape.into_iter().chain([j, k]).collect();
818    Ok((left, right))
819}
820
821#[inline]
822fn broadcast_dim(left: Option<usize>, right: Option<usize>) -> Result<Option<usize>, Error> {
823    match (left, right) {
824        (Some(l), Some(r)) if l == r => Ok(Some(l)),
825        (Some(1), Some(r)) => Ok(Some(r)),
826        (Some(l), Some(1)) => Ok(Some(l)),
827        (None, Some(r)) => Ok(Some(r)),
828        (Some(l), None) => Ok(Some(l)),
829        (None, None) => Ok(None),
830        (l, r) => Err(Error::bounds(format!(
831            "cannot broadcast dimensions {l:?} and {r:?}"
832        ))),
833    }
834}
835
836#[inline]
837fn range_shape(source_shape: &[usize], range: &[AxisRange]) -> Shape {
838    debug_assert_eq!(source_shape.len(), range.len());
839    range.iter().filter_map(|ar| ar.size()).collect()
840}
841
842/// Construct an iterator over the strides for the given shape and number of dimensions.
843#[inline]
844pub fn strides_for<'a>(shape: &'a [usize], ndim: usize) -> impl Iterator<Item = usize> + 'a {
845    debug_assert!(ndim >= shape.len());
846
847    let zeros = std::iter::repeat_n(0, ndim - shape.len());
848
849    let strides = shape.iter().copied().enumerate().map(|(x, dim)| {
850        if dim == 1 {
851            0
852        } else {
853            shape.iter().rev().take(shape.len() - 1 - x).product()
854        }
855    });
856
857    zeros.chain(strides)
858}