use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use super::collection_extract::CollectionExtract;
use super::filter::UniFilter;
use crate::constraint::balance::BalanceConstraint;
pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
where
Sc: Score,
{
extractor: E,
filter: F,
key_fn: KF,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
}
impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: CollectionExtract<S, Item = A>,
F: UniFilter<S, A>,
KF: Fn(&A) -> Option<K> + Send + Sync,
Sc: Score + 'static,
{
pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
Self {
extractor,
filter,
key_fn,
_phantom: PhantomData,
}
}
pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
let is_hard = base_score
.to_level_numbers()
.first()
.map(|&h| h != 0)
.unwrap_or(false);
BalanceConstraintBuilder {
extractor: self.extractor,
filter: self.filter,
key_fn: self.key_fn,
impact_type: ImpactType::Penalty,
base_score,
is_hard,
_phantom: PhantomData,
}
}
pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
Sc: Copy,
{
self.penalize(Sc::one_hard())
}
pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
Sc: Copy,
{
self.penalize(Sc::one_soft())
}
pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
Sc: Copy,
{
self.reward(Sc::one_hard())
}
pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
Sc: Copy,
{
self.reward(Sc::one_soft())
}
pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
let is_hard = base_score
.to_level_numbers()
.first()
.map(|&h| h != 0)
.unwrap_or(false);
BalanceConstraintBuilder {
extractor: self.extractor,
filter: self.filter,
key_fn: self.key_fn,
impact_type: ImpactType::Reward,
base_score,
is_hard,
_phantom: PhantomData,
}
}
}
impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BalanceConstraintStream").finish()
}
}
pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
Sc: Score,
{
extractor: E,
filter: F,
key_fn: KF,
impact_type: ImpactType,
base_score: Sc,
is_hard: bool,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
}
impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: CollectionExtract<S, Item = A>,
F: UniFilter<S, A>,
KF: Fn(&A) -> Option<K> + Send + Sync,
Sc: Score + 'static,
{
pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
BalanceConstraint::new(
ConstraintRef::new("", name),
self.impact_type,
self.extractor,
self.filter,
self.key_fn,
self.base_score,
self.is_hard,
)
}
}
impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BalanceConstraintBuilder")
.field("impact_type", &self.impact_type)
.finish()
}
}