medmodels-core 0.4.9

Limebit MedModels Crate
use super::NodeOperand;
use crate::{
    errors::MedRecordResult,
    medrecord::querying::{
        edges::EdgeOperand,
        group_by::{GroupBy, GroupOperand, GroupedOperand, PartitionGroups, Ungroup},
        nodes::{NodeIndexOperand, NodeIndicesOperand, NodeIndicesOperandContext},
        wrapper::Wrapper,
        DeepClone, EvaluateBackward, EvaluateForward, EvaluateForwardGrouped, GroupedIterator,
    },
    prelude::MedRecordAttribute,
    MedRecord,
};

#[derive(Debug, Clone)]
pub enum NodeOperandContext {
    Discriminator(<NodeOperand as GroupBy>::Discriminator),
    Nodes(Box<GroupOperand<NodeOperand>>),
    Edges(Box<GroupOperand<EdgeOperand>>),
}

impl DeepClone for NodeOperandContext {
    fn deep_clone(&self) -> Self {
        match self {
            Self::Discriminator(discriminator) => Self::Discriminator(discriminator.clone()),
            Self::Nodes(operand) => Self::Nodes(operand.deep_clone()),
            Self::Edges(operand) => Self::Edges(operand.deep_clone()),
        }
    }
}

impl From<<NodeOperand as GroupBy>::Discriminator> for NodeOperandContext {
    fn from(discriminator: <NodeOperand as GroupBy>::Discriminator) -> Self {
        Self::Discriminator(discriminator)
    }
}

impl From<GroupOperand<NodeOperand>> for NodeOperandContext {
    fn from(operand: GroupOperand<NodeOperand>) -> Self {
        Self::Nodes(Box::new(operand))
    }
}

impl From<GroupOperand<EdgeOperand>> for NodeOperandContext {
    fn from(operand: GroupOperand<EdgeOperand>) -> Self {
        Self::Edges(Box::new(operand))
    }
}

impl GroupedOperand for NodeOperand {
    type Context = NodeOperandContext;
}

#[derive(Debug, Clone)]
pub enum NodeOperandGroupDiscriminator {
    Attribute(MedRecordAttribute),
}

impl DeepClone for NodeOperandGroupDiscriminator {
    fn deep_clone(&self) -> Self {
        match self {
            Self::Attribute(attr) => Self::Attribute(attr.clone()),
        }
    }
}

impl<'a> EvaluateForward<'a> for GroupOperand<NodeOperand> {
    type InputValue = <NodeOperand as EvaluateForward<'a>>::InputValue;
    type ReturnValue = GroupedIterator<'a, <NodeOperand as EvaluateForward<'a>>::ReturnValue>;

    fn evaluate_forward(
        &self,
        medrecord: &'a MedRecord,
        indices: Self::InputValue,
    ) -> MedRecordResult<Self::ReturnValue> {
        match &self.context {
            NodeOperandContext::Discriminator(discriminator) => {
                let partitions = NodeOperand::partition(medrecord, indices, discriminator);

                self.operand
                    .evaluate_forward_grouped(medrecord, Box::new(partitions))
            }
            NodeOperandContext::Nodes(_) => unreachable!(),
            NodeOperandContext::Edges(_) => unreachable!(),
        }
    }
}

impl GroupedOperand for NodeIndicesOperand {
    type Context = GroupOperand<NodeOperand>;
}

impl<'a> EvaluateBackward<'a> for GroupOperand<NodeIndicesOperand> {
    type ReturnValue =
        GroupedIterator<'a, <NodeIndicesOperand as EvaluateBackward<'a>>::ReturnValue>;

    fn evaluate_backward(&self, medrecord: &'a MedRecord) -> MedRecordResult<Self::ReturnValue> {
        let partitions = self.context.evaluate_backward(medrecord)?;

        let indices = Box::new(partitions.map(|(key, partition)| {
            let reduced_partition = self.operand.reduce_input(partition)?;

            Ok((key, reduced_partition))
        }))
        .collect::<MedRecordResult<Vec<_>>>()?;

        self.operand
            .evaluate_forward_grouped(medrecord, Box::new(indices.into_iter()))
    }
}

impl Ungroup for GroupOperand<NodeIndicesOperand> {
    type OutputOperand = NodeIndicesOperand;

    fn ungroup(&self) -> Wrapper<Self::OutputOperand> {
        let operand = Wrapper::<Self::OutputOperand>::new(
            NodeIndicesOperandContext::NodeIndicesGroupByOperand(self.deep_clone()),
        );

        self.operand.push_merge_operation(operand.clone());

        operand
    }
}

impl GroupedOperand for NodeIndexOperand {
    type Context = GroupOperand<NodeIndicesOperand>;
}

impl<'a> EvaluateBackward<'a> for GroupOperand<NodeIndexOperand> {
    type ReturnValue = GroupedIterator<'a, <NodeIndexOperand as EvaluateBackward<'a>>::ReturnValue>;

    fn evaluate_backward(&self, medrecord: &'a MedRecord) -> MedRecordResult<Self::ReturnValue> {
        let partitions = self.context.evaluate_backward(medrecord)?;

        let indices: Vec<_> = partitions
            .map(|(key, partition)| {
                let reduced_partition = self.operand.reduce_input(partition)?;

                Ok((key, reduced_partition))
            })
            .collect::<MedRecordResult<_>>()?;

        self.operand
            .evaluate_forward_grouped(medrecord, Box::new(indices.into_iter()))
    }
}

impl Ungroup for GroupOperand<NodeIndexOperand> {
    type OutputOperand = NodeIndicesOperand;

    fn ungroup(&self) -> Wrapper<Self::OutputOperand> {
        let operand = Wrapper::<Self::OutputOperand>::new(
            NodeIndicesOperandContext::NodeIndexGroupByOperand(self.deep_clone()),
        );

        self.operand.push_merge_operation(operand.clone());

        operand
    }
}