iter_cartesian 0.1.0

A Cartesian product iterator with double-ended iteration and O(1) length queries.
Documentation
#![feature(trusted_len)]
#![deny(unsafe_op_in_unsafe_fn)]

use std::fmt;
use std::iter::{FusedIterator, TrustedLen};

/// An iterator that yields the Cartesian product of all elements from two iterators.
///
/// This `struct` is created by the [`cartesian`] method on [`CartesianExt`].
///
/// # Performance
///
/// Caches the secondary length for O(1) [`len`] and [`size_hint`].
///
/// [`cartesian`]: CartesianExt::cartesian
/// [`len`]: ExactSizeIterator::len
/// [`size_hint`]: Iterator::size_hint
///
/// # Examples
///
/// ```rust
/// use iter_cartesian::CartesianExt;
///
/// let v1 = vec![1, 2];
/// let v2 = vec![3, 4];
/// let mut it = v1.into_iter().cartesian(v2.into_iter());
///
/// assert_eq!(it.next(), Some((1, 3)));
/// assert_eq!(it.next(), Some((1, 4)));
/// assert_eq!(it.next(), Some((2, 3)));
/// assert_eq!(it.next(), Some((2, 4)));
/// assert_eq!(it.next(), None);
/// ```
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
{
    primary: I,
    front_secondary: J,
    back_secondary: J,
    orig_secondary: J,
    front_primary: Option<I::Item>,
    back_primary: Option<I::Item>,
    secondary_len: usize,
    primary_len: usize,
    shared_row: bool,
}

impl<I, J> CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
{
    pub(crate) fn new(mut primary: I, secondary: J) -> Self {
        let secondary_len = secondary.len();
        let front_primary = primary.next();
        let back_primary = if front_primary.is_some() {
            primary.next_back()
        } else {
            None
        };
        let primary_len = primary.len();
        // Primary had exactly one element: front owns the only row, no crossing possible.
        let shared_row = front_primary.is_some() && back_primary.is_none() && primary_len == 0;

        Self {
            orig_secondary: secondary.clone(),
            front_secondary: secondary.clone(),
            back_secondary: secondary,
            primary,
            front_primary,
            back_primary,
            secondary_len,
            primary_len,
            shared_row,
        }
    }
}

impl<I, J> Clone for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + Clone + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
    fn clone(&self) -> Self {
        Self {
            primary: self.primary.clone(),
            front_secondary: self.front_secondary.clone(),
            back_secondary: self.back_secondary.clone(),
            orig_secondary: self.orig_secondary.clone(),
            front_primary: self.front_primary.clone(),
            back_primary: self.back_primary.clone(),
            secondary_len: self.secondary_len,
            primary_len: self.primary_len,
            shared_row: self.shared_row,
        }
    }
}

impl<I, J> fmt::Debug for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + fmt::Debug + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + fmt::Debug + ExactSizeIterator,
    I::Item: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("CartesianProduct")
            .field("front_primary", &self.front_primary)
            .field("back_primary", &self.back_primary)
            .field("secondary_len", &self.secondary_len)
            .field("primary_len", &self.primary_len)
            .field("shared_row", &self.shared_row)
            .finish()
    }
}

impl<I, J> Iterator for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
    type Item = (I::Item, J::Item);

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        loop {
            if let Some(p) = &self.front_primary {
                if let Some(s) = self.front_secondary.next() {
                    if self.shared_row && self.back_primary.is_some() {
                        let yielded = (self.secondary_len - self.front_secondary.len())
                            + (self.secondary_len - self.back_secondary.len());
                        if yielded >= self.secondary_len {
                            self.front_primary = None;
                            return None;
                        }
                    }
                    return Some((p.clone(), s));
                }
            }

            if self.primary_len > 0 {
                self.front_primary = self.primary.next();
                self.front_secondary = self.orig_secondary.clone();
                self.primary_len -= 1;
            } else if let Some(bp) = self.back_primary.take() {
                self.front_primary = Some(bp);
                self.front_secondary = self.back_secondary.clone();
                self.shared_row = true;
            } else {
                self.front_primary = None;
                return None;
            }
        }
    }

    #[inline]
    fn size_hint(&self) -> (usize, Option<usize>) {
        let len = self.len();
        (len, Some(len))
    }

    // TODO: Use `fold` and `for_each` to avoid repeated `clone()` on the
    // primary element within a row, the default impls call `next()` in a loop
    // which clones on every iteration.
    #[inline]
    fn count(self) -> usize {
        self.len()
    }

    #[inline]
    fn last(mut self) -> Option<Self::Item> {
        self.next_back()
    }

    // TODO: Implement `advance_by` once it stabilises as it would allow `nth` to delegate
    // to the inner iterators' own skip logic rather than doing it manually.
    #[inline]
    fn nth(&mut self, n: usize) -> Option<Self::Item> {
        let length = self.len();
        if n >= length {
            self.front_primary = None;
            self.back_primary = None;
            self.primary_len = 0;
            self.shared_row = false;
            return None;
        }

        let front_avail = self.front_secondary.len();
        if n < front_avail {
            let p = self.front_primary.as_ref().unwrap_or_else(|| {
                unreachable!("front_primary is set whenever front_secondary is non-empty")
            });
            return self.front_secondary.nth(n).map(|s| (p.clone(), s));
        }

        let n_after_first = n - front_avail;
        let rows_to_skip = n_after_first / self.secondary_len;
        let rem_in_target = n_after_first % self.secondary_len;

        if rows_to_skip < self.primary_len {
            if rows_to_skip > 0 {
                self.primary.nth(rows_to_skip - 1);
            }
            self.front_primary = self.primary.next();
            self.front_secondary = self.orig_secondary.clone();
            self.primary_len -= rows_to_skip + 1;
        } else {
            self.primary_len = 0;
            self.front_primary = self.back_primary.take();
            self.front_secondary = self.back_secondary.clone();
            self.shared_row = true;
        }

        self.front_secondary
            .nth(rem_in_target)
            .and_then(|s| self.front_primary.as_ref().map(|p| (p.clone(), s)))
    }
}

