pumpkin_core/branching/value_selection/
random_splitter.rs

1use crate::branching::value_selection::ValueSelector;
2use crate::branching::BrancherEvent;
3use crate::branching::SelectionContext;
4use crate::engine::predicates::predicate::Predicate;
5use crate::engine::variables::DomainId;
6use crate::predicate;
7
8/// A [`ValueSelector`] which splits the domain in a random manner (between the lower-bound and
9/// lower-bound, disregarding holes), randomly selecting whether to exclude the lower-half or the
10/// upper-half.
11#[derive(Debug, Clone, Copy)]
12pub struct RandomSplitter;
13
14impl ValueSelector<DomainId> for RandomSplitter {
15    fn select_value(
16        &mut self,
17        context: &mut SelectionContext,
18        decision_variable: DomainId,
19    ) -> Predicate {
20        // Randomly generate a value within the lower-bound and upper-bound
21        let lb = context.lower_bound(decision_variable);
22        let ub = context.upper_bound(decision_variable);
23        let bound = context.random().generate_i32_in_range(lb, ub);
24
25        // We need to handle two special cases:
26        //
27        // 1. If the bound is equal to the lower-bound then we need to assign it to this bound since
28        //    [x >= lb] is currently true
29        // 2. If the bound is equal to the upper-bound then we need to assign it to this bound since
30        //    [x <= ub] is currentl true
31        if bound == context.lower_bound(decision_variable) {
32            return predicate!(decision_variable <= bound);
33        } else if bound == context.upper_bound(decision_variable) {
34            return predicate!(decision_variable >= bound);
35        }
36
37        // Then randomly determine how to split the domain
38        if context.random().generate_bool(0.5) {
39            predicate!(decision_variable >= bound)
40        } else {
41            predicate!(decision_variable <= bound)
42        }
43    }
44
45    fn is_restart_pointless(&mut self) -> bool {
46        false
47    }
48
49    fn subscribe_to_events(&self) -> Vec<BrancherEvent> {
50        vec![]
51    }
52}
53
54#[cfg(test)]
55mod tests {
56
57    use crate::basic_types::tests::TestRandom;
58    use crate::branching::value_selection::RandomSplitter;
59    use crate::branching::value_selection::ValueSelector;
60    use crate::branching::SelectionContext;
61    use crate::predicate;
62
63    #[test]
64    fn test_returns_correct_literal() {
65        let (assignments, _) = SelectionContext::create_for_testing(vec![(0, 10)]);
66        let mut test_random = TestRandom {
67            integers: vec![2],
68            bools: vec![true],
69            ..Default::default()
70        };
71        let mut context = SelectionContext::new(&assignments, &mut test_random);
72        let domain_ids = context.get_domains().collect::<Vec<_>>();
73
74        let mut selector = RandomSplitter;
75
76        let selected_predicate = selector.select_value(&mut context, domain_ids[0]);
77
78        assert_eq!(selected_predicate, predicate!(domain_ids[0] >= 2))
79    }
80}