solverforge-scoring 0.14.1

Incremental constraint scoring for SolverForge
Documentation
use std::collections::{hash_map::Entry, HashMap};
use std::hash::Hash;
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use crate::stream::collection_extract::{ChangeSource, CollectionExtract};
use crate::stream::collector::{Accumulator, Collector};

type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;

pub(super) struct GroupState<Acc> {
    accumulator: Acc,
    count: usize,
}

pub(super) struct MatchRow<GK, Retraction> {
    pub(super) pair: (usize, usize),
    pub(super) group_key: GK,
    pub(super) retraction: Retraction,
    pub(super) a_pos: usize,
    pub(super) b_pos: usize,
}

pub struct CrossComplementedGroupedConstraint<
    S,
    A,
    B,
    T,
    JK,
    GK,
    EA,
    EB,
    ET,
    KA,
    KB,
    F,
    GF,
    KT,
    C,
    V,
    R,
    Acc,
    D,
    W,
    Sc,
> where
    Acc: Accumulator<V, R>,
    Sc: Score,
{
    pub(super) constraint_ref: ConstraintRef,
    pub(super) impact_type: ImpactType,
    pub(super) extractor_a: EA,
    pub(super) extractor_b: EB,
    pub(super) extractor_t: ET,
    pub(super) key_a: KA,
    pub(super) key_b: KB,
    pub(super) filter: F,
    pub(super) group_key_fn: GF,
    pub(super) key_t: KT,
    pub(super) collector: C,
    pub(super) default_fn: D,
    pub(super) weight_fn: W,
    pub(super) is_hard: bool,
    pub(super) a_source: ChangeSource,
    pub(super) b_source: ChangeSource,
    pub(super) t_source: ChangeSource,
    pub(super) matches: HashMap<(usize, usize), usize>,
    pub(super) match_rows: Vec<MatchRow<GK, CollectorRetraction<Acc, V, R>>>,
    pub(super) a_to_matches: HashMap<usize, Vec<usize>>,
    pub(super) b_to_matches: HashMap<usize, Vec<usize>>,
    pub(super) a_by_key: HashMap<JK, Vec<usize>>,
    pub(super) b_by_key: HashMap<JK, Vec<usize>>,
    pub(super) a_index_to_key: HashMap<usize, JK>,
    pub(super) b_index_to_key: HashMap<usize, JK>,
    pub(super) t_by_key: HashMap<GK, Vec<usize>>,
    pub(super) t_index_to_key: HashMap<usize, GK>,
    groups: HashMap<GK, GroupState<Acc>>,
    pub(super) _phantom: PhantomData<(
        fn() -> S,
        fn() -> A,
        fn() -> B,
        fn() -> T,
        fn() -> V,
        fn() -> R,
        fn() -> Acc,
        fn() -> Sc,
    )>,
}

impl<S, A, B, T, JK, GK, EA, EB, ET, KA, KB, F, GF, KT, C, V, R, Acc, D, W, Sc>
    CrossComplementedGroupedConstraint<
        S,
        A,
        B,
        T,
        JK,
        GK,
        EA,
        EB,
        ET,
        KA,
        KB,
        F,
        GF,
        KT,
        C,
        V,
        R,
        Acc,
        D,
        W,
        Sc,
    >
