use crate::{
constraints::props::{Propagate, Prune},
variables::VarId,
variables::views::{Context, View},
};
#[derive(Clone, Debug)]
#[doc(hidden)]
pub struct Count<T: View> {
vars: Vec<VarId>,
target: T,
count_var: VarId,
}
impl<T: View> Count<T> {
pub const fn new(vars: Vec<VarId>, target: T, count_var: VarId) -> Self {
Self {
vars,
target,
count_var,
}
}
fn count_definitely_equal(&self, ctx: &Context) -> i64 {
let target_min = self.target.min(ctx);
let target_max = self.target.max(ctx);
if target_min != target_max {
return 0;
}
self.vars.iter()
.filter(|&&var| {
let min = var.min(ctx);
let max = var.max(ctx);
min == max && min == target_min
})
.count() as i64
}
fn count_possibly_equal(&self, ctx: &Context) -> i64 {
let target_min = self.target.min(ctx);
let target_max = self.target.max(ctx);
self.vars.iter()
.filter(|&&var| {
let var_min = var.min(ctx);
let var_max = var.max(ctx);
var_min <= target_max && var_max >= target_min
})
.count() as i64
}
fn propagate_count_bounds(&self, ctx: &mut Context) -> Option<()> {
let definitely_equal = self.count_definitely_equal(ctx);
let possibly_equal = self.count_possibly_equal(ctx);
let definitely_equal_val = crate::variables::Val::ValI(definitely_equal as i32);
let possibly_equal_val = crate::variables::Val::ValI(possibly_equal as i32);
self.count_var.try_set_min(definitely_equal_val, ctx)?;
self.count_var.try_set_max(possibly_equal_val, ctx)?;
let count_min = self.count_var.min(ctx);
let count_max = self.count_var.max(ctx);
let target_min = self.target.min(ctx);
let target_max = self.target.max(ctx);
if count_min == count_max && target_min == target_max {
let target_count = match count_min {
crate::variables::Val::ValI(i) => i as i64,
crate::variables::Val::ValF(f) => f as i64,
};
if definitely_equal == target_count {
for &var in &self.vars {
let min = var.min(ctx);
let max = var.max(ctx);
if min != max && min <= target_min && target_min <= max {
match (target_min, min, max) {
(crate::variables::Val::ValI(tgt), crate::variables::Val::ValI(min_val), crate::variables::Val::ValI(max_val)) => {
if tgt == min_val && tgt < max_val {
var.try_set_min(crate::variables::Val::ValI(tgt + 1), ctx)?;
} else if tgt == max_val && tgt > min_val {
var.try_set_max(crate::variables::Val::ValI(tgt - 1), ctx)?;
}
}
_ => {} }
}
}
}
else if possibly_equal == target_count {
for &var in &self.vars {
let min = var.min(ctx);
let max = var.max(ctx);
if min != max && min <= target_min && target_min <= max {
var.try_set_min(target_min, ctx)?;
var.try_set_max(target_min, ctx)?;
}
}
}
}
Some(())
}
}
impl<T: View> Prune for Count<T> {
fn prune(&self, ctx: &mut Context) -> Option<()> {
self.propagate_count_bounds(ctx)
}
}
impl<T: View + 'static> Propagate for Count<T> {
fn list_trigger_vars(&self) -> impl Iterator<Item = VarId> {
let target_vars: Vec<VarId> = self.target.get_underlying_var().into_iter().collect();
self.vars.iter()
.copied()
.chain(target_vars.into_iter())
.chain(std::iter::once(self.count_var))
}
}
#[cfg(test)]
mod test_count_direct {
use super::*;
use crate::variables::Vars;
use crate::variables::views::Context;
use crate::variables::Val;
#[test]
fn test_count_constraint_direct() {
let mut vars = Vars::new();
let v1 = vars.new_var_with_bounds(Val::int(1), Val::int(3));
let v2 = vars.new_var_with_bounds(Val::int(1), Val::int(3));
let v3 = vars.new_var_with_bounds(Val::int(1), Val::int(3));
let target_var = vars.new_var_with_bounds(Val::int(1), Val::int(1));
let count_var = vars.new_var_with_bounds(Val::int(0), Val::int(3));
let count = Count::new(vec![v1, v2, v3], target_var, count_var);
let mut events = Vec::new();
let mut ctx = Context::new(&mut vars, &mut events);
let result = count.prune(&mut ctx);
assert!(result.is_some());
}
#[test]
fn test_count_trait_object_dispatch() {
println!("=== Testing Count trait object dispatch ===");
let mut vars = Vars::new();
let v1 = vars.new_var_with_bounds(Val::int(1), Val::int(3));
let v2 = vars.new_var_with_bounds(Val::int(1), Val::int(3));
let target_var = vars.new_var_with_bounds(Val::int(2), Val::int(2));
let count_var = vars.new_var_with_bounds(Val::int(0), Val::int(2));
let count = Count::new(vec![v1, v2], target_var, count_var);
let trait_object: Box<dyn Prune> = Box::new(count);
let shared_trait_object = std::rc::Rc::new(trait_object);
let mut events = Vec::new();
let mut ctx = Context::new(&mut vars, &mut events);
println!("Calling prune through trait object...");
let result = shared_trait_object.as_ref().prune(&mut ctx);
println!("Trait object prune result: {:?}", result.is_some());
assert!(result.is_some(), "Count constraint should work through trait object");
}
}