medmodels-core 0.4.9

Limebit MedModels Crate
use super::{MultipleValuesWithIndexOperand, SingleValueWithIndexOperand};
use crate::{
    errors::MedRecordResult,
    medrecord::querying::{
        attributes::{MultipleAttributesWithIndexOperand, MultipleAttributesWithIndexOperation},
        group_by::{GroupOperand, GroupedOperand, Ungroup},
        values::{
            operand::MultipleValuesWithoutIndexOperand,
            operation::MultipleValuesWithoutIndexOperation, MultipleValuesWithIndexContext,
            MultipleValuesWithoutIndexContext, SingleKindWithoutIndex,
            SingleValueWithoutIndexOperand,
        },
        wrapper::Wrapper,
        BoxedIterator, DeepClone, EvaluateBackward, EvaluateForwardGrouped, GroupedIterator,
        RootOperand,
    },
    MedRecord,
};
use medmodels_utils::traits::ReadWriteOrPanic;
use std::fmt::Debug;

#[derive(Debug, Clone)]
pub enum MultipleValuesWithIndexOperandContext<O: RootOperand> {
    RootOperand(GroupOperand<O>),
    MultipleAttributesOperand(GroupOperand<MultipleAttributesWithIndexOperand<O>>),
}

impl<O: RootOperand> DeepClone for MultipleValuesWithIndexOperandContext<O> {
    fn deep_clone(&self) -> Self {
        match self {
            Self::RootOperand(operand) => Self::RootOperand(operand.deep_clone()),
            Self::MultipleAttributesOperand(operand) => {
                Self::MultipleAttributesOperand(operand.deep_clone())
            }
        }
    }
}

impl<O: RootOperand> From<GroupOperand<O>> for MultipleValuesWithIndexOperandContext<O> {
    fn from(operand: GroupOperand<O>) -> Self {
        Self::RootOperand(operand)
    }
}

impl<O: RootOperand> From<GroupOperand<MultipleAttributesWithIndexOperand<O>>>
    for MultipleValuesWithIndexOperandContext<O>
{
    fn from(operand: GroupOperand<MultipleAttributesWithIndexOperand<O>>) -> Self {
        Self::MultipleAttributesOperand(operand)
    }
}

impl<O: RootOperand> GroupedOperand for MultipleValuesWithIndexOperand<O> {
    type Context = MultipleValuesWithIndexOperandContext<O>;
}

impl<'a, O: RootOperand> EvaluateBackward<'a> for GroupOperand<MultipleValuesWithIndexOperand<O>>
where
    O: 'a,
{
    type ReturnValue = GroupedIterator<
        'a,
        <MultipleValuesWithIndexOperand<O> as EvaluateBackward<'a>>::ReturnValue,
    >;

    fn evaluate_backward(&self, medrecord: &'a MedRecord) -> MedRecordResult<Self::ReturnValue> {
        match &self.context {
            MultipleValuesWithIndexOperandContext::RootOperand(context) => {
                let partitions = context.evaluate_backward(medrecord)?;

                let values: Vec<_> = partitions
                    .map(|(key, partition)| {
                        let MultipleValuesWithIndexContext::Operand((_, attribute)) =
                            &self.operand.0.read_or_panic().context
                        else {
                            unreachable!()
                        };

                        let reduced_partition: BoxedIterator<_> = Box::new(
                            O::get_values_from_indices(medrecord, attribute.clone(), partition),
                        );

                        (key, reduced_partition)
                    })
                    .collect();

                self.operand
                    .evaluate_forward_grouped(medrecord, Box::new(values.into_iter()))
            }
            MultipleValuesWithIndexOperandContext::MultipleAttributesOperand(context) => {
                let partitions = context.evaluate_backward(medrecord)?;

                let values: Vec<_> = partitions
                    .map(|(key, partition)| {
                        let reduced_partition: BoxedIterator<_> =
                            Box::new(MultipleAttributesWithIndexOperation::<O>::get_values(
                                medrecord, partition,
                            )?);

                        Ok((key, reduced_partition))
                    })
                    .collect::<MedRecordResult<_>>()?;

                self.operand
                    .evaluate_forward_grouped(medrecord, Box::new(values.into_iter()))
            }
        }
    }
}

impl<O: RootOperand> Ungroup for GroupOperand<MultipleValuesWithIndexOperand<O>> {
    type OutputOperand = MultipleValuesWithIndexOperand<O>;

    fn ungroup(&self) -> Wrapper<Self::OutputOperand> {
        let operand = Wrapper::<Self::OutputOperand>::new(
            MultipleValuesWithIndexContext::MultipleValuesWithIndexGroupByOperand(
                self.deep_clone(),
            ),
        );

        self.operand.push_merge_operation(operand.clone());

        operand
    }
}

