use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use super::collection_extract::CollectionExtract;
use super::collector::UniCollector;
use crate::constraint::complemented::ComplementedGroupConstraint;
pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
where
Sc: Score,
{
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
EA: CollectionExtract<S, Item = A>,
EB: CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K> + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync + 'static,
C::Accumulator: Send + Sync,
C::Result: Clone + Send + Sync,
D: Fn(&B) -> C::Result + Send + Sync,
Sc: Score + 'static,
{
pub(crate) fn new(
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
) -> Self {
Self {
extractor_a,
extractor_b,
key_a,
key_b,
collector,
default_fn,
_phantom: PhantomData,
}
}
pub fn penalize_with<W>(
self,
weight_fn: W,
) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
W: Fn(&C::Result) -> Sc + Send + Sync,
{
ComplementedConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
collector: self.collector,
default_fn: self.default_fn,
impact_type: ImpactType::Penalty,
weight_fn,
is_hard: false,
_phantom: PhantomData,
}
}
pub fn penalize_hard_with<W>(
self,
weight_fn: W,
) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
W: Fn(&C::Result) -> Sc + Send + Sync,
{
ComplementedConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
collector: self.collector,
default_fn: self.default_fn,
impact_type: ImpactType::Penalty,
weight_fn,
is_hard: true,
_phantom: PhantomData,
}
}
pub fn reward_with<W>(
self,
weight_fn: W,
) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
W: Fn(&C::Result) -> Sc + Send + Sync,
{
ComplementedConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
collector: self.collector,
default_fn: self.default_fn,
impact_type: ImpactType::Reward,
weight_fn,
is_hard: false,
_phantom: PhantomData,
}
}
pub fn reward_hard_with<W>(
self,
weight_fn: W,
) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
W: Fn(&C::Result) -> Sc + Send + Sync,
{
ComplementedConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
collector: self.collector,
default_fn: self.default_fn,
impact_type: ImpactType::Reward,
weight_fn,
is_hard: true,
_phantom: PhantomData,
}
}
pub fn penalize_hard(
self,
) -> ComplementedConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
C,
D,
impl Fn(&C::Result) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let w = Sc::one_hard();
self.penalize_hard_with(move |_: &C::Result| w)
}
pub fn penalize_soft(
self,
) -> ComplementedConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
C,
D,
impl Fn(&C::Result) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let w = Sc::one_soft();
self.penalize_with(move |_: &C::Result| w)
}
pub fn reward_hard(
self,
) -> ComplementedConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
C,
D,
impl Fn(&C::Result) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let w = Sc::one_hard();
self.reward_hard_with(move |_: &C::Result| w)
}
pub fn reward_soft(
self,
) -> ComplementedConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
C,
D,
impl Fn(&C::Result) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let w = Sc::one_soft();
self.reward_with(move |_: &C::Result| w)
}
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComplementedConstraintStream").finish()
}
}
pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
Sc: Score,
{
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
impact_type: ImpactType,
weight_fn: W,
is_hard: bool,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
EA: CollectionExtract<S, Item = A>,
EB: CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K> + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync + 'static,
C::Accumulator: Send + Sync,
C::Result: Clone + Send + Sync,
D: Fn(&B) -> C::Result + Send + Sync,
W: Fn(&C::Result) -> Sc + Send + Sync,
Sc: Score + 'static,
{
pub fn named(
self,
name: &str,
) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
ComplementedGroupConstraint::new(
ConstraintRef::new("", name),
self.impact_type,
self.extractor_a,
self.extractor_b,
self.key_a,
self.key_b,
self.collector,
self.default_fn,
self.weight_fn,
self.is_hard,
)
}
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComplementedConstraintBuilder")
.field("impact_type", &self.impact_type)
.finish()
}
}