Skip to main content

pumpkin_core/branching/variable_selection/
random.rs

1use super::VariableSelector;
2use crate::branching::BrancherEvent;
3use crate::branching::SelectionContext;
4use crate::containers::SparseSet;
5use crate::containers::StorageKey;
6use crate::variables::DomainId;
7
8/// A [`VariableSelector`] which selects a random unfixed variable.
9#[derive(Debug)]
10pub struct RandomSelector {
11    variables: SparseSet<DomainId>,
12}
13
14impl RandomSelector {
15    pub fn new(variables: impl IntoIterator<Item = DomainId>) -> Self {
16        // Note the -1 due to the fact that the indices of the domain ids start at 1
17        Self {
18            variables: SparseSet::new(variables.into_iter().collect(), |element| {
19                element.index() - 1
20            }),
21        }
22    }
23
24    /// Add a domain to consideration in the variable selection.
25    pub fn add_domain(&mut self, domain: DomainId) {
26        self.variables.insert(domain);
27    }
28}
29
30impl VariableSelector<DomainId> for RandomSelector {
31    fn select_variable(&mut self, context: &mut SelectionContext) -> Option<DomainId> {
32        if self.variables.is_empty() {
33            return None;
34        }
35
36        let mut variable = *self.variables.get(
37            context
38                .random()
39                .generate_usize_in_range(0..self.variables.len()),
40        );
41
42        while context.is_integer_fixed(variable) {
43            self.variables.remove_temporarily(&variable);
44            if self.variables.is_empty() {
45                return None;
46            }
47
48            variable = *self.variables.get(
49                context
50                    .random()
51                    .generate_usize_in_range(0..self.variables.len()),
52            );
53        }
54
55        Some(variable)
56    }
57
58    fn on_unassign_integer(&mut self, variable: DomainId, _value: i32) {
59        self.variables.insert(variable);
60    }
61
62    fn is_restart_pointless(&mut self) -> bool {
63        false
64    }
65
66    fn subscribe_to_events(&self) -> Vec<BrancherEvent> {
67        vec![BrancherEvent::UnassignInteger]
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use crate::basic_types::tests::TestRandom;
74    use crate::branching::SelectionContext;
75    use crate::branching::variable_selection::RandomSelector;
76    use crate::branching::variable_selection::VariableSelector;
77    use crate::predicate;
78
79    #[test]
80    fn test_selects_randomly() {
81        let (assignments, _) = SelectionContext::create_for_testing(vec![(0, 10), (5, 20), (1, 3)]);
82        let mut test_rng = TestRandom {
83            usizes: vec![1],
84            ..Default::default()
85        };
86        let integer_variables = assignments.get_domains().collect::<Vec<_>>();
87        let mut strategy = RandomSelector::new(assignments.get_domains());
88
89        let mut context = SelectionContext::new(&assignments, &mut test_rng);
90
91        let selected = strategy.select_variable(&mut context);
92        assert!(selected.is_some());
93        assert_eq!(selected.unwrap(), integer_variables[1]);
94    }
95
96    #[test]
97    fn test_selects_randomly_not_unfixed() {
98        let (assignments, _) = SelectionContext::create_for_testing(vec![(0, 10), (5, 5), (1, 3)]);
99        let mut test_rng = TestRandom {
100            usizes: vec![1, 0],
101            ..Default::default()
102        };
103        let integer_variables = assignments.get_domains().collect::<Vec<_>>();
104        let mut strategy = RandomSelector::new(assignments.get_domains());
105
106        let mut context = SelectionContext::new(&assignments, &mut test_rng);
107
108        let selected = strategy.select_variable(&mut context);
109        assert!(selected.is_some());
110        assert_eq!(selected.unwrap(), integer_variables[0]);
111    }
112
113    #[test]
114    fn test_select_nothing_if_all_fixed() {
115        let (assignments, _) = SelectionContext::create_for_testing(vec![(0, 0), (5, 5), (1, 1)]);
116        let mut test_rng = TestRandom {
117            usizes: vec![1, 0, 0],
118            ..Default::default()
119        };
120        let mut strategy = RandomSelector::new(assignments.get_domains());
121
122        let mut context = SelectionContext::new(&assignments, &mut test_rng);
123
124        let selected = strategy.select_variable(&mut context);
125        assert!(selected.is_none());
126    }
127
128    #[test]
129    fn test_select_unfixed_variable_after_fixing() {
130        let (mut assignments, mut notification_engine) =
131            SelectionContext::create_for_testing(vec![(0, 0), (5, 7), (1, 1)]);
132        let mut test_rng = TestRandom {
133            usizes: vec![2, 0, 0, 0, 0],
134            ..Default::default()
135        };
136        let integer_variables = assignments.get_domains().collect::<Vec<_>>();
137        let mut strategy = RandomSelector::new(assignments.get_domains());
138
139        {
140            let mut context = SelectionContext::new(&assignments, &mut test_rng);
141
142            let selected = strategy.select_variable(&mut context);
143            assert!(selected.is_some());
144            assert_eq!(selected.unwrap(), integer_variables[1]);
145        }
146
147        assignments.new_checkpoint();
148        let _ = assignments.post_predicate(
149            predicate!(integer_variables[1] >= 7),
150            None,
151            &mut notification_engine,
152        );
153
154        {
155            let mut context = SelectionContext::new(&assignments, &mut test_rng);
156
157            let selected = strategy.select_variable(&mut context);
158            assert!(selected.is_none());
159        }
160
161        let _ = assignments.synchronise(0, &mut notification_engine);
162        strategy.on_unassign_integer(integer_variables[1], 7);
163        let mut context = SelectionContext::new(&assignments, &mut test_rng);
164        let selected = strategy.select_variable(&mut context);
165        assert!(selected.is_some());
166        assert_eq!(selected.unwrap(), integer_variables[1]);
167    }
168}