use polars::lazy::dsl::Expr;
use polars::prelude::*;
use std::fmt::{Debug, Formatter};
use crate::core::{Metric, MetricSpace};
use crate::metrics::{
AbsoluteDistance, FrameDistance, L0InfDistance, L0PInfDistance, LInfDistance, LpDistance,
};
use crate::traits::ProductOrd;
use crate::transformations::traits::UnboundedMetric;
use crate::{core::Domain, error::Fallible};
use super::{Frame, FrameDomain, LazyFrameDomain, Margin, SeriesDomain};
#[cfg(feature = "ffi")]
mod ffi;
#[derive(Clone, PartialEq, Debug)]
pub enum Context {
RowByRow,
Aggregation { margin: Margin },
}
impl Context {
pub fn aggregation(&self, operation: &str) -> Fallible<Margin> {
match self {
Context::RowByRow { .. } => fallible!(
MakeDomain,
"{} is not allowed in a row-by-row context",
operation
),
Context::Aggregation { margin } => Ok(margin.clone()),
}
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct WildExprDomain {
pub columns: Vec<SeriesDomain>,
pub context: Context,
}
impl Domain for WildExprDomain {
type Carrier = DslPlan;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
self.clone()
.to_frame_domain()?
.member(&LazyFrame::from(val.clone()))
}
}
impl WildExprDomain {
pub fn as_row_by_row(&self) -> Self {
Self {
columns: self.columns.clone(),
context: Context::RowByRow,
}
}
fn to_frame_domain<F: Frame>(self) -> Fallible<FrameDomain<F>> {
FrameDomain::new_with_margins(
self.columns,
match self.context {
Context::RowByRow => Vec::new(),
Context::Aggregation { margin } => {
vec![margin]
}
},
)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct ExprDomain {
pub column: SeriesDomain,
pub context: Context,
}
impl LazyFrameDomain {
pub fn select(self) -> WildExprDomain {
self.aggregate::<Expr, 0>([])
}
pub fn aggregate<S: Into<Expr>, const P: usize>(self, by: [S; P]) -> WildExprDomain {
let by = by.map(|s| s.into()).into();
let margin = self.get_margin(&by);
WildExprDomain {
columns: self.series_domains,
context: Context::Aggregation { margin },
}
}
pub fn row_by_row(self) -> WildExprDomain {
WildExprDomain {
columns: self.series_domains,
context: Context::RowByRow,
}
}
}
#[derive(Clone)]
pub struct ExprPlan {
pub plan: DslPlan,
pub expr: Expr,
pub fill: Option<Expr>,
}
impl Debug for ExprPlan {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExprPlan")
.field("expr", &self.expr)
.field("default", &self.fill.is_some())
.finish()
}
}
impl ExprPlan {
pub fn then(&self, function: impl Fn(Expr) -> Expr) -> Self {
Self {
plan: self.plan.clone(),
expr: function(self.expr.clone()),
fill: self.fill.clone().map(function),
}
}
}
impl From<DslPlan> for ExprPlan {
fn from(value: DslPlan) -> Self {
ExprPlan {
plan: value,
expr: Expr::Selector(all()),
fill: None,
}
}
}
impl From<LazyFrame> for ExprPlan {
fn from(value: LazyFrame) -> Self {
ExprPlan::from(value.logical_plan)
}
}
impl Domain for ExprDomain {
type Carrier = ExprPlan;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
let (plan, expr) = (LazyFrame::from(val.plan.clone()), val.expr.clone());
let frame = match &self.context {
Context::RowByRow { .. } => plan.select([expr]),
Context::Aggregation { margin } => plan
.group_by(margin.by.iter().cloned().collect::<Vec<_>>())
.agg([expr.clone()]),
}
.collect()?;
let series = frame.column(&self.column.name)?.as_materialized_series();
if !(self.column).member(series)? {
return Ok(false);
}
match &self.context {
Context::RowByRow => (),
Context::Aggregation { margin } => {
if !margin.member(frame.lazy().group_by(&Vec::from_iter(margin.by.clone())))? {
return Ok(false);
}
}
}
Ok(true)
}
}
pub trait OuterMetric: 'static + Metric {
type InnerMetric: Metric;
fn inner_metric(&self) -> Self::InnerMetric;
}
impl<M: UnboundedMetric> OuterMetric for FrameDistance<M> {
type InnerMetric = M;
fn inner_metric(&self) -> Self::InnerMetric {
self.0.clone()
}
}
impl<const P: usize, M: 'static + Metric> OuterMetric for L0PInfDistance<P, M> {
type InnerMetric = M;
fn inner_metric(&self) -> Self::InnerMetric {
self.0.clone()
}
}
impl<M: 'static + Metric> OuterMetric for L0InfDistance<M> {
type InnerMetric = M;
fn inner_metric(&self) -> Self::InnerMetric {
self.0.clone()
}
}
impl<const P: usize, Q: 'static> OuterMetric for LpDistance<P, Q> {
type InnerMetric = AbsoluteDistance<Q>;
fn inner_metric(&self) -> Self::InnerMetric {
AbsoluteDistance::default()
}
}
impl<M: UnboundedMetric> MetricSpace for (WildExprDomain, FrameDistance<M>) {
fn check_space(&self) -> Fallible<()> {
let (expr_domain, metric) = self;
(
expr_domain.clone().to_frame_domain::<DslPlan>()?,
metric.clone(),
)
.check_space()
}
}
impl<const P: usize, M: UnboundedMetric> MetricSpace for (WildExprDomain, L0PInfDistance<P, M>) {
fn check_space(&self) -> Fallible<()> {
let (expr_domain, L0PInfDistance(inner_metric)) = self;
(
expr_domain.clone().to_frame_domain::<DslPlan>()?,
inner_metric.clone(),
)
.check_space()
}
}
impl<M: UnboundedMetric> MetricSpace for (ExprDomain, FrameDistance<M>) {
fn check_space(&self) -> Fallible<()> {
Ok(())
}
}
impl<Q: ProductOrd, const P: usize> MetricSpace for (ExprDomain, LpDistance<P, Q>) {
fn check_space(&self) -> Fallible<()> {
if ![1, 2].contains(&P) {
return fallible!(MetricSpace, "P must be 1 or 2");
}
let column = &self.0.column;
if column.nullable {
return fallible!(
MetricSpace,
"LpDistance between vectors with nulls is undefined"
);
}
if !column.dtype().is_primitive_numeric() {
return fallible!(
MetricSpace,
"LpDistance is only well defined for numeric data"
);
}
Ok(())
}
}
impl<Q: ProductOrd> MetricSpace for (ExprDomain, LInfDistance<Q>) {
fn check_space(&self) -> Fallible<()> {
let column = &self.0.column;
if column.nullable {
return fallible!(
MetricSpace,
"LInfDistance between vectors with nulls is undefined"
);
}
if let DataType::Array(inner_dtype, _) = column.dtype() {
if !inner_dtype.is_primitive_numeric() {
return fallible!(
MetricSpace,
"LInfDistance is only well defined for numeric array data"
);
}
} else {
return fallible!(
MetricSpace,
"LInfDistance is only well defined for array data"
);
}
Ok(())
}
}
impl<Q: ProductOrd> MetricSpace for (ExprDomain, L0InfDistance<LInfDistance<Q>>) {
fn check_space(&self) -> Fallible<()> {
let (expr_domain, L0InfDistance(inner_metric)) = self;
(expr_domain.clone(), inner_metric.clone()).check_space()
}
}
impl<const P: usize, M: UnboundedMetric> MetricSpace for (ExprDomain, L0PInfDistance<P, M>) {
fn check_space(&self) -> Fallible<()> {
Ok(())
}
}