use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::constraint::grouped::{ComplementedGroupedScorerSet, GroupedTerminalScorer};
use crate::stream::collection_extract::CollectionExtract;
use crate::stream::collector::{Accumulator, Collector};
use crate::stream::ConstraintWeight;
use super::shared_set::SharedCrossComplementedGroupedConstraintSet;
use super::state::CrossComplementedGroupedNodeState;
pub struct CrossComplementedGroupedConstraintSetBuilder<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
W,
Sc,
> where
Acc: Accumulator<V, R>,
Sc: Score,
{
state: CrossComplementedGroupedNodeState<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
>,
scorers: Scorers,
cached_score: Sc,
impact_type: ImpactType,
weight_fn: W,
is_hard: bool,
_phantom: PhantomData<fn() -> Sc>,
}
impl<S, A, B, T, JK, GK, EA, EB, ET, KA, KB, F, GF, KT, C, V, R, Acc, D, Scorers, Sc>
SharedCrossComplementedGroupedConstraintSet<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
Sc,
>
where
S: Send + Sync + 'static,
A: Send + Sync + 'static,
B: Send + Sync + 'static,
T: Send + Sync + 'static,
JK: Eq + Hash + Send + Sync + 'static,
GK: Eq + Hash + Send + Sync + 'static,
EA: CollectionExtract<S, Item = A> + Send + Sync,
EB: CollectionExtract<S, Item = B> + Send + Sync,
ET: CollectionExtract<S, Item = T> + Send + Sync,
KA: Fn(&A) -> JK + Send + Sync,
KB: Fn(&B) -> JK + Send + Sync,
F: Fn(&S, &A, &B, usize, usize) -> bool + Send + Sync,
GF: Fn(&A, &B) -> GK + Send + Sync,
KT: Fn(&T) -> GK + Send + Sync,
C: for<'i> Collector<(&'i A, &'i B), Value = V, Result = R, Accumulator = Acc> + Send + Sync,
V: Send + Sync + 'static,
R: Send + Sync + 'static,
Acc: Accumulator<V, R> + Send + Sync + 'static,
D: Fn(&T) -> R + Send + Sync,
Scorers: ComplementedGroupedScorerSet<GK, R, Sc>,
Sc: Score + 'static,
{
fn into_weighted_builder<W>(
self,
impact_type: ImpactType,
weight_fn: W,
is_hard: bool,
) -> CrossComplementedGroupedConstraintSetBuilder<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
W,
Sc,
>
where
W: Fn(&GK, &R) -> Sc + Send + Sync,
{
CrossComplementedGroupedConstraintSetBuilder {
state: self.state,
scorers: self.scorers,
cached_score: self.cached_score,
impact_type,
weight_fn,
is_hard,
_phantom: PhantomData,
}
}
pub fn penalize<W>(
self,
weight: W,
) -> CrossComplementedGroupedConstraintSetBuilder<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
impl Fn(&GK, &R) -> Sc + Send + Sync,
Sc,
>
where
W: for<'w> ConstraintWeight<(&'w GK, &'w R), Sc> + Send + Sync,
{
let is_hard = weight.is_hard();
self.into_weighted_builder(
ImpactType::Penalty,
move |key: &GK, result: &R| weight.score((key, result)),
is_hard,
)
}
pub fn reward<W>(
self,
weight: W,
) -> CrossComplementedGroupedConstraintSetBuilder<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
impl Fn(&GK, &R) -> Sc + Send + Sync,
Sc,
>
where
W: for<'w> ConstraintWeight<(&'w GK, &'w R), Sc> + Send + Sync,
{
let is_hard = weight.is_hard();
self.into_weighted_builder(
ImpactType::Reward,
move |key: &GK, result: &R| weight.score((key, result)),
is_hard,
)
}
}
impl<S, A, B, T, JK, GK, EA, EB, ET, KA, KB, F, GF, KT, C, V, R, Acc, D, Scorers, W, Sc>
CrossComplementedGroupedConstraintSetBuilder<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
Scorers,
W,
Sc,
>
where
S: Send + Sync + 'static,
A: Send + Sync + 'static,
B: Send + Sync + 'static,
T: Send + Sync + 'static,
JK: Eq + Hash + Send + Sync + 'static,
GK: Eq + Hash + Send + Sync + 'static,
EA: CollectionExtract<S, Item = A> + Send + Sync,
EB: CollectionExtract<S, Item = B> + Send + Sync,
ET: CollectionExtract<S, Item = T> + Send + Sync,
KA: Fn(&A) -> JK + Send + Sync,
KB: Fn(&B) -> JK + Send + Sync,
F: Fn(&S, &A, &B, usize, usize) -> bool + Send + Sync,
GF: Fn(&A, &B) -> GK + Send + Sync,
KT: Fn(&T) -> GK + Send + Sync,
C: for<'i> Collector<(&'i A, &'i B), Value = V, Result = R, Accumulator = Acc> + Send + Sync,
V: Send + Sync + 'static,
R: Send + Sync + 'static,
Acc: Accumulator<V, R> + Send + Sync + 'static,
D: Fn(&T) -> R + Send + Sync,
Scorers: ComplementedGroupedScorerSet<GK, R, Sc>,
W: Fn(&GK, &R) -> Sc + Send + Sync,
Sc: Score + 'static,
{
pub fn named(
self,
name: &str,
) -> SharedCrossComplementedGroupedConstraintSet<
S,
A,
B,
T,
JK,
GK,
EA,
EB,
ET,
KA,
KB,
F,
GF,
KT,
C,
V,
R,
Acc,
D,
(Scorers, GroupedTerminalScorer<GK, R, W, Sc>),
Sc,
> {
let scorer = GroupedTerminalScorer::new(
ConstraintRef::new("", name),
self.impact_type,
self.weight_fn,
self.is_hard,
);
SharedCrossComplementedGroupedConstraintSet {
state: self.state,
scorers: (self.scorers, scorer),
cached_score: self.cached_score,
_phantom: PhantomData,
}
}
}