solverforge-scoring 0.13.1

Incremental constraint scoring for SolverForge
Documentation
use std::collections::{hash_map::Entry, HashMap, HashSet};
use std::hash::Hash;
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use crate::stream::collection_extract::{ChangeSource, CollectionExtract};
use crate::stream::collector::{Accumulator, Collector};
use crate::stream::filter::UniFilter;
use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};

type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;

pub(super) struct GroupState<Acc> {
    accumulator: Acc,
    count: usize,
}

pub struct ProjectedComplementedGroupedConstraint<
    S,
    Out,
    B,
    K,
    Src,
    EB,
    F,
    KA,
    KB,
    C,
    V,
    R,
    Acc,
    D,
    W,
    Sc,
> where
    Src: ProjectedSource<S, Out>,
    Acc: Accumulator<V, R>,
    Sc: Score,
{
    pub(super) constraint_ref: ConstraintRef,
    pub(super) impact_type: ImpactType,
    pub(super) source: Src,
    pub(super) extractor_b: EB,
    pub(super) filter: F,
    pub(super) key_a: KA,
    pub(super) key_b: KB,
    pub(super) collector: C,
    pub(super) default_fn: D,
    pub(super) weight_fn: W,
    pub(super) is_hard: bool,
    pub(super) b_source: ChangeSource,
    pub(super) source_state: Option<Src::State>,
    pub(super) groups: HashMap<K, GroupState<Acc>>,
    pub(super) row_outputs: HashMap<ProjectedRowCoordinate, Out>,
    pub(super) row_keys: HashMap<ProjectedRowCoordinate, K>,
    pub(super) row_retractions: HashMap<ProjectedRowCoordinate, CollectorRetraction<Acc, V, R>>,
    pub(super) rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
    pub(super) b_by_key: HashMap<K, Vec<usize>>,
    pub(super) b_index_to_key: HashMap<usize, K>,
    pub(super) _phantom: PhantomData<(
        fn() -> S,
        fn() -> Out,
        fn() -> B,
        fn() -> V,
        fn() -> R,
        fn() -> Acc,
        fn() -> Sc,
    )>,
}

impl<S, Out, B, K, Src, EB, F, KA, KB, C, V, R, Acc, D, W, Sc>
    ProjectedComplementedGroupedConstraint<S, Out, B, K, Src, EB, F, KA, KB, C, V, R, Acc, D, W, Sc>
