1use std::cmp::max;
2use std::cmp::min;
3use std::marker::PhantomData;
4
5use pumpkin_checking::AtomicConstraint;
6use pumpkin_checking::CheckerVariable;
7use pumpkin_checking::InferenceChecker;
8use pumpkin_checking::IntExt;
9use pumpkin_checking::VariableState;
10use pumpkin_core::containers::KeyedVec;
11use pumpkin_core::containers::StorageKey;
12use pumpkin_core::propagation::LocalId;
13
14use crate::disjunctive::ArgDisjunctiveTask;
15use crate::disjunctive::disjunctive_task::DisjunctiveTask;
16use crate::disjunctive::theta_lambda_tree::Node;
17
18#[derive(Clone, Debug)]
19pub struct DisjunctiveEdgeFindingChecker<Var> {
20 pub tasks: Box<[ArgDisjunctiveTask<Var>]>,
21}
22
23impl<Var, Atomic> InferenceChecker<Atomic> for DisjunctiveEdgeFindingChecker<Var>
24where
25 Var: CheckerVariable<Atomic>,
26 Atomic: AtomicConstraint,
27{
28 fn check(
29 &self,
30 state: VariableState<Atomic>,
31 _premises: &[Atomic],
32 consequent: Option<&Atomic>,
33 ) -> bool {
34 let mut lb_interval = i32::MAX;
43 let mut ub_interval = i32::MIN;
44 let mut p = 0;
45 let mut propagating_task = None;
46 let mut theta = Vec::new();
47
48 for task in self.tasks.iter() {
50 if task.start_time.induced_lower_bound(&state) != IntExt::NegativeInf
55 && task.start_time.induced_upper_bound(&state) != IntExt::PositiveInf
56 {
57 let est_task: i32 = task
59 .start_time
60 .induced_lower_bound(&state)
61 .try_into()
62 .unwrap();
63 let lst_task =
64 <IntExt as TryInto<i32>>::try_into(task.start_time.induced_upper_bound(&state))
65 .unwrap();
66
67 let is_propagating_task = if let Some(consequent) = consequent {
68 task.start_time.does_atomic_constrain_self(consequent)
69 } else {
70 false
71 };
72 if !is_propagating_task {
73 theta.push(task.clone());
74 p += task.processing_time;
75 lb_interval = lb_interval.min(est_task);
76 ub_interval = ub_interval.max(lst_task + task.processing_time);
77 } else {
78 propagating_task = Some(task.clone());
79 }
80 }
81 }
82
83 if consequent.is_some() {
84 let propagating_task = propagating_task
85 .expect("If there is a consequent then there should be a propagating task");
86
87 let est_task = propagating_task
88 .start_time
89 .induced_lower_bound(&state)
90 .try_into()
91 .unwrap();
92
93 let mut theta_lambda_tree = CheckerThetaLambdaTree::new(
94 &theta
95 .iter()
96 .enumerate()
97 .map(|(index, task)| DisjunctiveTask {
98 start_time: task.start_time.clone(),
99 processing_time: task.processing_time,
100 id: LocalId::from(index as u32),
101 })
102 .collect::<Vec<_>>(),
103 );
104 theta_lambda_tree.update(&state);
105 for (index, task) in theta.iter().enumerate() {
106 theta_lambda_tree.add_to_theta(
107 &DisjunctiveTask {
108 start_time: task.start_time.clone(),
109 processing_time: task.processing_time,
110 id: LocalId::from(index as u32),
111 },
112 &state,
113 );
114 }
115
116 min(est_task, lb_interval) + p + propagating_task.processing_time > ub_interval
117 && theta_lambda_tree.ect() > propagating_task.start_time.induced_upper_bound(&state)
118 } else {
119 p > (ub_interval - lb_interval)
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use pumpkin_checking::TestAtomic;
128 use pumpkin_checking::VariableState;
129
130 use super::*;
131
132 #[test]
133 fn test_simple_propagation() {
134 let premises = [
135 TestAtomic {
136 name: "x1",
137 comparison: pumpkin_checking::Comparison::GreaterEqual,
138 value: 0,
139 },
140 TestAtomic {
141 name: "x1",
142 comparison: pumpkin_checking::Comparison::LessEqual,
143 value: 7,
144 },
145 TestAtomic {
146 name: "x2",
147 comparison: pumpkin_checking::Comparison::GreaterEqual,
148 value: 5,
149 },
150 TestAtomic {
151 name: "x2",
152 comparison: pumpkin_checking::Comparison::LessEqual,
153 value: 6,
154 },
155 TestAtomic {
156 name: "x3",
157 comparison: pumpkin_checking::Comparison::GreaterEqual,
158 value: 0,
159 },
160 ];
161
162 let consequent = Some(TestAtomic {
163 name: "x3",
164 comparison: pumpkin_checking::Comparison::GreaterEqual,
165 value: 8,
166 });
167 let state = VariableState::prepare_for_conflict_check(premises, consequent)
168 .expect("no conflicting atomics");
169
170 let checker = DisjunctiveEdgeFindingChecker {
171 tasks: vec![
172 ArgDisjunctiveTask {
173 start_time: "x1",
174 processing_time: 2,
175 },
176 ArgDisjunctiveTask {
177 start_time: "x2",
178 processing_time: 3,
179 },
180 ArgDisjunctiveTask {
181 start_time: "x3",
182 processing_time: 5,
183 },
184 ]
185 .into(),
186 };
187
188 assert!(checker.check(state, &premises, consequent.as_ref()));
189 }
190
191 #[test]
192 fn test_conflict() {
193 let premises = [
194 TestAtomic {
195 name: "x1",
196 comparison: pumpkin_checking::Comparison::GreaterEqual,
197 value: 0,
198 },
199 TestAtomic {
200 name: "x1",
201 comparison: pumpkin_checking::Comparison::LessEqual,
202 value: 1,
203 },
204 TestAtomic {
205 name: "x2",
206 comparison: pumpkin_checking::Comparison::GreaterEqual,
207 value: 0,
208 },
209 TestAtomic {
210 name: "x2",
211 comparison: pumpkin_checking::Comparison::LessEqual,
212 value: 1,
213 },
214 ];
215
216 let state = VariableState::prepare_for_conflict_check(premises, None)
217 .expect("no conflicting atomics");
218
219 let checker = DisjunctiveEdgeFindingChecker {
220 tasks: vec![
221 ArgDisjunctiveTask {
222 start_time: "x1",
223 processing_time: 2,
224 },
225 ArgDisjunctiveTask {
226 start_time: "x2",
227 processing_time: 3,
228 },
229 ]
230 .into(),
231 };
232
233 assert!(checker.check(state, &premises, None));
234 }
235
236 #[test]
237 fn test_simple_propagation_not_accepted() {
238 let premises = [
239 TestAtomic {
240 name: "x1",
241 comparison: pumpkin_checking::Comparison::GreaterEqual,
242 value: 0,
243 },
244 TestAtomic {
245 name: "x1",
246 comparison: pumpkin_checking::Comparison::LessEqual,
247 value: 7,
248 },
249 TestAtomic {
250 name: "x2",
251 comparison: pumpkin_checking::Comparison::GreaterEqual,
252 value: 5,
253 },
254 TestAtomic {
255 name: "x2",
256 comparison: pumpkin_checking::Comparison::LessEqual,
257 value: 6,
258 },
259 TestAtomic {
260 name: "x3",
261 comparison: pumpkin_checking::Comparison::GreaterEqual,
262 value: 0,
263 },
264 ];
265
266 let consequent = Some(TestAtomic {
267 name: "x3",
268 comparison: pumpkin_checking::Comparison::GreaterEqual,
269 value: 9,
270 });
271 let state = VariableState::prepare_for_conflict_check(premises, consequent)
272 .expect("no conflicting atomics");
273
274 let checker = DisjunctiveEdgeFindingChecker {
275 tasks: vec![
276 ArgDisjunctiveTask {
277 start_time: "x1",
278 processing_time: 2,
279 },
280 ArgDisjunctiveTask {
281 start_time: "x2",
282 processing_time: 3,
283 },
284 ArgDisjunctiveTask {
285 start_time: "x3",
286 processing_time: 5,
287 },
288 ]
289 .into(),
290 };
291
292 assert!(!checker.check(state, &premises, consequent.as_ref()));
293 }
294
295 #[test]
296 fn test_conflict_not_accepted() {
297 let premises = [
298 TestAtomic {
299 name: "x1",
300 comparison: pumpkin_checking::Comparison::GreaterEqual,
301 value: 0,
302 },
303 TestAtomic {
304 name: "x1",
305 comparison: pumpkin_checking::Comparison::LessEqual,
306 value: 1,
307 },
308 TestAtomic {
309 name: "x2",
310 comparison: pumpkin_checking::Comparison::GreaterEqual,
311 value: 0,
312 },
313 TestAtomic {
314 name: "x2",
315 comparison: pumpkin_checking::Comparison::LessEqual,
316 value: 2,
317 },
318 ];
319
320 let state = VariableState::prepare_for_conflict_check(premises, None)
321 .expect("no conflicting atomics");
322
323 let checker = DisjunctiveEdgeFindingChecker {
324 tasks: vec![
325 ArgDisjunctiveTask {
326 start_time: "x1",
327 processing_time: 2,
328 },
329 ArgDisjunctiveTask {
330 start_time: "x2",
331 processing_time: 3,
332 },
333 ]
334 .into(),
335 };
336
337 assert!(!checker.check(state, &premises, None));
338 }
339}
340
341#[derive(Debug, Clone)]
342pub(super) struct CheckerThetaLambdaTree<Var, Atomic> {
343 pub(super) nodes: Vec<Node>,
344 mapping: KeyedVec<LocalId, usize>,
347 number_of_internal_nodes: usize,
350 sorted_tasks: Vec<DisjunctiveTask<Var>>,
354 phantom_data: PhantomData<Atomic>,
355}
356
357impl<Atomic: AtomicConstraint, Var: CheckerVariable<Atomic>> CheckerThetaLambdaTree<Var, Atomic> {
358 pub(super) fn new(tasks: &[DisjunctiveTask<Var>]) -> Self {
362 let mut number_of_internal_nodes = 1;
364 while number_of_internal_nodes < tasks.len() {
365 number_of_internal_nodes <<= 1;
366 }
367
368 CheckerThetaLambdaTree {
369 nodes: Default::default(),
370 mapping: KeyedVec::default(),
371 number_of_internal_nodes: number_of_internal_nodes - 1,
372 sorted_tasks: tasks.to_vec(),
373 phantom_data: PhantomData,
374 }
375 }
376
377 pub(super) fn update(&mut self, context: &VariableState<Atomic>) {
381 self.sorted_tasks
383 .sort_by_key(|task| task.start_time.induced_lower_bound(context));
384
385 self.mapping.clear();
388 for (index, task) in self.sorted_tasks.iter().enumerate() {
389 while self.mapping.len() <= task.id.index() {
390 let _ = self.mapping.push(usize::MAX);
391 }
392 self.mapping[task.id] = index;
393 }
394
395 self.nodes.clear();
397 for _ in 0..=2 * self.number_of_internal_nodes {
398 self.nodes.push(Node::empty())
399 }
400 }
401
402 pub(super) fn ect(&self) -> i32 {
404 assert!(!self.nodes.is_empty());
405 self.nodes[0].ect
406 }
407
408 pub(super) fn add_to_theta(
410 &mut self,
411 task: &DisjunctiveTask<Var>,
412 context: &VariableState<Atomic>,
413 ) {
414 let position = self.nodes.len() / 2 + self.mapping[task.id];
416 let ect = task.start_time.induced_lower_bound(context) + task.processing_time;
417
418 self.nodes[position] = Node::new_white_node(
419 ect.try_into().expect("Should have bounds"),
420 task.processing_time,
421 );
422 self.upheap(position)
423 }
424
425 fn get_left_child_index(index: usize) -> usize {
427 2 * index + 1
428 }
429
430 fn get_right_child_index(index: usize) -> usize {
432 2 * index + 2
433 }
434
435 fn get_parent(index: usize) -> usize {
437 assert!(index > 0);
438 (index - 1) / 2
439 }
440
441 pub(super) fn upheap(&mut self, mut index: usize) {
443 while index != 0 {
444 let parent = Self::get_parent(index);
445 let left_child_of_parent = Self::get_left_child_index(parent);
446 let right_child_of_parent = Self::get_right_child_index(parent);
447 assert!(left_child_of_parent == index || right_child_of_parent == index);
448
449 self.nodes[parent].sum_of_processing_times = self.nodes[left_child_of_parent]
452 .sum_of_processing_times
453 + self.nodes[right_child_of_parent].sum_of_processing_times;
454
455 let ect_left = self.nodes[left_child_of_parent].ect
459 + self.nodes[right_child_of_parent].sum_of_processing_times;
460 self.nodes[parent].ect = max(self.nodes[right_child_of_parent].ect, ect_left);
461
462 let sum_of_processing_times_left_child_lambda = self.nodes[left_child_of_parent]
468 .sum_of_processing_times_bar
469 + self.nodes[right_child_of_parent].sum_of_processing_times;
470 let sum_of_processing_times_right_child_lambda = self.nodes[left_child_of_parent]
471 .sum_of_processing_times
472 + self.nodes[right_child_of_parent].sum_of_processing_times_bar;
473 self.nodes[parent].sum_of_processing_times_bar = max(
474 sum_of_processing_times_left_child_lambda,
475 sum_of_processing_times_right_child_lambda,
476 );
477
478 let ect_right_child_lambda = self.nodes[left_child_of_parent].ect
485 + self.nodes[right_child_of_parent].sum_of_processing_times_bar;
486 let ect_left_child_lambda = self.nodes[left_child_of_parent].ect_bar
487 + self.nodes[right_child_of_parent].sum_of_processing_times;
488 self.nodes[parent].ect_bar = max(
489 self.nodes[right_child_of_parent].ect_bar,
490 max(ect_right_child_lambda, ect_left_child_lambda),
491 );
492
493 index = parent;
494 }
495 }
496}