where
    S: 'static,
    A: Clone + 'static,
    B: Clone + 'static,
    T: Clone + 'static,
    JK: Clone + Eq + Hash,
    GK: Clone + Eq + Hash,
    EA: CollectionExtract<S, Item = A>,
    EB: CollectionExtract<S, Item = B>,
    ET: CollectionExtract<S, Item = T>,
    KA: Fn(&A) -> JK,
    KB: Fn(&B) -> JK,
    F: Fn(&S, &A, &B, usize, usize) -> bool,
    GF: Fn(&A, &B) -> GK,
    KT: Fn(&T) -> GK,
    C: for<'i> Collector<(&'i A, &'i B), Value = V, Result = R, Accumulator = Acc>,
    Acc: Accumulator<V, R>,
    D: Fn(&T) -> R,
    W: Fn(&GK, &R) -> Sc,
    Sc: Score,
{
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        extractor_a: EA,
        extractor_b: EB,
        extractor_t: ET,
        key_a: KA,
        key_b: KB,
        filter: F,
        group_key_fn: GF,
        key_t: KT,
        collector: C,
        default_fn: D,
        weight_fn: W,
        is_hard: bool,
    ) -> Self {
        let a_source = extractor_a.change_source();
        let b_source = extractor_b.change_source();
        let t_source = extractor_t.change_source();
        Self {
            constraint_ref,
            impact_type,
            extractor_a,
            extractor_b,
            extractor_t,
            key_a,
            key_b,
            filter,
            group_key_fn,
            key_t,
            collector,
            default_fn,
            weight_fn,
            is_hard,
            a_source,
            b_source,
            t_source,
            matches: HashMap::new(),
            match_rows: Vec::new(),
            a_to_matches: HashMap::new(),
            b_to_matches: HashMap::new(),
            a_by_key: HashMap::new(),
            b_by_key: HashMap::new(),
            a_index_to_key: HashMap::new(),
            b_index_to_key: HashMap::new(),
            t_by_key: HashMap::new(),
            t_index_to_key: HashMap::new(),
            groups: HashMap::new(),
            _phantom: PhantomData,
        }
    }

    #[inline]
    pub(super) fn compute_score(&self, key: &GK, result: &R) -> Sc {
        let base = (self.weight_fn)(key, result);
        match self.impact_type {
            ImpactType::Penalty => -base,
            ImpactType::Reward => base,
        }
    }

    pub(super) fn b_index_for(&self, solution: &S, entities_b: &[B]) -> HashMap<JK, Vec<usize>> {
        let mut b_by_key = HashMap::<JK, Vec<usize>>::new();
        for (b_idx, b) in entities_b.iter().enumerate() {
            if !self.extractor_b.contains(solution, b) {
                continue;
            }
            let key = (self.key_b)(b);
            b_by_key.entry(key).or_default().push(b_idx);
        }
        b_by_key
    }

    pub(super) fn build_join_indexes(&mut self, solution: &S, entities_a: &[A], entities_b: &[B]) {
        self.a_by_key.clear();
        self.b_by_key.clear();
        self.a_index_to_key.clear();
        self.b_index_to_key.clear();
        for (a_idx, a) in entities_a.iter().enumerate() {
            if !self.extractor_a.contains(solution, a) {
                continue;
            }
            let key = (self.key_a)(a);
            self.a_index_to_key.insert(a_idx, key.clone());
            self.a_by_key.entry(key).or_default().push(a_idx);
        }
        for (b_idx, b) in entities_b.iter().enumerate() {
            if !self.extractor_b.contains(solution, b) {
                continue;
            }
            let key = (self.key_b)(b);
            self.b_index_to_key.insert(b_idx, key.clone());
            self.b_by_key.entry(key).or_default().push(b_idx);
        }
    }

    #[inline]
    pub(super) fn matching_b_indices_in<'a>(
        &self,
        b_by_key: &'a HashMap<JK, Vec<usize>>,
        a: &A,
    ) -> &'a [usize] {
        let key = (self.key_a)(a);
        b_by_key.get(&key).map_or(&[], Vec::as_slice)
    }

    pub(super) fn complement_score_for_index(
        &self,
        entities_t: &[T],
        key: &GK,
        t_idx: usize,
    ) -> Sc {
        if t_idx >= entities_t.len() {
            return Sc::zero();
        }
        if let Some(group) = self.groups.get(key) {
            return group
                .accumulator
                .with_result(|result| self.compute_score(key, result));
        }
        let default_result = (self.default_fn)(&entities_t[t_idx]);
        self.compute_score(key, &default_result)
    }

    pub(super) fn key_score(&self, entities_t: &[T], key: &GK) -> Sc {
        let Some(indices) = self.t_by_key.get(key) else {
            return Sc::zero();
        };
        indices.iter().fold(Sc::zero(), |total, &t_idx| {
            total + self.complement_score_for_index(entities_t, key, t_idx)
        })
    }

    fn remove_index_from_key_bucket(
        indexes_by_key: &mut HashMap<GK, Vec<usize>>,
        key: &GK,
        idx: usize,
    ) {
        let mut remove_bucket = false;
        if let Some(indices) = indexes_by_key.get_mut(key) {
            if let Some(pos) = indices.iter().position(|candidate| *candidate == idx) {
                indices.swap_remove(pos);
            }
            remove_bucket = indices.is_empty();
        }
        if remove_bucket {
            indexes_by_key.remove(key);
        }
    }

    pub(super) fn index_complement(&mut self, key: GK, t_idx: usize) {
        if let Some(old_key) = self.t_index_to_key.insert(t_idx, key.clone()) {
            Self::remove_index_from_key_bucket(&mut self.t_by_key, &old_key, t_idx);
        }
        self.t_by_key.entry(key).or_default().push(t_idx);
    }

    pub(super) fn insert_value(
        &mut self,
        entities_t: &[T],
        key: GK,
        value: V,
    ) -> (Sc, CollectorRetraction<Acc, V, R>) {
        let old = self.key_score(entities_t, &key);
        let retraction = match self.groups.entry(key.clone()) {
            Entry::Occupied(mut entry) => {
                let group = entry.get_mut();
                let retraction = group.accumulator.accumulate(value);
                group.count += 1;
                retraction
            }
            Entry::Vacant(entry) => {
                let group = entry.insert(GroupState {
                    accumulator: self.collector.create_accumulator(),
                    count: 0,
                });
                let retraction = group.accumulator.accumulate(value);
                group.count += 1;
                retraction
            }
        };
        let new_score = self.key_score(entities_t, &key);
        (new_score - old, retraction)
    }

    pub(super) fn retract_value(
        &mut self,
        entities_t: &[T],
        key: GK,
        retraction: CollectorRetraction<Acc, V, R>,
    ) -> Sc {
        let old = self.key_score(entities_t, &key);
        let Entry::Occupied(mut entry) = self.groups.entry(key.clone()) else {
            return Sc::zero();
        };
        let group = entry.get_mut();
        group.accumulator.retract(retraction);
        group.count = group.count.saturating_sub(1);
        if group.count == 0 {
            entry.remove();
        }
        let new_score = self.key_score(entities_t, &key);
        new_score - old
    }

    pub(super) fn clear_state(&mut self) {
        self.matches.clear();
        self.match_rows.clear();
        self.a_to_matches.clear();
        self.b_to_matches.clear();
        self.a_by_key.clear();
        self.b_by_key.clear();
        self.a_index_to_key.clear();
        self.b_index_to_key.clear();
        self.t_by_key.clear();
        self.t_index_to_key.clear();
        self.groups.clear();
    }
}