where
    S: Send + Sync + 'static,
    Out: Send + Sync + 'static,
    B: Clone + Send + Sync + 'static,
    K: Clone + Eq + Hash + Send + Sync + 'static,
    Src: ProjectedSource<S, Out>,
    EB: CollectionExtract<S, Item = B>,
    F: UniFilter<S, Out>,
    KA: Fn(&Out) -> Option<K> + Send + Sync,
    KB: Fn(&B) -> K + Send + Sync,
    C: for<'i> Collector<&'i Out, Value = V, Result = R, Accumulator = Acc> + Send + Sync,
    V: Send + Sync,
    R: Send + Sync,
    Acc: Accumulator<V, R> + Send + Sync,
    D: Fn(&B) -> R + Send + Sync,
    W: Fn(&K, &R) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        source: Src,
        extractor_b: EB,
        filter: F,
        key_a: KA,
        key_b: KB,
        collector: C,
        default_fn: D,
        weight_fn: W,
        is_hard: bool,
    ) -> Self {
        let b_source = extractor_b.change_source();
        Self {
            constraint_ref,
            impact_type,
            source,
            extractor_b,
            filter,
            key_a,
            key_b,
            collector,
            default_fn,
            weight_fn,
            is_hard,
            b_source,
            source_state: None,
            groups: HashMap::new(),
            row_outputs: HashMap::new(),
            row_keys: HashMap::new(),
            row_retractions: HashMap::new(),
            rows_by_owner: HashMap::new(),
            b_by_key: HashMap::new(),
            b_index_to_key: HashMap::new(),
            _phantom: PhantomData,
        }
    }

    pub(super) fn compute_score(&self, key: &K, result: &R) -> Sc {
        let base = (self.weight_fn)(key, result);
        match self.impact_type {
            ImpactType::Penalty => -base,
            ImpactType::Reward => base,
        }
    }
    fn b_score_for_index(&self, entities_b: &[B], key: &K, b_idx: usize) -> Sc {
        if b_idx >= entities_b.len() {
            return Sc::zero();
        }
        if let Some(group) = self.groups.get(key) {
            return group
                .accumulator
                .with_result(|result| self.compute_score(key, result));
        }
        let default_result = (self.default_fn)(&entities_b[b_idx]);
        self.compute_score(key, &default_result)
    }

    fn key_score(&self, entities_b: &[B], key: &K) -> Sc {
        let Some(indices) = self.b_by_key.get(key) else {
            return Sc::zero();
        };
        indices.iter().fold(Sc::zero(), |total, &b_idx| {
            total + self.b_score_for_index(entities_b, key, b_idx)
        })
    }

    fn remove_index_from_key_bucket(
        indexes_by_key: &mut HashMap<K, Vec<usize>>,
        key: &K,
        idx: usize,
    ) {
        let mut remove_bucket = false;
        if let Some(indices) = indexes_by_key.get_mut(key) {
            if let Some(pos) = indices.iter().position(|candidate| *candidate == idx) {
                indices.swap_remove(pos);
            }
            remove_bucket = indices.is_empty();
        }
        if remove_bucket {
            indexes_by_key.remove(key);
        }
    }

    fn index_b(&mut self, key: K, b_idx: usize) {
        if let Some(old_key) = self.b_index_to_key.insert(b_idx, key.clone()) {
            Self::remove_index_from_key_bucket(&mut self.b_by_key, &old_key, b_idx);
        }
        self.b_by_key.entry(key).or_default().push(b_idx);
    }

    fn insert_value(
        &mut self,
        entities_b: &[B],
        key: K,
        value: V,
    ) -> (Sc, CollectorRetraction<Acc, V, R>) {
        let old = self.key_score(entities_b, &key);
        let retraction = match self.groups.entry(key.clone()) {
            Entry::Occupied(mut entry) => {
                let group = entry.get_mut();
                let retraction = group.accumulator.accumulate(value);
                group.count += 1;
                retraction
            }
            Entry::Vacant(entry) => {
                let group = entry.insert(GroupState {
                    accumulator: self.collector.create_accumulator(),
                    count: 0,
                });
                let retraction = group.accumulator.accumulate(value);
                group.count += 1;
                retraction
            }
        };
        let new_score = self.key_score(entities_b, &key);
        (new_score - old, retraction)
    }
    fn retract_value(
        &mut self,
        entities_b: &[B],
        key: K,
        retraction: CollectorRetraction<Acc, V, R>,
    ) -> Sc {
        let old = self.key_score(entities_b, &key);
        let Entry::Occupied(mut entry) = self.groups.entry(key.clone()) else {
            return Sc::zero();
        };
        let group = entry.get_mut();
        group.accumulator.retract(retraction);
        group.count = group.count.saturating_sub(1);
        if group.count == 0 {
            entry.remove();
        }
        let new_score = self.key_score(entities_b, &key);
        new_score - old
    }

    pub(super) fn ensure_source_state(&mut self, solution: &S) {
        if self.source_state.is_none() {
            self.source_state = Some(self.source.build_state(solution));
        }
    }
    fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
        coordinate.for_each_owner(|owner| {
            self.rows_by_owner
                .entry(owner)
                .or_default()
                .push(coordinate);
        });
    }

    fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
        coordinate.for_each_owner(|owner| {
            let mut remove_bucket = false;
            if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
                rows.retain(|candidate| *candidate != coordinate);
                remove_bucket = rows.is_empty();
            }
            if remove_bucket {
                self.rows_by_owner.remove(&owner);
            }
        });
    }

    pub(super) fn insert_row(
        &mut self,
        solution: &S,
        entities_b: &[B],
        coordinate: ProjectedRowCoordinate,
        output: Out,
    ) -> Sc {
        if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
            return Sc::zero();
        }
        let Some(key) = (self.key_a)(&output) else {
            return Sc::zero();
        };
        let value = self.collector.extract(&output);
        let (delta, retraction) = self.insert_value(entities_b, key.clone(), value);
        self.row_outputs.insert(coordinate, output);
        self.row_keys.insert(coordinate, key);
        self.row_retractions.insert(coordinate, retraction);
        self.index_coordinate(coordinate);
        delta
    }

    pub(super) fn retract_row(
        &mut self,
        entities_b: &[B],
        coordinate: ProjectedRowCoordinate,
    ) -> Sc {
        let Some(_output) = self.row_outputs.remove(&coordinate) else {
            return Sc::zero();
        };
        self.unindex_coordinate(coordinate);
        let Some(key) = self.row_keys.remove(&coordinate) else {
            return Sc::zero();
        };
        let Some(retraction) = self.row_retractions.remove(&coordinate) else {
            return Sc::zero();
        };
        self.retract_value(entities_b, key, retraction)
    }

    pub(super) fn localized_owners(
        &self,
        descriptor_index: usize,
        entity_index: usize,
    ) -> Vec<ProjectedRowOwner> {
        let mut owners = Vec::new();
        for slot in 0..self.source.source_count() {
            if self
                .source
                .change_source(slot)
                .assert_localizes(descriptor_index, &self.constraint_ref.name)
            {
                owners.push(ProjectedRowOwner {
                    source_slot: slot,
                    entity_index,
                });
            }
        }
        owners
    }

    pub(super) fn coordinates_for_owners(
        &self,
        owners: &[ProjectedRowOwner],
    ) -> Vec<ProjectedRowCoordinate> {
        let mut seen = HashSet::new();
        let mut coordinates = Vec::new();
        for owner in owners {
            let Some(rows) = self.rows_by_owner.get(owner) else {
                continue;
            };
            for &coordinate in rows {
                if seen.insert(coordinate) {
                    coordinates.push(coordinate);
                }
            }
        }
        coordinates
    }

    pub(super) fn insert_b(&mut self, entities_b: &[B], b_idx: usize) -> Sc {
        if b_idx >= entities_b.len() {
            return Sc::zero();
        }
        let key = (self.key_b)(&entities_b[b_idx]);
        self.index_b(key.clone(), b_idx);
        self.b_score_for_index(entities_b, &key, b_idx)
    }

    pub(super) fn retract_b(&mut self, entities_b: &[B], b_idx: usize) -> Sc {
        let Some(key) = self.b_index_to_key.remove(&b_idx) else {
            return Sc::zero();
        };
        let delta = -self.b_score_for_index(entities_b, &key, b_idx);
        Self::remove_index_from_key_bucket(&mut self.b_by_key, &key, b_idx);
        delta
    }
}