use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::constraint::cross_bi_incremental::IncrementalCrossBiConstraint;
use super::collection_extract::CollectionExtract;
use super::filter::{AndBiFilter, BiFilter, FnBiFilter, TrueFilter};
use super::flattened_bi_stream::FlattenedBiConstraintStream;
pub struct CrossBiConstraintStream<S, A, B, K, EA, EB, KA, KB, F, Sc>
where
Sc: Score,
{
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
filter: F,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, Sc>
CrossBiConstraintStream<S, A, B, K, EA, EB, KA, KB, TrueFilter, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Eq + Hash + Clone + Send + Sync,
EA: CollectionExtract<S, Item = A>,
EB: CollectionExtract<S, Item = B>,
KA: Fn(&A) -> K + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
Sc: Score + 'static,
{
pub fn new(extractor_a: EA, extractor_b: EB, key_a: KA, key_b: KB) -> Self {
Self {
extractor_a,
extractor_b,
key_a,
key_b,
filter: TrueFilter,
_phantom: PhantomData,
}
}
}
impl<S, A, B, K, EA, EB, KA, KB, F, Sc> CrossBiConstraintStream<S, A, B, K, EA, EB, KA, KB, F, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Eq + Hash + Clone + Send + Sync,
EA: CollectionExtract<S, Item = A>,
EB: CollectionExtract<S, Item = B>,
KA: Fn(&A) -> K + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
F: BiFilter<S, A, B>,
Sc: Score + 'static,
{
pub fn new_with_filter(
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
filter: F,
) -> Self {
Self {
extractor_a,
extractor_b,
key_a,
key_b,
filter,
_phantom: PhantomData,
}
}
pub fn filter<P>(
self,
predicate: P,
) -> CrossBiConstraintStream<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
AndBiFilter<F, FnBiFilter<impl Fn(&S, &A, &B) -> bool + Send + Sync>>,
Sc,
>
where
P: Fn(&A, &B) -> bool + Send + Sync,
{
CrossBiConstraintStream {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: AndBiFilter::new(
self.filter,
FnBiFilter::new(move |_s: &S, a: &A, b: &B| predicate(a, b)),
),
_phantom: PhantomData,
}
}
pub fn penalize(
self,
weight: Sc,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let is_hard = weight
.to_level_numbers()
.first()
.map(|&h| h != 0)
.unwrap_or(false);
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Penalty,
weight: move |_: &A, _: &B| weight,
is_hard,
_phantom: PhantomData,
}
}
pub fn penalize_with<W>(
self,
weight_fn: W,
) -> CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
W: Fn(&A, &B) -> Sc + Send + Sync,
{
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Penalty,
weight: weight_fn,
is_hard: false,
_phantom: PhantomData,
}
}
pub fn penalize_hard_with<W>(
self,
weight_fn: W,
) -> CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
W: Fn(&A, &B) -> Sc + Send + Sync,
{
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Penalty,
weight: weight_fn,
is_hard: true,
_phantom: PhantomData,
}
}
pub fn reward(
self,
weight: Sc,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
let is_hard = weight
.to_level_numbers()
.first()
.map(|&h| h != 0)
.unwrap_or(false);
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Reward,
weight: move |_: &A, _: &B| weight,
is_hard,
_phantom: PhantomData,
}
}
pub fn reward_with<W>(
self,
weight_fn: W,
) -> CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
W: Fn(&A, &B) -> Sc + Send + Sync,
{
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Reward,
weight: weight_fn,
is_hard: false,
_phantom: PhantomData,
}
}
pub fn reward_hard_with<W>(
self,
weight_fn: W,
) -> CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
W: Fn(&A, &B) -> Sc + Send + Sync,
{
CrossBiConstraintBuilder {
extractor_a: self.extractor_a,
extractor_b: self.extractor_b,
key_a: self.key_a,
key_b: self.key_b,
filter: self.filter,
impact_type: ImpactType::Reward,
weight: weight_fn,
is_hard: true,
_phantom: PhantomData,
}
}
pub fn penalize_hard(
self,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
self.penalize(Sc::one_hard())
}
pub fn penalize_soft(
self,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
self.penalize(Sc::one_soft())
}
pub fn reward_hard(
self,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
self.reward(Sc::one_hard())
}
pub fn reward_soft(
self,
) -> CrossBiConstraintBuilder<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
F,
impl Fn(&A, &B) -> Sc + Send + Sync,
Sc,
>
where
Sc: Copy,
{
self.reward(Sc::one_soft())
}
pub fn flatten_last<C, CK, Flatten, CKeyFn, ALookup>(
self,
flatten: Flatten,
c_key_fn: CKeyFn,
a_lookup_fn: ALookup,
) -> FlattenedBiConstraintStream<
S,
A,
B,
C,
K,
CK,
EA,
EB,
KA,
KB,
Flatten,
CKeyFn,
ALookup,
super::filter::TrueFilter,
Sc,
>
where
C: Clone + Send + Sync + 'static,
CK: Eq + Hash + Clone + Send + Sync,
Flatten: Fn(&B) -> &[C] + Send + Sync,
CKeyFn: Fn(&C) -> CK + Send + Sync,
ALookup: Fn(&A) -> CK + Send + Sync,
{
FlattenedBiConstraintStream::new(
self.extractor_a,
self.extractor_b,
self.key_a,
self.key_b,
flatten,
c_key_fn,
a_lookup_fn,
)
}
}
impl<S, A, B, K, EA, EB, KA, KB, F, Sc: Score> std::fmt::Debug
for CrossBiConstraintStream<S, A, B, K, EA, EB, KA, KB, F, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CrossBiConstraintStream").finish()
}
}
pub struct CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
Sc: Score,
{
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
filter: F,
impact_type: ImpactType,
weight: W,
is_hard: bool,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Eq + Hash + Clone + Send + Sync,
EA: CollectionExtract<S, Item = A> + Clone,
EB: CollectionExtract<S, Item = B> + Clone,
KA: Fn(&A) -> K + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
F: BiFilter<S, A, B>,
W: Fn(&A, &B) -> Sc + Send + Sync,
Sc: Score + 'static,
{
pub fn named(
self,
name: &str,
) -> IncrementalCrossBiConstraint<
S,
A,
B,
K,
EA,
EB,
KA,
KB,
impl Fn(&S, &A, &B) -> bool + Send + Sync,
impl Fn(&S, usize, usize) -> Sc + Send + Sync,
Sc,
> {
let filter = self.filter;
let combined_filter = move |s: &S, a: &A, b: &B| filter.test(s, a, b, 0, 0);
let extractor_a = self.extractor_a.clone();
let extractor_b = self.extractor_b.clone();
let weight = self.weight;
let adapted_weight = move |s: &S, a_idx: usize, b_idx: usize| {
let entities_a = extractor_a.extract(s);
let entities_b = extractor_b.extract(s);
let a = &entities_a[a_idx];
let b = &entities_b[b_idx];
weight(a, b)
};
IncrementalCrossBiConstraint::new(
ConstraintRef::new("", name),
self.impact_type,
self.extractor_a,
self.extractor_b,
self.key_a,
self.key_b,
combined_filter,
adapted_weight,
self.is_hard,
)
}
}
impl<S, A, B, K, EA, EB, KA, KB, F, W, Sc: Score> std::fmt::Debug
for CrossBiConstraintBuilder<S, A, B, K, EA, EB, KA, KB, F, W, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CrossBiConstraintBuilder")
.field("impact_type", &self.impact_type)
.finish()
}
}