solverforge-scoring 0.15.0

Incremental constraint scoring for SolverForge
Documentation
use std::hash::Hash;
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use crate::api::analysis::DetailedConstraintMatch;
use crate::api::constraint_set::{ConstraintSet, IncrementalConstraint};
use crate::constraint::grouped::GroupedTerminalScorer;
use crate::stream::collection_extract::CollectionExtract;
use crate::stream::collector::{Accumulator, Collector};
use crate::stream::ConstraintWeight;

use super::shared_set::SharedCrossGroupedConstraintSet;
use super::state::CrossGroupedNodeState;

type Inner<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc> =
    SharedCrossGroupedConstraintSet<
        S,
        A,
        B,
        JK,
        GK,
        EA,
        EB,
        KA,
        KB,
        F,
        GF,
        C,
        V,
        R,
        Acc,
        GroupedTerminalScorer<GK, R, W, Sc>,
        Sc,
    >;

pub struct CrossGroupedConstraint<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc>
where
    Acc: Accumulator<V, R>,
    Sc: Score,
{
    is_hard: bool,
    inner: Inner<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc>,
    _phantom: PhantomData<fn() -> (A, B, V, R, Acc)>,
}

impl<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc>
    CrossGroupedConstraint<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc>
where
    S: Send + Sync + 'static,
    A: Send + Sync + 'static,
    B: Send + Sync + 'static,
    JK: Eq + Hash + Send + Sync + 'static,
    GK: Eq + Hash + Send + Sync + 'static,
    EA: CollectionExtract<S, Item = A> + Send + Sync,
    EB: CollectionExtract<S, Item = B> + Send + Sync,
    KA: Fn(&A) -> JK + Send + Sync,
    KB: Fn(&B) -> JK + Send + Sync,
    F: Fn(&S, &A, &B, usize, usize) -> bool + Send + Sync,
    GF: Fn(&A, &B) -> GK + Send + Sync,
    C: for<'i> Collector<(&'i A, &'i B), Value = V, Result = R, Accumulator = Acc> + Send + Sync,
    V: Send + Sync + 'static,
    R: Send + Sync + 'static,
    Acc: Accumulator<V, R> + Send + Sync + 'static,
    W: Fn(&GK, &R) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        extractor_a: EA,
        extractor_b: EB,
        key_a: KA,
        key_b: KB,
        filter: F,
        group_key_fn: GF,
        collector: C,
        weight_fn: W,
        is_hard: bool,
    ) -> Self {
        let state = CrossGroupedNodeState::new(
            extractor_a,
            extractor_b,
            key_a,
            key_b,
            filter,
            group_key_fn,
            collector,
        );
        let scorer = GroupedTerminalScorer::new(constraint_ref, impact_type, weight_fn, is_hard);
        Self {
            is_hard,
            inner: SharedCrossGroupedConstraintSet::new(state, scorer),
            _phantom: PhantomData,
        }
    }

    pub fn penalize<W2>(
        self,
        weight: W2,
    ) -> super::shared_set::CrossGroupedConstraintSetBuilder<
        S,
        A,
        B,
        JK,
        GK,
        EA,
        EB,
        KA,
        KB,
        F,
        GF,
        C,
        V,
        R,
        Acc,
        GroupedTerminalScorer<GK, R, W, Sc>,
        impl Fn(&GK, &R) -> Sc + Send + Sync,
        Sc,
    >
    where
        W2: for<'w> ConstraintWeight<(&'w GK, &'w R), Sc> + Send + Sync,
    {
        self.inner.penalize(weight)
    }

    pub fn reward<W2>(
        self,
        weight: W2,
    ) -> super::shared_set::CrossGroupedConstraintSetBuilder<
        S,
        A,
        B,
        JK,
        GK,
        EA,
        EB,
        KA,
        KB,
        F,
        GF,
        C,
        V,
        R,
        Acc,
        GroupedTerminalScorer<GK, R, W, Sc>,
        impl Fn(&GK, &R) -> Sc + Send + Sync,
        Sc,
    >
    where
        W2: for<'w> ConstraintWeight<(&'w GK, &'w R), Sc> + Send + Sync,
    {
        self.inner.reward(weight)
    }
}

impl<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
    for CrossGroupedConstraint<S, A, B, JK, GK, EA, EB, KA, KB, F, GF, C, V, R, Acc, W, Sc>
where
    S: Send + Sync + 'static,
    A: Send + Sync + 'static,
    B: Send + Sync + 'static,
    JK: Eq + Hash + Send + Sync + 'static,
    GK: Eq + Hash + Send + Sync + 'static,
    EA: CollectionExtract<S, Item = A> + Send + Sync,
    EB: CollectionExtract<S, Item = B> + Send + Sync,
    KA: Fn(&A) -> JK + Send + Sync,
    KB: Fn(&B) -> JK + Send + Sync,
    F: Fn(&S, &A, &B, usize, usize) -> bool + Send + Sync,
    GF: Fn(&A, &B) -> GK + Send + Sync,
    C: for<'i> Collector<(&'i A, &'i B), Value = V, Result = R, Accumulator = Acc> + Send + Sync,
    V: Send + Sync + 'static,
    R: Send + Sync + 'static,
    Acc: Accumulator<V, R> + Send + Sync + 'static,
    W: Fn(&GK, &R) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    fn evaluate(&self, solution: &S) -> Sc {
        self.inner.evaluate_all(solution)
    }

    fn match_count(&self, solution: &S) -> usize {
        self.inner
            .evaluate_each(solution)
            .first()
            .map_or(0, |result| result.match_count)
    }

    fn initialize(&mut self, solution: &S) -> Sc {
        self.inner.initialize_all(solution)
    }

    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        self.inner
            .on_insert_all(solution, entity_index, descriptor_index)
    }

    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        self.inner
            .on_retract_all(solution, entity_index, descriptor_index)
    }

    fn reset(&mut self) {
        self.inner.reset_all();
    }

    fn constraint_ref(&self) -> &ConstraintRef {
        self.inner.primary_constraint_ref()
    }

    fn is_hard(&self) -> bool {
        self.is_hard
    }

    fn get_matches<'a>(&'a self, _solution: &S) -> Vec<DetailedConstraintMatch<'a, Sc>> {
        Vec::new()
    }

    fn weight(&self) -> Sc {
        Sc::zero()
    }
}