impl<O: RootOperand> GroupedOperand for SingleValueWithIndexOperand<O> {
    type Context = GroupOperand<MultipleValuesWithIndexOperand<O>>;
}

impl<'a, O: 'a + RootOperand> EvaluateBackward<'a>
    for GroupOperand<SingleValueWithIndexOperand<O>>
{
    type ReturnValue =
        GroupedIterator<'a, <SingleValueWithIndexOperand<O> as EvaluateBackward<'a>>::ReturnValue>;

    fn evaluate_backward(&self, medrecord: &'a MedRecord) -> MedRecordResult<Self::ReturnValue> {
        let partitions = self.context.evaluate_backward(medrecord)?;

        let values: Vec<_> = partitions
            .map(|(key, partition)| {
                let reduced_partition = self.operand.reduce_input(partition)?;

                Ok((key, reduced_partition))
            })
            .collect::<MedRecordResult<_>>()?;

        self.operand
            .evaluate_forward_grouped(medrecord, Box::new(values.into_iter()))
    }
}

impl<O: RootOperand> Ungroup for GroupOperand<SingleValueWithIndexOperand<O>> {
    type OutputOperand = MultipleValuesWithIndexOperand<O>;

    fn ungroup(&self) -> Wrapper<Self::OutputOperand> {
        let operand = Wrapper::<Self::OutputOperand>::new(
            MultipleValuesWithIndexContext::SingleValueWithIndexGroupByOperand(self.deep_clone()),
        );

        self.operand.push_merge_operation(operand.clone());

        operand
    }
}

impl<O: RootOperand> GroupedOperand for SingleValueWithoutIndexOperand<O> {
    type Context = GroupOperand<MultipleValuesWithIndexOperand<O>>;
}

impl<'a, O: 'a + RootOperand> EvaluateBackward<'a>
    for GroupOperand<SingleValueWithoutIndexOperand<O>>
{
    type ReturnValue = GroupedIterator<
        'a,
        <SingleValueWithoutIndexOperand<O> as EvaluateBackward<'a>>::ReturnValue,
    >;

    fn evaluate_backward(&self, medrecord: &'a MedRecord) -> MedRecordResult<Self::ReturnValue> {
        let partitions = self.context.evaluate_backward(medrecord)?;

        let values: Vec<_> = partitions
            .map(|(key, partition)| {
                let partition = partition.map(|(_, value)| value);

                let reduced_partition = match self.operand.0.read_or_panic().kind {
                    SingleKindWithoutIndex::Max => {
                        MultipleValuesWithoutIndexOperation::<O>::get_max(partition)?
                    }
                    SingleKindWithoutIndex::Min => {
                        MultipleValuesWithoutIndexOperation::<O>::get_min(partition)?
                    }
                    SingleKindWithoutIndex::Mean => {
                        MultipleValuesWithoutIndexOperation::<O>::get_mean(partition)?
                    }
                    SingleKindWithoutIndex::Median => {
                        MultipleValuesWithoutIndexOperation::<O>::get_median(partition)?
                    }
                    SingleKindWithoutIndex::Mode => {
                        MultipleValuesWithoutIndexOperation::<O>::get_mode(partition)?
                    }
                    SingleKindWithoutIndex::Std => {
                        MultipleValuesWithoutIndexOperation::<O>::get_std(partition)?
                    }
                    SingleKindWithoutIndex::Var => {
                        MultipleValuesWithoutIndexOperation::<O>::get_var(partition)?
                    }
                    SingleKindWithoutIndex::Count => Some(
                        MultipleValuesWithoutIndexOperation::<O>::get_count(partition),
                    ),
                    SingleKindWithoutIndex::Sum => {
                        MultipleValuesWithoutIndexOperation::<O>::get_sum(partition)?
                    }
                    SingleKindWithoutIndex::Random => {
                        MultipleValuesWithoutIndexOperation::<O>::get_random(partition)
                    }
                };

                Ok((key, reduced_partition))
            })
            .collect::<MedRecordResult<_>>()?;

        self.operand
            .evaluate_forward_grouped(medrecord, Box::new(values.into_iter()))
    }
}

impl<O: RootOperand> Ungroup for GroupOperand<SingleValueWithoutIndexOperand<O>> {
    type OutputOperand = MultipleValuesWithoutIndexOperand<O>;

    fn ungroup(&self) -> Wrapper<Self::OutputOperand> {
        let operand = Wrapper::<Self::OutputOperand>::new(
            MultipleValuesWithoutIndexContext::GroupByOperand(self.deep_clone()),
        );

        self.operand.push_merge_operation(operand.clone());

        operand
    }
}