polars-expr 0.54.1

Physical expression implementation of the Polars project.
Documentation
#![allow(unsafe_op_in_unsafe_fn)]
use polars_compute::moment::{CovState, PearsonState};
use polars_core::prelude::*;
use polars_core::utils::{align_chunks_binary, try_get_supertype};

use super::*;

fn out_dtype(dtype_x: &DataType, dtype_y: &DataType) -> DataType {
    let st = try_get_supertype(dtype_x, dtype_y).unwrap_or(DataType::Float64);
    match st {
        #[cfg(feature = "dtype-f16")]
        DataType::Float16 => DataType::Float16,
        DataType::Float32 => DataType::Float32,
        _ => DataType::Float64,
    }
}

pub fn new_cov_reduction(
    dtype_x: DataType,
    dtype_y: DataType,
    ddof: u8,
) -> PolarsResult<Box<dyn GroupedReduction>> {
    polars_ensure!(
        dtype_x.is_primitive_numeric(),
        InvalidOperation: "`cov` operation not supported for dtype `{dtype_x}`"
    );
    polars_ensure!(
        dtype_y.is_primitive_numeric(),
        InvalidOperation: "`cov` operation not supported for dtype `{dtype_y}`"
    );
    let out_dtype = out_dtype(&dtype_x, &dtype_y);
    Ok(Box::new(CovGroupedReduction {
        values: Vec::new(),
        evicted_values: Vec::new(),
        ddof,
        out_dtype,
    }))
}

struct CovGroupedReduction {
    values: Vec<CovState>,
    evicted_values: Vec<CovState>,
    ddof: u8,
    out_dtype: DataType,
}

impl GroupedReduction for CovGroupedReduction {
    fn new_empty(&self) -> Box<dyn GroupedReduction> {
        Box::new(Self {
            values: Vec::new(),
            evicted_values: Vec::new(),
            ddof: self.ddof,
            out_dtype: self.out_dtype.clone(),
        })
    }

    fn reserve(&mut self, additional: usize) {
        self.values.reserve(additional);
    }

    fn resize(&mut self, num_groups: IdxSize) {
        self.values.resize(num_groups as usize, CovState::default());
    }

    fn update_group(
        &mut self,
        values: &[&Column],
        group_idx: IdxSize,
        _seq_id: u64,
    ) -> PolarsResult<()> {
        assert!(values.len() == 2);
        let sx = values[0].cast(&DataType::Float64)?;
        let sy = values[1].cast(&DataType::Float64)?;
        let cx = sx.f64().unwrap();
        let cy = sy.f64().unwrap();
        let (cx, cy) = align_chunks_binary(cx, cy);
        let state = &mut self.values[group_idx as usize];
        for (ax, ay) in cx.downcast_iter().zip(cy.downcast_iter()) {
            state.combine(&polars_compute::moment::cov(ax, ay));
        }
        Ok(())
    }

    unsafe fn update_groups_while_evicting(
        &mut self,
        values: &[&Column],
        subset: &[IdxSize],
        group_idxs: &[EvictIdx],
        _seq_id: u64,
    ) -> PolarsResult<()> {
        assert!(values.len() == 2);
        assert!(subset.len() == group_idxs.len());
        let sx = values[0]
            .take_slice_unchecked(subset)
            .cast(&DataType::Float64)?;
        let sy = values[1]
            .take_slice_unchecked(subset)
            .cast(&DataType::Float64)?;
        let cx = sx.f64().unwrap();
        let cy = sy.f64().unwrap();
        let ax = cx.downcast_as_array();
        let ay = cy.downcast_as_array();
        if ax.has_nulls() || ay.has_nulls() {
            for ((ox, oy), g) in ax.iter().zip(ay.iter()).zip(group_idxs) {
                let grp = self.values.get_unchecked_mut(g.idx());
                if g.should_evict() {
                    let old = core::mem::take(grp);
                    self.evicted_values.push(old);
                }
                if let (Some(x), Some(y)) = (ox, oy) {
                    grp.insert_one(*x, *y);
                }
            }
        } else {
            for ((x, y), g) in ax.values().iter().zip(ay.values().iter()).zip(group_idxs) {
                let grp = self.values.get_unchecked_mut(g.idx());
                if g.should_evict() {
                    let old = core::mem::take(grp);
                    self.evicted_values.push(old);
                }
                grp.insert_one(*x, *y);
            }
        }
        Ok(())
    }

    unsafe fn combine_subset(
        &mut self,
        other: &dyn GroupedReduction,
        subset: &[IdxSize],
        group_idxs: &[IdxSize],
    ) -> PolarsResult<()> {
        let other = other.as_any().downcast_ref::<Self>().unwrap();
        assert!(subset.len() == group_idxs.len());
        for (i, g) in subset.iter().zip(group_idxs) {
            let v = other.values.get_unchecked(*i as usize);
            let grp = self.values.get_unchecked_mut(*g as usize);
            grp.combine(v);
        }
        Ok(())
    }

    fn take_evictions(&mut self) -> Box<dyn GroupedReduction> {
        Box::new(Self {
            values: core::mem::take(&mut self.evicted_values),
            evicted_values: Vec::new(),
            ddof: self.ddof,
            out_dtype: self.out_dtype.clone(),
        })
    }

