pumpkin_core/branching/value_selection/
random_splitter.rs1use crate::branching::value_selection::ValueSelector;
2use crate::branching::BrancherEvent;
3use crate::branching::SelectionContext;
4use crate::engine::predicates::predicate::Predicate;
5use crate::engine::variables::DomainId;
6use crate::predicate;
7
8#[derive(Debug, Clone, Copy)]
12pub struct RandomSplitter;
13
14impl ValueSelector<DomainId> for RandomSplitter {
15 fn select_value(
16 &mut self,
17 context: &mut SelectionContext,
18 decision_variable: DomainId,
19 ) -> Predicate {
20 let range =
22 context.lower_bound(decision_variable)..context.upper_bound(decision_variable) + 1;
23 let bound = context.random().generate_i32_in_range(range);
24
25 if bound == context.lower_bound(decision_variable) {
32 return predicate!(decision_variable <= bound);
33 } else if bound == context.upper_bound(decision_variable) {
34 return predicate!(decision_variable >= bound);
35 }
36
37 if context.random().generate_bool(0.5) {
39 predicate!(decision_variable >= bound)
40 } else {
41 predicate!(decision_variable <= bound)
42 }
43 }
44
45 fn is_restart_pointless(&mut self) -> bool {
46 false
47 }
48
49 fn subscribe_to_events(&self) -> Vec<BrancherEvent> {
50 vec![]
51 }
52}
53
54#[cfg(test)]
55mod tests {
56
57 use crate::basic_types::tests::TestRandom;
58 use crate::branching::value_selection::RandomSplitter;
59 use crate::branching::value_selection::ValueSelector;
60 use crate::branching::SelectionContext;
61 use crate::predicate;
62
63 #[test]
64 fn test_returns_correct_literal() {
65 let (assignments, _) = SelectionContext::create_for_testing(vec![(0, 10)]);
66 let mut test_random = TestRandom {
67 integers: vec![2],
68 bools: vec![true],
69 ..Default::default()
70 };
71 let mut context = SelectionContext::new(&assignments, &mut test_random);
72 let domain_ids = context.get_domains().collect::<Vec<_>>();
73
74 let mut selector = RandomSplitter;
75
76 let selected_predicate = selector.select_value(&mut context, domain_ids[0]);
77
78 assert_eq!(selected_predicate, predicate!(domain_ids[0] >= 2))
79 }
80}