polars-ops 0.54.4

More operations on Polars data structures
Documentation
use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder};
use arrow::array::{Array, IntoBoxedArray, ListArray, NullArray};
use arrow::bitmap::BitmapBuilder;
use arrow::offset::Offsets;
use arrow::pushable::Pushable;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;

type LargeListArray = ListArray<i64>;

fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> {
    polars_ensure!(
       (length_srs == length_by) | (length_by == 1) | (length_srs == 1),
       ShapeMismatch: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}",
       length_srs, length_by
    );
    Ok(())
}

fn new_by(by: &IdxCa, len: usize) -> IdxCa {
    if let Some(x) = by.get(0) {
        let values = std::iter::repeat_n(x, len).collect::<Vec<IdxSize>>();
        IdxCa::new(PlSmallStr::EMPTY, values)
    } else {
        IdxCa::full_null(PlSmallStr::EMPTY, len)
    }
}

fn repeat_by_primitive<T>(ca: &ChunkedArray<T>, by: &IdxCa) -> PolarsResult<ListChunked>
where
    T: PolarsNumericType,
{
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => {
            Ok(arity::binary(ca, by, |arr, by| {
                let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
                    opt_by.map(|by| std::iter::repeat_n(opt_v.copied(), *by as usize))
                });

                // SAFETY: length of iter is trusted.
                unsafe {
                    LargeListArray::from_iter_primitive_trusted_len(
                        iter,
                        T::get_static_dtype().to_arrow(CompatLevel::newest()),
                    )
                }
            }))
        },
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_primitive(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_primitive(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

fn repeat_by_bool(ca: &BooleanChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => {
            Ok(arity::binary(ca, by, |arr, by| {
                let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
                    opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))
                });

                // SAFETY: length of iter is trusted.
                unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
            }))
        },
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_bool(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_bool(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

fn repeat_by_binary(ca: &BinaryChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => {
            Ok(arity::binary(ca, by, |arr, by| {
                let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
                    opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))
                });

                // SAFETY: length of iter is trusted.
                unsafe { LargeListArray::from_iter_binary_trusted_len(iter, ca.len()) }
            }))
        },
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_binary(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_binary(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

fn repeat_by_list(ca: &ListChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_list(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_list(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

fn repeat_by_null(ca: &NullChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => {
            let arr_length = by.iter().flatten().map(|x| x as usize).sum();
            let arr = NullArray::new(ArrowDataType::Null, arr_length);

            let mut validity = BitmapBuilder::with_capacity(by.len());
            let mut offsets = Offsets::<i64>::with_capacity(by.len());
            for n_repeat in by.iter() {
                validity.push(n_repeat.is_some());
                if let Some(repeats) = n_repeat {
                    offsets.push(repeats as usize);
                } else {
                    offsets.push_null();
                }
            }

            let array = LargeListArray::new(
                ListArray::<i64>::default_datatype(arr.dtype().clone()),
                offsets.into(),
                arr.into_boxed(),
                validity.into_opt_validity(),
            );

            Ok(unsafe {
                ListChunked::from_chunks_and_dtype(
                    ca.name().clone(),
                    vec![array.into_boxed()],
                    DataType::List(Box::new(DataType::Null)),
                )
            })
        },
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_null(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            let new_array = new_array.null().unwrap();
            repeat_by_null(new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

#[cfg(feature = "dtype-array")]
fn repeat_by_array(ca: &ArrayChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_array(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_array(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

#[cfg(feature = "dtype-struct")]
fn repeat_by_struct(ca: &StructChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
    check_lengths(ca.len(), by.len())?;

    match (ca.len(), by.len()) {
        (left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
        (_, 1) => {
            let by = new_by(by, ca.len());
            repeat_by_struct(ca, &by)
        },
        (1, _) => {
            let new_array = ca.new_from_index(0, by.len());
            repeat_by_struct(&new_array, by)
        },
        // we have already checked the length
        _ => unreachable!(),
    }
}

fn repeat_by_generic_inner<T: PolarsDataType>(ca: &ChunkedArray<T>, by: &IdxCa) -> ListChunked {
    let mut builder = make_builder(&ca.dtype().to_arrow(CompatLevel::newest()));
    arity::binary(ca, by, |arr, by| {
        let arr_length = by.iter().flatten().map(|x| *x as usize).sum();
        builder.reserve(arr_length);

        let mut validity = BitmapBuilder::with_capacity(by.len());
        let mut offsets = Offsets::<i64>::with_capacity(by.len());
        for (idx, n_repeat) in by.iter().enumerate() {
            validity.push(n_repeat.is_some());
            if let Some(repeats) = n_repeat {
                offsets.push(*repeats as usize);
                builder.subslice_extend_repeated(
                    arr,
                    idx,
                    1,
                    *repeats as usize,
                    ShareStrategy::Always,
                );
            } else {
                offsets.push_null();
            }
        }

        let repeated_values = builder.freeze_reset();
        LargeListArray::new(
            ListArray::<i64>::default_datatype(arr.dtype().clone()),
            offsets.into(),
            repeated_values,
            validity.into_opt_validity(),
        )
    })
}

pub fn repeat_by(s: &Series, by: &IdxCa) -> PolarsResult<ListChunked> {
    let s_phys = s.to_physical_repr();
    use DataType as D;
    let out = match s_phys.dtype() {
        D::Null => repeat_by_null(s_phys.null().unwrap(), by),
        D::Boolean => repeat_by_bool(s_phys.bool().unwrap(), by),
        D::String => {
            let ca = s_phys.str().unwrap();
            repeat_by_binary(&ca.as_binary(), by)
                .and_then(|ca| ca.apply_to_inner(&|s| unsafe { s.cast_unchecked(&D::String) }))
        },
        D::Binary => repeat_by_binary(s_phys.binary().unwrap(), by),
        dt if dt.is_primitive_numeric() => {
            with_match_physical_numeric_polars_type!(dt, |$T| {
                let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref();
                repeat_by_primitive(ca, by)
            })
        },
        D::List(_) => repeat_by_list(s_phys.list().unwrap(), by),
        #[cfg(feature = "dtype-struct")]
        D::Struct(_) => repeat_by_struct(s_phys.struct_().unwrap(), by),
        #[cfg(feature = "dtype-array")]
        D::Array(_, _) => repeat_by_array(s_phys.array().unwrap(), by),
        _ => polars_bail!(opq = repeat_by, s.dtype()),
    };
    out.and_then(|ca| {
        let logical_type = s.dtype();
        ca.apply_to_inner(&|s| unsafe { s.from_physical_unchecked(logical_type) })
    })
}