impl<I, J> DoubleEndedIterator for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
    #[inline]
    fn next_back(&mut self) -> Option<Self::Item> {
        loop {
            if let Some(p) = &self.back_primary {
                if let Some(s) = self.back_secondary.next_back() {
                    if self.shared_row && self.front_primary.is_some() {
                        let yielded = (self.secondary_len - self.front_secondary.len())
                            + (self.secondary_len - self.back_secondary.len());
                        if yielded >= self.secondary_len {
                            self.back_primary = None;
                            return None;
                        }
                    }
                    return Some((p.clone(), s));
                }
            }

            if self.primary_len > 0 {
                self.back_primary = self.primary.next_back();
                self.back_secondary = self.orig_secondary.clone();
                self.primary_len -= 1;
            } else if let Some(fp) = self.front_primary.take() {
                self.back_primary = Some(fp);
                self.back_secondary = self.front_secondary.clone();
                self.shared_row = true;
            } else {
                self.back_primary = None;
                return None;
            }
        }
    }
}

impl<I, J> ExactSizeIterator for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
    #[inline]
    fn len(&self) -> usize {
        let mid_len = self
            .primary_len
            .checked_mul(self.secondary_len)
            .expect("capacity overflow");

        match (&self.front_primary, &self.back_primary) {
            (Some(_), Some(_)) => {
                let total = self
                    .front_secondary
                    .len()
                    .checked_add(mid_len)
                    .and_then(|n| n.checked_add(self.back_secondary.len()))
                    .expect("capacity overflow");
                if self.shared_row {
                    total.saturating_sub(self.secondary_len)
                } else {
                    total
                }
            }
            (Some(_), None) => self
                .front_secondary
                .len()
                .checked_add(mid_len)
                .expect("capacity overflow"),
            (None, Some(_)) => self
                .back_secondary
                .len()
                .checked_add(mid_len)
                .expect("capacity overflow"),
            (None, None) => 0,
        }
    }
}

impl<I, J> FusedIterator for CartesianProduct<I, J>
where
    I: DoubleEndedIterator + ExactSizeIterator,
    J: DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
}

// SAFETY: `len()` is always exact, and it panics on overflow rather than
// returning a wrong value, satisfying the contract that the iterator
// yields exactly `len()` more items.
unsafe impl<I, J> TrustedLen for CartesianProduct<I, J>
where
    I: TrustedLen + DoubleEndedIterator + ExactSizeIterator,
    J: TrustedLen + DoubleEndedIterator + Clone + ExactSizeIterator,
    I::Item: Clone,
{
}

// TODO: Relax the `ExactSizeIterator` bound on the primary once a weaker
// double-ended length query is available, the secondary bound is load-bearing
// for `TrustedLen` but the primary bound is only needed for `len()` and `nth()`.

/// Extension trait that adds [`cartesian`] to any double-ended exact-size iterator.
///
/// [`cartesian`]: CartesianExt::cartesian
pub trait CartesianExt: DoubleEndedIterator + ExactSizeIterator + Sized {
    /// Returns an iterator over the Cartesian product `self × other`.
    ///
    /// The secondary iterator must be [`Clone`] because it is rewound for
    /// each element of the primary.
    fn cartesian<J>(self, other: J) -> CartesianProduct<Self, J>
    where
        J: DoubleEndedIterator + Clone + ExactSizeIterator;
}

// TODO: Consider a `cartesian_product` free function for iterators that don't
// meet the `DoubleEndedIterator + ExactSizeIterator` bounds,
// allowing for more generic usage.
impl<I> CartesianExt for I
where
    I: DoubleEndedIterator + ExactSizeIterator + Sized,
{
    fn cartesian<J>(self, other: J) -> CartesianProduct<Self, J>
    where
        J: DoubleEndedIterator + Clone + ExactSizeIterator,
    {
        CartesianProduct::new(self, other)
    }
}