    fn finalize(&mut self) -> PolarsResult<Series> {
        let v = core::mem::take(&mut self.values);
        let ddof = self.ddof;
        let ca: Float64Chunked = v
            .into_iter()
            .map(|s| s.finalize(ddof))
            .collect_ca(PlSmallStr::EMPTY);
        ca.into_series().cast(&self.out_dtype)
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

pub fn new_pearson_corr_reduction(
    dtype_x: DataType,
    dtype_y: DataType,
) -> PolarsResult<Box<dyn GroupedReduction>> {
    polars_ensure!(
        dtype_x.is_primitive_numeric(),
        InvalidOperation: "`corr` operation not supported for dtype `{dtype_x}`"
    );
    polars_ensure!(
        dtype_y.is_primitive_numeric(),
        InvalidOperation: "`corr` operation not supported for dtype `{dtype_y}`"
    );
    let out_dtype = out_dtype(&dtype_x, &dtype_y);
    Ok(Box::new(PearsonCorrGroupedReduction {
        values: Vec::new(),
        evicted_values: Vec::new(),
        out_dtype,
    }))
}

struct PearsonCorrGroupedReduction {
    values: Vec<PearsonState>,
    evicted_values: Vec<PearsonState>,
    out_dtype: DataType,
}

impl GroupedReduction for PearsonCorrGroupedReduction {
    fn new_empty(&self) -> Box<dyn GroupedReduction> {
        Box::new(Self {
            values: Vec::new(),
            evicted_values: Vec::new(),
            out_dtype: self.out_dtype.clone(),
        })
    }

    fn reserve(&mut self, additional: usize) {
        self.values.reserve(additional);
    }

    fn resize(&mut self, num_groups: IdxSize) {
        self.values
            .resize(num_groups as usize, PearsonState::default());
    }

    fn update_group(
        &mut self,
        values: &[&Column],
        group_idx: IdxSize,
        _seq_id: u64,
    ) -> PolarsResult<()> {
        assert!(values.len() == 2);
        let sx = values[0].cast(&DataType::Float64)?;
        let sy = values[1].cast(&DataType::Float64)?;
        let cx = sx.f64().unwrap();
        let cy = sy.f64().unwrap();
        let (cx, cy) = align_chunks_binary(cx, cy);
        let state = &mut self.values[group_idx as usize];
        for (ax, ay) in cx.downcast_iter().zip(cy.downcast_iter()) {
            state.combine(&polars_compute::moment::pearson_corr(ax, ay));
        }
        Ok(())
    }

    unsafe fn update_groups_while_evicting(
        &mut self,
        values: &[&Column],
        subset: &[IdxSize],
        group_idxs: &[EvictIdx],
        _seq_id: u64,
    ) -> PolarsResult<()> {
        assert!(values.len() == 2);
        assert!(subset.len() == group_idxs.len());
        let sx = values[0]
            .take_slice_unchecked(subset)
            .cast(&DataType::Float64)?;
        let sy = values[1]
            .take_slice_unchecked(subset)
            .cast(&DataType::Float64)?;
        let cx = sx.f64().unwrap();
        let cy = sy.f64().unwrap();
        let ax = cx.downcast_as_array();
        let ay = cy.downcast_as_array();
        if ax.has_nulls() || ay.has_nulls() {
            for ((ox, oy), g) in ax.iter().zip(ay.iter()).zip(group_idxs) {
                let grp = self.values.get_unchecked_mut(g.idx());
                if g.should_evict() {
                    let old = core::mem::take(grp);
                    self.evicted_values.push(old);
                }
                if let (Some(x), Some(y)) = (ox, oy) {
                    grp.insert_one(*x, *y);
                }
            }
        } else {
            for ((x, y), g) in ax.values().iter().zip(ay.values().iter()).zip(group_idxs) {
                let grp = self.values.get_unchecked_mut(g.idx());
                if g.should_evict() {
                    let old = core::mem::take(grp);
                    self.evicted_values.push(old);
                }
                grp.insert_one(*x, *y);
            }
        }
        Ok(())
    }

    unsafe fn combine_subset(
        &mut self,
        other: &dyn GroupedReduction,
        subset: &[IdxSize],
        group_idxs: &[IdxSize],
    ) -> PolarsResult<()> {
        let other = other.as_any().downcast_ref::<Self>().unwrap();
        assert!(subset.len() == group_idxs.len());
        for (i, g) in subset.iter().zip(group_idxs) {
            let v = other.values.get_unchecked(*i as usize);
            let grp = self.values.get_unchecked_mut(*g as usize);
            grp.combine(v);
        }
        Ok(())
    }

    fn take_evictions(&mut self) -> Box<dyn GroupedReduction> {
        Box::new(Self {
            values: core::mem::take(&mut self.evicted_values),
            evicted_values: Vec::new(),
            out_dtype: self.out_dtype.clone(),
        })
    }

    fn finalize(&mut self) -> PolarsResult<Series> {
        let v = core::mem::take(&mut self.values);
        let ca: Float64Chunked = v
            .into_iter()
            .map(|s| {
                if s.weight() == 0.0 {
                    None
                } else {
                    Some(s.finalize())
                }
            })
            .collect_ca(PlSmallStr::EMPTY);
        ca.into_series().cast(&self.out_dtype)
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}