solverforge-scoring 0.10.0

Incremental constraint scoring for SolverForge
Documentation
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use crate::api::constraint_set::IncrementalConstraint;
use crate::stream::filter::{BiFilter, UniFilter};
use crate::stream::{ProjectedRowCoordinate, ProjectedSource};

struct ProjectedJoinRow<Out, K> {
    key: K,
    output: Out,
    order: ProjectedRowCoordinate,
}

pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
where
    Sc: Score,
{
    constraint_ref: ConstraintRef,
    impact_type: ImpactType,
    source: Src,
    filter: F,
    key_fn: KF,
    pair_filter: PF,
    weight: W,
    is_hard: bool,
    rows: Vec<Option<ProjectedJoinRow<Out, K>>>,
    free_row_ids: Vec<usize>,
    rows_by_entity: HashMap<(usize, usize), Vec<usize>>,
    rows_by_key: HashMap<K, Vec<usize>>,
    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
}

impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
where
    S: Send + Sync + 'static,
    Out: Clone + Send + Sync + 'static,
    K: Clone + Eq + Hash + Send + Sync + 'static,
    Src: ProjectedSource<S, Out>,
    F: UniFilter<S, Out>,
    KF: Fn(&Out) -> K + Send + Sync,
    PF: BiFilter<S, Out, Out>,
    W: Fn(&Out, &Out) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        source: Src,
        filter: F,
        key_fn: KF,
        pair_filter: PF,
        weight: W,
        is_hard: bool,
    ) -> Self {
        Self {
            constraint_ref,
            impact_type,
            source,
            filter,
            key_fn,
            pair_filter,
            weight,
            is_hard,
            rows: Vec::new(),
            free_row_ids: Vec::new(),
            rows_by_entity: HashMap::new(),
            rows_by_key: HashMap::new(),
            _phantom: PhantomData,
        }
    }

    fn compute_score(&self, left: &Out, right: &Out) -> Sc {
        let base = (self.weight)(left, right);
        match self.impact_type {
            ImpactType::Penalty => -base,
            ImpactType::Reward => base,
        }
    }

    fn score_ordered_rows(
        &self,
        solution: &S,
        first: &ProjectedJoinRow<Out, K>,
        second: &ProjectedJoinRow<Out, K>,
    ) -> Sc {
        let (left, right) = if first.order <= second.order {
            (first, second)
        } else {
            (second, first)
        };
        if !self
            .pair_filter
            .test(solution, &left.output, &right.output, 0, 1)
        {
            return Sc::zero();
        }
        self.compute_score(&left.output, &right.output)
    }

    fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
        let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
            return Sc::zero();
        };
        let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
            return Sc::zero();
        };
        self.score_ordered_rows(solution, first, second)
    }

    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
        let key = (self.key_fn)(&output);
        let existing = self.rows_by_key.get(&key).cloned().unwrap_or_default();
        let row = Some(ProjectedJoinRow {
            key: key.clone(),
            output,
            order: coordinate,
        });
        let row_id = if let Some(row_id) = self.free_row_ids.pop() {
            debug_assert!(self.rows[row_id].is_none());
            self.rows[row_id] = row;
            row_id
        } else {
            let row_id = self.rows.len();
            self.rows.push(row);
            row_id
        };
        self.rows_by_entity
            .entry((coordinate.source_slot, coordinate.entity_index))
            .or_default()
            .push(row_id);

        let mut total = Sc::zero();
        for other_id in existing {
            total = total + self.score_pair(solution, row_id, other_id);
        }
        self.rows_by_key.entry(key).or_default().push(row_id);
        total
    }

    fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
        let Some(row) = self.rows.get(row_id).and_then(Option::as_ref) else {
            return Sc::zero();
        };
        let key = row.key.clone();
        let candidates = self.rows_by_key.get(&key).cloned().unwrap_or_default();
        let mut total = Sc::zero();
        for other_id in candidates {
            if other_id == row_id {
                continue;
            }
            total = total - self.score_pair(solution, row_id, other_id);
        }

        if let Some(ids) = self.rows_by_key.get_mut(&key) {
            ids.retain(|&id| id != row_id);
            if ids.is_empty() {
                self.rows_by_key.remove(&key);
            }
        }
        self.rows[row_id] = None;
        self.free_row_ids.push(row_id);
        total
    }

    fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
        let mut outputs = Vec::new();
        self.source
            .collect_entity(solution, slot, entity_index, |coordinate, output| {
                if self.filter.test(solution, &output) {
                    outputs.push((coordinate, output));
                }
            });

        outputs
            .into_iter()
            .fold(Sc::zero(), |total, (coordinate, output)| {
                total + self.insert_row(solution, coordinate, output)
            })
    }

    fn retract_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
        let Some(row_ids) = self.rows_by_entity.remove(&(slot, entity_index)) else {
            return Sc::zero();
        };
        row_ids.into_iter().fold(Sc::zero(), |total, row_id| {
            total + self.retract_row(solution, row_id)
        })
    }

    fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out, K>> {
        let mut rows = Vec::new();
        self.source.collect_all(solution, |coordinate, output| {
            if self.filter.test(solution, &output) {
                rows.push(ProjectedJoinRow {
                    key: (self.key_fn)(&output),
                    output,
                    order: coordinate,
                });
            }
        });
        rows
    }

    fn score_evaluation_pair(
        &self,
        solution: &S,
        first: &ProjectedJoinRow<Out, K>,
        second: &ProjectedJoinRow<Out, K>,
    ) -> Sc {
        if first.key == second.key {
            self.score_ordered_rows(solution, first, second)
        } else {
            Sc::zero()
        }
    }

    fn evaluation_pair_matches(
        &self,
        solution: &S,
        first: &ProjectedJoinRow<Out, K>,
        second: &ProjectedJoinRow<Out, K>,
    ) -> bool {
        if first.key != second.key {
            return false;
        }
        let (left, right) = if first.order <= second.order {
            (first, second)
        } else {
            (second, first)
        };
        self.pair_filter
            .test(solution, &left.output, &right.output, 0, 1)
    }

    fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
        let mut slots = Vec::new();
        for slot in 0..self.source.source_count() {
            if self
                .source
                .change_source(slot)
                .assert_localizes(descriptor_index, &self.constraint_ref.name)
            {
                slots.push(slot);
            }
        }
        slots
    }

    #[cfg(test)]
    pub(crate) fn debug_row_storage_len(&self) -> usize {
        self.rows.len()
    }

    #[cfg(test)]
    pub(crate) fn debug_free_row_count(&self) -> usize {
        self.free_row_ids.len()
    }
}

impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
    for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
where
    S: Send + Sync + 'static,
    Out: Clone + Send + Sync + 'static,
    K: Clone + Eq + Hash + Send + Sync + 'static,
    Src: ProjectedSource<S, Out>,
    F: UniFilter<S, Out>,
    KF: Fn(&Out) -> K + Send + Sync,
    PF: BiFilter<S, Out, Out>,
    W: Fn(&Out, &Out) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    fn evaluate(&self, solution: &S) -> Sc {
        let rows = self.evaluate_rows(solution);

        let mut total = Sc::zero();
        for left_index in 0..rows.len() {
            for right_index in (left_index + 1)..rows.len() {
                total = total
                    + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
            }
        }
        total
    }

    fn match_count(&self, solution: &S) -> usize {
        let rows = self.evaluate_rows(solution);

        let mut count = 0;
        for left_index in 0..rows.len() {
            for right_index in (left_index + 1)..rows.len() {
                if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
                    count += 1;
                }
            }
        }
        count
    }

    fn initialize(&mut self, solution: &S) -> Sc {
        self.reset();
        let mut rows = Vec::new();
        self.source.collect_all(solution, |coordinate, output| {
            if self.filter.test(solution, &output) {
                rows.push((coordinate, output));
            }
        });

        rows.into_iter()
            .fold(Sc::zero(), |total, (coordinate, output)| {
                total + self.insert_row(solution, coordinate, output)
            })
    }

    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        let mut total = Sc::zero();
        for slot in self.localized_slots(descriptor_index) {
            total = total + self.insert_entity_outputs(solution, slot, entity_index);
        }
        total
    }

    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        let mut total = Sc::zero();
        for slot in self.localized_slots(descriptor_index) {
            total = total + self.retract_entity_outputs(solution, slot, entity_index);
        }
        total
    }

    fn reset(&mut self) {
        self.rows.clear();
        self.free_row_ids.clear();
        self.rows_by_entity.clear();
        self.rows_by_key.clear();
    }

    fn name(&self) -> &str {
        &self.constraint_ref.name
    }

    fn is_hard(&self) -> bool {
        self.is_hard
    }

    fn constraint_ref(&self) -> ConstraintRef {
        self.constraint_ref.clone()
    }
}