Skip to main content

pumpkin_propagators/propagators/disjunctive/
checker.rs

1use std::cmp::max;
2use std::cmp::min;
3use std::marker::PhantomData;
4
5use pumpkin_checking::AtomicConstraint;
6use pumpkin_checking::CheckerVariable;
7use pumpkin_checking::InferenceChecker;
8use pumpkin_checking::IntExt;
9use pumpkin_checking::VariableState;
10use pumpkin_core::containers::KeyedVec;
11use pumpkin_core::containers::StorageKey;
12use pumpkin_core::propagation::LocalId;
13
14use crate::disjunctive::ArgDisjunctiveTask;
15use crate::disjunctive::disjunctive_task::DisjunctiveTask;
16use crate::disjunctive::theta_lambda_tree::Node;
17
18#[derive(Clone, Debug)]
19pub struct DisjunctiveEdgeFindingChecker<Var> {
20    pub tasks: Box<[ArgDisjunctiveTask<Var>]>,
21}
22
23impl<Var, Atomic> InferenceChecker<Atomic> for DisjunctiveEdgeFindingChecker<Var>
24where
25    Var: CheckerVariable<Atomic>,
26    Atomic: AtomicConstraint,
27{
28    fn check(
29        &self,
30        state: VariableState<Atomic>,
31        _premises: &[Atomic],
32        consequent: Option<&Atomic>,
33    ) -> bool {
34        // Recall the following:
35        // - For conflict detection, the explanation represents a set omega with the following
36        //   property: `p_omega > lct_omega - est_omega`.
37        //
38        //   We simply need to check whether the interval [est_omega, lct_omega] is overloaded
39        // - For propagation, the explanation represents a set omega (and omega') such that the
40        //   following holds: `min(est_i, est_omega) + p_omega + p_i > lct_omega -> [s_i >=
41        //   ect_omega]`.
42        let mut lb_interval = i32::MAX;
43        let mut ub_interval = i32::MIN;
44        let mut p = 0;
45        let mut propagating_task = None;
46        let mut theta = Vec::new();
47
48        // We go over all of the tasks
49        for task in self.tasks.iter() {
50            // Only if they are present in the explanation, do we actually process them
51            // - For tasks in omega, both bounds should be present to define the interval
52            // - For the propagating task, the lower-bound should be present, and the negation of
53            //   the consequent ensures that an upper-bound is present
54            if task.start_time.induced_lower_bound(&state) != IntExt::NegativeInf
55                && task.start_time.induced_upper_bound(&state) != IntExt::PositiveInf
56            {
57                // Now we calculate the durations of tasks
58                let est_task: i32 = task
59                    .start_time
60                    .induced_lower_bound(&state)
61                    .try_into()
62                    .unwrap();
63                let lst_task =
64                    <IntExt as TryInto<i32>>::try_into(task.start_time.induced_upper_bound(&state))
65                        .unwrap();
66
67                let is_propagating_task = if let Some(consequent) = consequent {
68                    task.start_time.does_atomic_constrain_self(consequent)
69                } else {
70                    false
71                };
72                if !is_propagating_task {
73                    theta.push(task.clone());
74                    p += task.processing_time;
75                    lb_interval = lb_interval.min(est_task);
76                    ub_interval = ub_interval.max(lst_task + task.processing_time);
77                } else {
78                    propagating_task = Some(task.clone());
79                }
80            }
81        }
82
83        if consequent.is_some() {
84            let propagating_task = propagating_task
85                .expect("If there is a consequent then there should be a propagating task");
86
87            let est_task = propagating_task
88                .start_time
89                .induced_lower_bound(&state)
90                .try_into()
91                .unwrap();
92
93            let mut theta_lambda_tree = CheckerThetaLambdaTree::new(
94                &theta
95                    .iter()
96                    .enumerate()
97                    .map(|(index, task)| DisjunctiveTask {
98                        start_time: task.start_time.clone(),
99                        processing_time: task.processing_time,
100                        id: LocalId::from(index as u32),
101                    })
102                    .collect::<Vec<_>>(),
103            );
104            theta_lambda_tree.update(&state);
105            for (index, task) in theta.iter().enumerate() {
106                theta_lambda_tree.add_to_theta(
107                    &DisjunctiveTask {
108                        start_time: task.start_time.clone(),
109                        processing_time: task.processing_time,
110                        id: LocalId::from(index as u32),
111                    },
112                    &state,
113                );
114            }
115
116            min(est_task, lb_interval) + p + propagating_task.processing_time > ub_interval
117                && theta_lambda_tree.ect() > propagating_task.start_time.induced_upper_bound(&state)
118        } else {
119            // We simply check whether the interval is overloaded
120            p > (ub_interval - lb_interval)
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use pumpkin_checking::TestAtomic;
128    use pumpkin_checking::VariableState;
129
130    use super::*;
131
132    #[test]
133    fn test_simple_propagation() {
134        let premises = [
135            TestAtomic {
136                name: "x1",
137                comparison: pumpkin_checking::Comparison::GreaterEqual,
138                value: 0,
139            },
140            TestAtomic {
141                name: "x1",
142                comparison: pumpkin_checking::Comparison::LessEqual,
143                value: 7,
144            },
145            TestAtomic {
146                name: "x2",
147                comparison: pumpkin_checking::Comparison::GreaterEqual,
148                value: 5,
149            },
150            TestAtomic {
151                name: "x2",
152                comparison: pumpkin_checking::Comparison::LessEqual,
153                value: 6,
154            },
155            TestAtomic {
156                name: "x3",
157                comparison: pumpkin_checking::Comparison::GreaterEqual,
158                value: 0,
159            },
160        ];
161
162        let consequent = Some(TestAtomic {
163            name: "x3",
164            comparison: pumpkin_checking::Comparison::GreaterEqual,
165            value: 8,
166        });
167        let state = VariableState::prepare_for_conflict_check(premises, consequent)
168            .expect("no conflicting atomics");
169
170        let checker = DisjunctiveEdgeFindingChecker {
171            tasks: vec![
172                ArgDisjunctiveTask {
173                    start_time: "x1",
174                    processing_time: 2,
175                },
176                ArgDisjunctiveTask {
177                    start_time: "x2",
178                    processing_time: 3,
179                },
180                ArgDisjunctiveTask {
181                    start_time: "x3",
182                    processing_time: 5,
183                },
184            ]
185            .into(),
186        };
187
188        assert!(checker.check(state, &premises, consequent.as_ref()));
189    }
190
191    #[test]
192    fn test_conflict() {
193        let premises = [
194            TestAtomic {
195                name: "x1",
196                comparison: pumpkin_checking::Comparison::GreaterEqual,
197                value: 0,
198            },
199            TestAtomic {
200                name: "x1",
201                comparison: pumpkin_checking::Comparison::LessEqual,
202                value: 1,
203            },
204            TestAtomic {
205                name: "x2",
206                comparison: pumpkin_checking::Comparison::GreaterEqual,
207                value: 0,
208            },
209            TestAtomic {
210                name: "x2",
211                comparison: pumpkin_checking::Comparison::LessEqual,
212                value: 1,
213            },
214        ];
215
216        let state = VariableState::prepare_for_conflict_check(premises, None)
217            .expect("no conflicting atomics");
218
219        let checker = DisjunctiveEdgeFindingChecker {
220            tasks: vec![
221                ArgDisjunctiveTask {
222                    start_time: "x1",
223                    processing_time: 2,
224                },
225                ArgDisjunctiveTask {
226                    start_time: "x2",
227                    processing_time: 3,
228                },
229            ]
230            .into(),
231        };
232
233        assert!(checker.check(state, &premises, None));
234    }
235
236    #[test]
237    fn test_simple_propagation_not_accepted() {
238        let premises = [
239            TestAtomic {
240                name: "x1",
241                comparison: pumpkin_checking::Comparison::GreaterEqual,
242                value: 0,
243            },
244            TestAtomic {
245                name: "x1",
246                comparison: pumpkin_checking::Comparison::LessEqual,
247                value: 7,
248            },
249            TestAtomic {
250                name: "x2",
251                comparison: pumpkin_checking::Comparison::GreaterEqual,
252                value: 5,
253            },
254            TestAtomic {
255                name: "x2",
256                comparison: pumpkin_checking::Comparison::LessEqual,
257                value: 6,
258            },
259            TestAtomic {
260                name: "x3",
261                comparison: pumpkin_checking::Comparison::GreaterEqual,
262                value: 0,
263            },
264        ];
265
266        let consequent = Some(TestAtomic {
267            name: "x3",
268            comparison: pumpkin_checking::Comparison::GreaterEqual,
269            value: 9,
270        });
271        let state = VariableState::prepare_for_conflict_check(premises, consequent)
272            .expect("no conflicting atomics");
273
274        let checker = DisjunctiveEdgeFindingChecker {
275            tasks: vec![
276                ArgDisjunctiveTask {
277                    start_time: "x1",
278                    processing_time: 2,
279                },
280                ArgDisjunctiveTask {
281                    start_time: "x2",
282                    processing_time: 3,
283                },
284                ArgDisjunctiveTask {
285                    start_time: "x3",
286                    processing_time: 5,
287                },
288            ]
289            .into(),
290        };
291
292        assert!(!checker.check(state, &premises, consequent.as_ref()));
293    }
294
295    #[test]
296    fn test_conflict_not_accepted() {
297        let premises = [
298            TestAtomic {
299                name: "x1",
300                comparison: pumpkin_checking::Comparison::GreaterEqual,
301                value: 0,
302            },
303            TestAtomic {
304                name: "x1",
305                comparison: pumpkin_checking::Comparison::LessEqual,
306                value: 1,
307            },
308            TestAtomic {
309                name: "x2",
310                comparison: pumpkin_checking::Comparison::GreaterEqual,
311                value: 0,
312            },
313            TestAtomic {
314                name: "x2",
315                comparison: pumpkin_checking::Comparison::LessEqual,
316                value: 2,
317            },
318        ];
319
320        let state = VariableState::prepare_for_conflict_check(premises, None)
321            .expect("no conflicting atomics");
322
323        let checker = DisjunctiveEdgeFindingChecker {
324            tasks: vec![
325                ArgDisjunctiveTask {
326                    start_time: "x1",
327                    processing_time: 2,
328                },
329                ArgDisjunctiveTask {
330                    start_time: "x2",
331                    processing_time: 3,
332                },
333            ]
334            .into(),
335        };
336
337        assert!(!checker.check(state, &premises, None));
338    }
339}
340
341#[derive(Debug, Clone)]
342pub(super) struct CheckerThetaLambdaTree<Var, Atomic> {
343    pub(super) nodes: Vec<Node>,
344    /// Then we keep track of a mapping from the [`LocalId`] to its position in the tree since the
345    /// methods take as input tasks with [`LocalId`]s.
346    mapping: KeyedVec<LocalId, usize>,
347    /// The number of internal nodes in the tree; used to calculate the leaf node index based on
348    /// the index in the tree
349    number_of_internal_nodes: usize,
350    /// The tasks which are stored in the leaves of the tree.
351    ///
352    /// These tasks are sorted based on non-decreasing start time.
353    sorted_tasks: Vec<DisjunctiveTask<Var>>,
354    phantom_data: PhantomData<Atomic>,
355}
356
357impl<Atomic: AtomicConstraint, Var: CheckerVariable<Atomic>> CheckerThetaLambdaTree<Var, Atomic> {
358    /// Initialises the theta-lambda tree.
359    ///
360    /// Note that [`Self::update`] should be called to actually create the tree itself.
361    pub(super) fn new(tasks: &[DisjunctiveTask<Var>]) -> Self {
362        // Calculate the number of internal nodes which are required to create the binary tree
363        let mut number_of_internal_nodes = 1;
364        while number_of_internal_nodes < tasks.len() {
365            number_of_internal_nodes <<= 1;
366        }
367
368        CheckerThetaLambdaTree {
369            nodes: Default::default(),
370            mapping: KeyedVec::default(),
371            number_of_internal_nodes: number_of_internal_nodes - 1,
372            sorted_tasks: tasks.to_vec(),
373            phantom_data: PhantomData,
374        }
375    }
376
377    /// Update the theta-lambda tree based on the provided `context`.
378    ///
379    /// It resets theta and lambda to be the empty set.
380    pub(super) fn update(&mut self, context: &VariableState<Atomic>) {
381        // First we sort the tasks by lower-bound/earliest start time.
382        self.sorted_tasks
383            .sort_by_key(|task| task.start_time.induced_lower_bound(context));
384
385        // Then we keep track of a mapping from the [`LocalId`] to its position in the tree and a
386        // reverse mapping
387        self.mapping.clear();
388        for (index, task) in self.sorted_tasks.iter().enumerate() {
389            while self.mapping.len() <= task.id.index() {
390                let _ = self.mapping.push(usize::MAX);
391            }
392            self.mapping[task.id] = index;
393        }
394
395        // Finally, we reset the entire tree to be empty
396        self.nodes.clear();
397        for _ in 0..=2 * self.number_of_internal_nodes {
398            self.nodes.push(Node::empty())
399        }
400    }
401
402    /// Returns the earliest completion time of Theta
403    pub(super) fn ect(&self) -> i32 {
404        assert!(!self.nodes.is_empty());
405        self.nodes[0].ect
406    }
407
408    /// Add the provided task to Theta
409    pub(super) fn add_to_theta(
410        &mut self,
411        task: &DisjunctiveTask<Var>,
412        context: &VariableState<Atomic>,
413    ) {
414        // We need to find the leaf node index; note that there are |nodes| / 2 leaves
415        let position = self.nodes.len() / 2 + self.mapping[task.id];
416        let ect = task.start_time.induced_lower_bound(context) + task.processing_time;
417
418        self.nodes[position] = Node::new_white_node(
419            ect.try_into().expect("Should have bounds"),
420            task.processing_time,
421        );
422        self.upheap(position)
423    }
424
425    /// Returns the index of the left child of the provided index
426    fn get_left_child_index(index: usize) -> usize {
427        2 * index + 1
428    }
429
430    /// Returns the index of the right child of the provided index
431    fn get_right_child_index(index: usize) -> usize {
432        2 * index + 2
433    }
434
435    /// Returns the index of the parent of the provided index
436    fn get_parent(index: usize) -> usize {
437        assert!(index > 0);
438        (index - 1) / 2
439    }
440
441    /// Calculate the new values for the ancestors of the provided index
442    pub(super) fn upheap(&mut self, mut index: usize) {
443        while index != 0 {
444            let parent = Self::get_parent(index);
445            let left_child_of_parent = Self::get_left_child_index(parent);
446            let right_child_of_parent = Self::get_right_child_index(parent);
447            assert!(left_child_of_parent == index || right_child_of_parent == index);
448
449            // The sum of processing times is the sum of processing times in the left child + the
450            // sum of processing times in right child
451            self.nodes[parent].sum_of_processing_times = self.nodes[left_child_of_parent]
452                .sum_of_processing_times
453                + self.nodes[right_child_of_parent].sum_of_processing_times;
454
455            // The ECT is either the ECT of the left child node + the processing times of the right
456            // child or it is the ECT of the right child (we do not know whether the processing
457            // times of the left child influence the processing times of the right child)
458            let ect_left = self.nodes[left_child_of_parent].ect
459                + self.nodes[right_child_of_parent].sum_of_processing_times;
460            self.nodes[parent].ect = max(self.nodes[right_child_of_parent].ect, ect_left);
461
462            // The sum of processing times (including one element of lambda) is either:
463            // 1) The sum of processing times of the right child + the sum of processing times of
464            //    the left child including one element of lambda
465            // 2) The sum of processing times of the left child + the sum of processing times of the
466            //    right child include one element of lambda
467            let sum_of_processing_times_left_child_lambda = self.nodes[left_child_of_parent]
468                .sum_of_processing_times_bar
469                + self.nodes[right_child_of_parent].sum_of_processing_times;
470            let sum_of_processing_times_right_child_lambda = self.nodes[left_child_of_parent]
471                .sum_of_processing_times
472                + self.nodes[right_child_of_parent].sum_of_processing_times_bar;
473            self.nodes[parent].sum_of_processing_times_bar = max(
474                sum_of_processing_times_left_child_lambda,
475                sum_of_processing_times_right_child_lambda,
476            );
477
478            // The earliest completion time (including one element of lambda) is either:
479            // 1) The earliest completion time including one element of lambda from the right child
480            // 2) The earliest completion time of the right child + the sum of processing times
481            //    including one element of lambda of the right child
482            // 2) The earliest completion time of the left child + the sum of processing times
483            //    including one element of lambda of the left child
484            let ect_right_child_lambda = self.nodes[left_child_of_parent].ect
485                + self.nodes[right_child_of_parent].sum_of_processing_times_bar;
486            let ect_left_child_lambda = self.nodes[left_child_of_parent].ect_bar
487                + self.nodes[right_child_of_parent].sum_of_processing_times;
488            self.nodes[parent].ect_bar = max(
489                self.nodes[right_child_of_parent].ect_bar,
490                max(ect_right_child_lambda, ect_left_child_lambda),
491            );
492
493            index = parent;
494        }
495    }
496}