solverforge-scoring 0.15.0

Incremental constraint scoring for SolverForge
Documentation
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use super::state::GroupedStateView;

pub struct GroupedTerminalScorer<K, R, W, Sc>
where
    Sc: Score,
{
    pub(super) constraint_ref: ConstraintRef,
    pub(super) impact_type: ImpactType,
    pub(super) weight_fn: W,
    pub(super) is_hard: bool,
    pub(super) match_count: usize,
    pub(super) refresh_count: usize,
    cached_scores: Vec<Sc>,
    _phantom: PhantomData<fn() -> (K, R, Sc)>,
}

impl<K, R, W, Sc> GroupedTerminalScorer<K, R, W, Sc>
where
    R: Send + Sync + 'static,
    W: Fn(&K, &R) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        weight_fn: W,
        is_hard: bool,
    ) -> Self {
        Self {
            constraint_ref,
            impact_type,
            weight_fn,
            is_hard,
            match_count: 0,
            refresh_count: 0,
            cached_scores: Vec::new(),
            _phantom: PhantomData,
        }
    }

    pub fn evaluate<State>(&self, state: &State) -> Sc
    where
        State: GroupedStateView<K, R>,
    {
        let mut total = Sc::zero();
        state.for_each_group_result(|key, result| {
            total = total + self.compute_score(key, result);
        });
        total
    }

    pub fn initialize<State>(&mut self, state: &State) -> Sc
    where
        State: GroupedStateView<K, R>,
    {
        self.match_count = state.group_count();
        self.refresh_count = 0;
        self.cached_scores.clear();
        let mut total = Sc::zero();
        state.for_each_group_slot_result(|group_id, entry| {
            let score = self.score_entry(entry);
            self.replace_cached_score(group_id, score);
            total = total + score;
        });
        total
    }

    pub fn refresh_all<State>(&mut self, state: &State) -> Sc
    where
        State: GroupedStateView<K, R>,
    {
        self.match_count = state.group_count();
        self.refresh_count += 1;
        self.cached_scores.clear();
        let mut total = Sc::zero();
        state.for_each_group_slot_result(|group_id, entry| {
            let score = self.score_entry(entry);
            self.replace_cached_score(group_id, score);
            total = total + score;
        });
        total
    }

    pub fn refresh_changed<State>(&mut self, state: &State) -> Sc
    where
        State: GroupedStateView<K, R>,
    {
        self.match_count = state.group_count();
        self.refresh_count += 1;
        let mut delta = Sc::zero();
        state.for_each_changed_group_slot_result(|group_id, entry| {
            let score = self.score_entry(entry);
            delta = delta + self.replace_cached_score(group_id, score);
        });
        delta
    }

    pub fn reset(&mut self) {
        self.match_count = 0;
        self.refresh_count = 0;
        self.cached_scores.clear();
    }

    pub fn constraint_ref(&self) -> &ConstraintRef {
        &self.constraint_ref
    }

    pub fn name(&self) -> &str {
        &self.constraint_ref.name
    }

    pub fn is_hard(&self) -> bool {
        self.is_hard
    }

    pub fn match_count(&self) -> usize {
        self.match_count
    }

    pub fn refresh_count(&self) -> usize {
        self.refresh_count
    }

    #[inline]
    pub(super) fn compute_score(&self, key: &K, result: &R) -> Sc {
        let base = (self.weight_fn)(key, result);
        match self.impact_type {
            ImpactType::Penalty => -base,
            ImpactType::Reward => base,
        }
    }

    pub(crate) fn score_entry(&self, entry: Option<(&K, &R)>) -> Sc {
        match entry {
            Some((key, result)) => self.compute_score(key, result),
            None => Sc::zero(),
        }
    }

    pub(crate) fn replace_cached_score(&mut self, slot: usize, score: Sc) -> Sc {
        while self.cached_scores.len() <= slot {
            self.cached_scores.push(Sc::zero());
        }
        let previous = self.cached_scores[slot];
        self.cached_scores[slot] = score;
        score - previous
    }

    pub(crate) fn reset_incremental_cache(&mut self, match_count: usize) {
        self.match_count = match_count;
        self.refresh_count = 0;
        self.cached_scores.clear();
    }

    pub(crate) fn mark_incremental_refresh(&mut self, match_count: usize) {
        self.match_count = match_count;
        self.refresh_count += 1;
    }

    pub(crate) fn clear_incremental_scores(&mut self) {
        self.cached_scores.clear();
    }
}

impl<K, R, W, Sc> std::fmt::Debug for GroupedTerminalScorer<K, R, W, Sc>
where
    Sc: Score,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("GroupedTerminalScorer")
            .field("name", &self.constraint_ref.name)
            .field("impact_type", &self.impact_type)
            .field("match_count", &self.match_count)
            .field("refresh_count", &self.refresh_count)
            .finish()
    }
}