dypdl_heuristic_search/search_algorithm/data_structure/
transition_mutex.rs

1use super::successor_generator::SuccessorGenerator;
2use super::transition::TransitionWithId;
3use dypdl::expression::*;
4use dypdl::variable_type::Element;
5use dypdl::{Model, Transition, TransitionInterface};
6use rustc_hash::{FxHashMap, FxHashSet};
7use std::ops::Deref;
8
9#[derive(Default, PartialEq, Eq, Debug)]
10struct AffectedElements {
11    achieved: Vec<(usize, Element)>,
12    removed: Vec<(usize, Element)>,
13    arbitrary: Vec<usize>,
14}
15
16fn get_affected_elements(transition: &Transition) -> AffectedElements {
17    let mut achieved = Vec::default();
18    let mut removed = Vec::default();
19    let mut arbitrary = Vec::default();
20
21    for (var_id, expression) in &transition.effect.set_effects {
22        match expression {
23            SetExpression::SetElementOperation(
24                SetElementOperator::Add,
25                ElementExpression::Constant(element),
26                _,
27            ) => achieved.push((*var_id, *element)),
28            SetExpression::SetElementOperation(
29                SetElementOperator::Remove,
30                ElementExpression::Constant(element),
31                _,
32            ) => removed.push((*var_id, *element)),
33            _ => arbitrary.push(*var_id),
34        }
35    }
36
37    AffectedElements {
38        achieved,
39        removed,
40        arbitrary,
41    }
42}
43
44#[derive(Default, PartialEq, Eq, Debug)]
45struct RequiredElements {
46    positively: Vec<(usize, Element)>,
47    negatively: Vec<(usize, Element)>,
48}
49
50fn get_required_elements(transition: &Transition) -> RequiredElements {
51    let mut positively = Vec::default();
52    let mut negatively = Vec::default();
53
54    for condition in transition.get_preconditions() {
55        if let Condition::Set(condition) = condition {
56            match condition.as_ref() {
57                SetCondition::IsIn(
58                    ElementExpression::Constant(element),
59                    SetExpression::Reference(ReferenceExpression::Variable(var_id)),
60                ) => positively.push((*var_id, *element)),
61                SetCondition::IsIn(
62                    ElementExpression::Constant(element),
63                    SetExpression::Complement(expression),
64                ) => {
65                    if let SetExpression::Reference(ReferenceExpression::Variable(var_id)) =
66                        expression.as_ref()
67                    {
68                        negatively.push((*var_id, *element))
69                    }
70                }
71                _ => {}
72            }
73        } else if let Condition::Not(condition) = condition {
74            if let Condition::Set(condition) = condition.as_ref() {
75                match condition.as_ref() {
76                    SetCondition::IsIn(
77                        ElementExpression::Constant(element),
78                        SetExpression::Reference(ReferenceExpression::Variable(var_id)),
79                    ) => negatively.push((*var_id, *element)),
80                    SetCondition::IsIn(
81                        ElementExpression::Constant(element),
82                        SetExpression::Complement(expression),
83                    ) => {
84                        if let SetExpression::Reference(ReferenceExpression::Variable(var_id)) =
85                            expression.as_ref()
86                        {
87                            positively.push((*var_id, *element))
88                        }
89                    }
90                    _ => {}
91                }
92            }
93        }
94    }
95
96    RequiredElements {
97        positively,
98        negatively,
99    }
100}
101
102/// Data structure storing pairs of transitions that must not happen before or after each other.
103///
104/// # Examples
105///
106/// ```
107/// use dypdl::prelude::*;
108/// use dypdl_heuristic_search::search_algorithm::data_structure::{
109///     TransitionMutex,
110/// };
111/// use dypdl_heuristic_search::search_algorithm::SuccessorGenerator;
112/// use std::rc::Rc;
113///
114/// let mut model = Model::default();
115/// let object_type = model.add_object_type("object", 4).unwrap();
116/// let set = model.create_set(object_type, &[0, 1, 2, 3]).unwrap();
117/// let variable = model.add_set_variable("variable", object_type, set).unwrap();
118///
119/// let mut transition = Transition::new("remove 0");
120/// transition.add_effect(variable, variable.remove(0)).unwrap();
121/// model.add_forward_transition(transition.clone()).unwrap();
122///
123/// let mut transition = Transition::new("remove 1");
124/// transition.add_effect(variable, variable.remove(1)).unwrap();
125/// model.add_forward_forced_transition(transition.clone()).unwrap();
126///
127/// let mut transition = Transition::new("require 0");
128/// transition.add_precondition(variable.contains(0));
129/// model.add_forward_forced_transition(transition.clone()).unwrap();
130///
131/// let mut transition = Transition::new("require 1");
132/// transition.add_precondition(variable.contains(1));
133/// model.add_forward_transition(transition).unwrap();
134///
135/// let model = Rc::new(model);
136/// let generator = SuccessorGenerator::<Transition>::from_model(model, false);
137/// let transitions = generator
138///     .transitions
139///     .iter()
140///     .chain(generator.forced_transitions.iter())
141///     .map(|t| t.as_ref().clone())
142///     .collect::<Vec<_>>();
143/// let mutex = TransitionMutex::new(transitions);
144///
145/// assert_eq!(mutex.get_forbidden_before(false, 0), &[]);
146/// assert_eq!(mutex.get_forbidden_before(true, 0), &[]);
147/// assert_eq!(mutex.get_forbidden_before(true, 1), &[(false, 0)]);
148/// assert_eq!(mutex.get_forbidden_before(false, 1), &[(true, 0)]);
149///
150/// assert_eq!(mutex.get_forbidden_after(false, 0), &[(true, 1)]);
151/// assert_eq!(mutex.get_forbidden_after(true, 0), &[(false, 1)]);
152/// assert_eq!(mutex.get_forbidden_after(true, 1), &[]);
153/// assert_eq!(mutex.get_forbidden_after(false, 1), &[]);
154///
155/// let remove_0 = generator.transitions[0].clone();
156/// let require_1 = generator.transitions[1].clone();
157/// let prefix = &[(*remove_0).clone()];
158/// let suffix = &[(*require_1).clone()];
159/// let generator = mutex.filter_successor_generator(&generator, prefix, suffix);
160/// assert_eq!(generator.transitions, vec![remove_0, require_1]);
161/// assert_eq!(generator.forced_transitions, vec![]);
162/// ```
163#[derive(Default, PartialEq, Eq, Debug)]
164pub struct TransitionMutex {
165    forbidden_before: Vec<Vec<(bool, usize)>>,
166    forced_forbidden_before: Vec<Vec<(bool, usize)>>,
167    forbidden_after: Vec<Vec<(bool, usize)>>,
168    forced_forbidden_after: Vec<Vec<(bool, usize)>>,
169}
170
171impl TransitionMutex {
172    /// Get transitions that must not happen before the given transition.
173    ///
174    /// The first return value indicates whether it is a forced transition,
175    /// and the second return value is the transition id.
176    pub fn get_forbidden_before(&self, forced: bool, id: usize) -> &[(bool, usize)] {
177        if forced {
178            &self.forced_forbidden_before[id]
179        } else {
180            &self.forbidden_before[id]
181        }
182    }
183
184    /// Get transitions that must not happen after the given transition.
185    ///
186    /// The first return value indicates whether it is a forced transition,
187    /// and the second return value is the transition id.
188    pub fn get_forbidden_after(&self, forced: bool, id: usize) -> &[(bool, usize)] {
189        if forced {
190            &self.forced_forbidden_after[id]
191        } else {
192            &self.forbidden_after[id]
193        }
194    }
195
196    /// Create a successor generator filtering forbidden transitions by the given prefix ans suffix.
197    pub fn filter_successor_generator<T, U, R>(
198        &self,
199        generator: &SuccessorGenerator<T, U, R>,
200        prefix: &[TransitionWithId<T>],
201        suffix: &[TransitionWithId<T>],
202    ) -> SuccessorGenerator<T, U, R>
203    where
204        T: TransitionInterface,
205        U: Deref<Target = TransitionWithId<T>> + Clone + From<TransitionWithId<T>>,
206        R: Deref<Target = Model> + Clone,
207    {
208        let (forbidden_forced_ids, forbidden_ids): (Vec<_>, Vec<_>) = suffix
209            .iter()
210            .flat_map(|t| self.get_forbidden_before(t.forced, t.id).iter())
211            .chain(
212                prefix
213                    .iter()
214                    .flat_map(|t| self.get_forbidden_after(t.forced, t.id).iter()),
215            )
216            .copied()
217            .partition(|(forced, _)| *forced);
218        let forbidden_forced_ids =
219            FxHashSet::<usize>::from_iter(forbidden_forced_ids.into_iter().map(|(_, id)| id));
220        let forbidden_ids =
221            FxHashSet::<usize>::from_iter(forbidden_ids.into_iter().map(|(_, id)| id));
222
223        let forced_transitions = generator
224            .forced_transitions
225            .iter()
226            .enumerate()
227            .filter_map(|(id, t)| {
228                if forbidden_forced_ids.contains(&id) {
229                    None
230                } else {
231                    Some(t.clone())
232                }
233            })
234            .collect();
235
236        let transitions = generator
237            .transitions
238            .iter()
239            .enumerate()
240            .filter_map(|(id, t)| {
241                if forbidden_ids.contains(&id) {
242                    None
243                } else {
244                    Some(t.clone())
245                }
246            })
247            .collect();
248
249        SuccessorGenerator::new(
250            forced_transitions,
251            transitions,
252            generator.backward,
253            generator.model.clone(),
254        )
255    }
256
257    /// Create a new transition mutex from the given transitions.
258    pub fn new<T>(transitions: Vec<TransitionWithId<T>>) -> Self
259    where
260        T: TransitionInterface + Clone,
261        Transition: From<T>,
262    {
263        let len = transitions
264            .iter()
265            .filter_map(|t| if !t.forced { Some(t.id) } else { None })
266            .max()
267            .map_or(0, |id_max| id_max + 1);
268        let forced_len = transitions
269            .iter()
270            .filter_map(|t| if t.forced { Some(t.id) } else { None })
271            .max()
272            .map_or(0, |id_max| id_max + 1);
273
274        let mut achievers = FxHashMap::default();
275        let mut removers = FxHashMap::default();
276        let mut arbitrary_affected = FxHashSet::default();
277        let mut positively_conditioned = FxHashMap::default();
278        let mut negatively_conditioned = FxHashMap::default();
279
280        for t in transitions {
281            let id = t.id;
282            let forced = t.forced;
283            let transition = Transition::from(t.transition);
284            let affected = get_affected_elements(&transition);
285            extend_element_transitions_map(&mut achievers, &affected.achieved, forced, id);
286            extend_element_transitions_map(&mut removers, &affected.removed, forced, id);
287            arbitrary_affected.extend(affected.arbitrary.into_iter());
288            let required = get_required_elements(&transition);
289            extend_element_transitions_map(
290                &mut positively_conditioned,
291                &required.positively,
292                forced,
293                id,
294            );
295            extend_element_transitions_map(
296                &mut negatively_conditioned,
297                &required.negatively,
298                forced,
299                id,
300            );
301        }
302
303        let mut forbidden_before = vec![FxHashSet::default(); len];
304        let mut forced_forbidden_before = vec![FxHashSet::default(); forced_len];
305        let mut forbidden_after = vec![FxHashSet::default(); len];
306        let mut forced_forbidden_after = vec![FxHashSet::default(); forced_len];
307
308        // For each transition that positively requires an element.
309        for ((var_id, element), operator_ids) in positively_conditioned {
310            // If an element can be added, it does not matter.
311            if arbitrary_affected.contains(&var_id) || achievers.contains_key(&(var_id, element)) {
312                continue;
313            }
314
315            // A transition that removes the element must not happen before.
316            for (forced, id) in operator_ids {
317                if let Some(removers) = removers.get(&(var_id, element)) {
318                    for remover in removers {
319                        if forced {
320                            forced_forbidden_before[id].insert(*remover);
321                        } else {
322                            forbidden_before[id].insert(*remover);
323                        }
324
325                        if remover.0 {
326                            forced_forbidden_after[remover.1].insert((forced, id));
327                        } else {
328                            forbidden_after[remover.1].insert((forced, id));
329                        }
330                    }
331                }
332            }
333        }
334
335        // For each transition that negatively requires an element.
336        for ((var_id, element), operator_ids) in negatively_conditioned {
337            // If an element can be removed, it does not matter.
338            if arbitrary_affected.contains(&var_id) || removers.contains_key(&(var_id, element)) {
339                continue;
340            }
341
342            // A transition that adds the element must not happen before.
343            for (forced, id) in operator_ids {
344                if let Some(achievers) = achievers.get(&(var_id, element)) {
345                    for achiever in achievers {
346                        if forced {
347                            forced_forbidden_before[id].insert(*achiever);
348                        } else {
349                            forbidden_before[id].insert(*achiever);
350                        }
351
352                        if achiever.0 {
353                            forced_forbidden_after[achiever.1].insert((forced, id));
354                        } else {
355                            forbidden_after[achiever.1].insert((forced, id));
356                        }
357                    }
358                }
359            }
360        }
361
362        let forbidden_before = forbidden_before
363            .into_iter()
364            .map(|x| {
365                let mut v = Vec::from_iter(x);
366                v.sort();
367                v
368            })
369            .collect();
370        let forced_forbidden_before = forced_forbidden_before
371            .into_iter()
372            .map(|x| {
373                let mut v = Vec::from_iter(x);
374                v.sort();
375                v
376            })
377            .collect();
378        let forbidden_after = forbidden_after
379            .into_iter()
380            .map(|x| {
381                let mut v = Vec::from_iter(x);
382                v.sort();
383                v
384            })
385            .collect();
386        let forced_forbidden_after = forced_forbidden_after
387            .into_iter()
388            .map(|x| {
389                let mut v = Vec::from_iter(x);
390                v.sort();
391                v
392            })
393            .collect();
394
395        Self {
396            forbidden_before,
397            forced_forbidden_before,
398            forbidden_after,
399            forced_forbidden_after,
400        }
401    }
402}
403
404fn extend_element_transitions_map(
405    map: &mut FxHashMap<(usize, Element), Vec<(bool, usize)>>,
406    elements: &[(usize, Element)],
407    forced: bool,
408    id: usize,
409) {
410    for key in elements {
411        map.entry(*key)
412            .and_modify(|e| e.push((forced, id)))
413            .or_insert_with(|| vec![(forced, id)]);
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use dypdl::{Effect, GroundedCondition};
421    use std::rc::Rc;
422
423    #[test]
424    fn get_affected_elements_without_set_effect() {
425        let transition = Transition::default();
426        let result = get_affected_elements(&transition);
427        assert_eq!(result, AffectedElements::default());
428    }
429
430    #[test]
431    fn get_affected_elements_with_set_effect() {
432        let transition = Transition {
433            effect: Effect {
434                set_effects: vec![
435                    (
436                        0,
437                        SetExpression::SetElementOperation(
438                            SetElementOperator::Add,
439                            ElementExpression::Constant(2),
440                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
441                        ),
442                    ),
443                    (
444                        1,
445                        SetExpression::SetElementOperation(
446                            SetElementOperator::Remove,
447                            ElementExpression::Constant(3),
448                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(2))),
449                        ),
450                    ),
451                    (
452                        2,
453                        SetExpression::SetOperation(
454                            SetOperator::Union,
455                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(0))),
456                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
457                        ),
458                    ),
459                ],
460                ..Default::default()
461            },
462            ..Default::default()
463        };
464        let result = get_affected_elements(&transition);
465        assert_eq!(result.achieved, vec![(0, 2)]);
466        assert_eq!(result.removed, vec![(1, 3)]);
467        assert_eq!(result.arbitrary, vec![2]);
468    }
469
470    #[test]
471    fn get_required_elements_without_set_effect() {
472        let transition = Transition::default();
473        let result = get_required_elements(&transition);
474        assert_eq!(result, RequiredElements::default());
475    }
476
477    #[test]
478    fn get_required_elements_with_set_effect() {
479        let transition = Transition {
480            elements_in_set_variable: vec![(0, 1)],
481            preconditions: vec![
482                GroundedCondition {
483                    condition: Condition::Set(Box::new(SetCondition::IsIn(
484                        ElementExpression::Constant(2),
485                        SetExpression::Reference(ReferenceExpression::Variable(1)),
486                    ))),
487                    ..Default::default()
488                },
489                GroundedCondition {
490                    condition: Condition::Set(Box::new(SetCondition::IsIn(
491                        ElementExpression::Constant(3),
492                        SetExpression::Complement(Box::new(SetExpression::Reference(
493                            ReferenceExpression::Variable(2),
494                        ))),
495                    ))),
496                    ..Default::default()
497                },
498                GroundedCondition {
499                    condition: Condition::Set(Box::new(SetCondition::IsIn(
500                        ElementExpression::Variable(0),
501                        SetExpression::Complement(Box::new(SetExpression::Reference(
502                            ReferenceExpression::Variable(3),
503                        ))),
504                    ))),
505                    ..Default::default()
506                },
507                GroundedCondition {
508                    condition: Condition::Not(Box::new(Condition::Set(Box::new(
509                        SetCondition::IsIn(
510                            ElementExpression::Constant(5),
511                            SetExpression::Reference(ReferenceExpression::Variable(4)),
512                        ),
513                    )))),
514                    ..Default::default()
515                },
516                GroundedCondition {
517                    condition: Condition::Not(Box::new(Condition::Set(Box::new(
518                        SetCondition::IsIn(
519                            ElementExpression::Constant(6),
520                            SetExpression::Complement(Box::new(SetExpression::Reference(
521                                ReferenceExpression::Variable(5),
522                            ))),
523                        ),
524                    )))),
525                    ..Default::default()
526                },
527                GroundedCondition {
528                    condition: Condition::Not(Box::new(Condition::Set(Box::new(
529                        SetCondition::IsIn(
530                            ElementExpression::Variable(0),
531                            SetExpression::Reference(ReferenceExpression::Variable(6)),
532                        ),
533                    )))),
534                    ..Default::default()
535                },
536            ],
537            ..Default::default()
538        };
539        let result = get_required_elements(&transition);
540        assert_eq!(result.positively, vec![(0, 1), (1, 2), (5, 6)]);
541        assert_eq!(result.negatively, vec![(2, 3), (4, 5)]);
542    }
543
544    #[test]
545    #[should_panic]
546    fn get_forbidden_before_with_no_transitions() {
547        let constraints = TransitionMutex::new(Vec::<TransitionWithId>::default());
548        constraints.get_forbidden_before(false, 0);
549    }
550
551    #[test]
552    #[should_panic]
553    fn get_forbidden_after_with_no_transitions() {
554        let constraints = TransitionMutex::new(Vec::<TransitionWithId>::default());
555        constraints.get_forbidden_after(false, 0);
556    }
557
558    #[test]
559    fn new() {
560        let transitions = vec![
561            TransitionWithId {
562                id: 0,
563                forced: true,
564                transition: Transition {
565                    effect: Effect {
566                        set_effects: vec![
567                            (
568                                0,
569                                SetExpression::SetElementOperation(
570                                    SetElementOperator::Remove,
571                                    ElementExpression::Constant(1),
572                                    Box::new(SetExpression::Reference(
573                                        ReferenceExpression::Variable(0),
574                                    )),
575                                ),
576                            ),
577                            (
578                                1,
579                                SetExpression::SetElementOperation(
580                                    SetElementOperator::Add,
581                                    ElementExpression::Constant(1),
582                                    Box::new(SetExpression::Reference(
583                                        ReferenceExpression::Variable(1),
584                                    )),
585                                ),
586                            ),
587                        ],
588                        ..Default::default()
589                    },
590                    ..Default::default()
591                },
592            },
593            TransitionWithId {
594                id: 1,
595                forced: true,
596                transition: Transition {
597                    preconditions: vec![
598                        GroundedCondition {
599                            condition: Condition::Set(Box::new(SetCondition::IsIn(
600                                ElementExpression::Constant(1),
601                                SetExpression::Reference(ReferenceExpression::Variable(0)),
602                            ))),
603                            ..Default::default()
604                        },
605                        GroundedCondition {
606                            condition: Condition::Set(Box::new(SetCondition::IsIn(
607                                ElementExpression::Constant(2),
608                                SetExpression::Reference(ReferenceExpression::Variable(2)),
609                            ))),
610                            ..Default::default()
611                        },
612                    ],
613                    effect: Effect {
614                        set_effects: vec![
615                            (
616                                0,
617                                SetExpression::SetElementOperation(
618                                    SetElementOperator::Remove,
619                                    ElementExpression::Constant(1),
620                                    Box::new(SetExpression::Reference(
621                                        ReferenceExpression::Variable(0),
622                                    )),
623                                ),
624                            ),
625                            (
626                                2,
627                                SetExpression::SetElementOperation(
628                                    SetElementOperator::Remove,
629                                    ElementExpression::Constant(2),
630                                    Box::new(SetExpression::Reference(
631                                        ReferenceExpression::Variable(2),
632                                    )),
633                                ),
634                            ),
635                        ],
636                        ..Default::default()
637                    },
638                    ..Default::default()
639                },
640            },
641            TransitionWithId {
642                id: 0,
643                forced: false,
644                transition: Transition {
645                    effect: Effect {
646                        set_effects: vec![
647                            (
648                                1,
649                                SetExpression::SetElementOperation(
650                                    SetElementOperator::Add,
651                                    ElementExpression::Constant(2),
652                                    Box::new(SetExpression::Reference(
653                                        ReferenceExpression::Variable(1),
654                                    )),
655                                ),
656                            ),
657                            (
658                                0,
659                                SetExpression::SetElementOperation(
660                                    SetElementOperator::Remove,
661                                    ElementExpression::Constant(2),
662                                    Box::new(SetExpression::Reference(
663                                        ReferenceExpression::Variable(0),
664                                    )),
665                                ),
666                            ),
667                            (
668                                2,
669                                SetExpression::SetElementOperation(
670                                    SetElementOperator::Remove,
671                                    ElementExpression::Constant(2),
672                                    Box::new(SetExpression::Reference(
673                                        ReferenceExpression::Variable(2),
674                                    )),
675                                ),
676                            ),
677                        ],
678                        ..Default::default()
679                    },
680                    ..Default::default()
681                },
682            },
683            TransitionWithId {
684                id: 1,
685                forced: false,
686                transition: Transition {
687                    preconditions: vec![GroundedCondition {
688                        condition: Condition::Not(Box::new(Condition::Set(Box::new(
689                            SetCondition::IsIn(
690                                ElementExpression::Constant(2),
691                                SetExpression::Reference(ReferenceExpression::Variable(1)),
692                            ),
693                        )))),
694                        ..Default::default()
695                    }],
696                    effect: Effect {
697                        set_effects: vec![
698                            (
699                                1,
700                                SetExpression::SetElementOperation(
701                                    SetElementOperator::Add,
702                                    ElementExpression::Constant(2),
703                                    Box::new(SetExpression::Reference(
704                                        ReferenceExpression::Variable(1),
705                                    )),
706                                ),
707                            ),
708                            (
709                                2,
710                                SetExpression::SetElementOperation(
711                                    SetElementOperator::Remove,
712                                    ElementExpression::Variable(0),
713                                    Box::new(SetExpression::Reference(
714                                        ReferenceExpression::Variable(2),
715                                    )),
716                                ),
717                            ),
718                        ],
719                        ..Default::default()
720                    },
721                    ..Default::default()
722                },
723            },
724        ];
725        let mutex = TransitionMutex::new(transitions);
726        assert_eq!(mutex.get_forbidden_before(true, 0), &[]);
727        assert_eq!(mutex.get_forbidden_before(true, 1), &[(true, 0), (true, 1)]);
728        assert_eq!(mutex.get_forbidden_before(false, 0), &[]);
729        assert_eq!(
730            mutex.get_forbidden_before(false, 1),
731            &[(false, 0), (false, 1)]
732        );
733        assert_eq!(mutex.get_forbidden_after(true, 0), &[(true, 1)]);
734        assert_eq!(mutex.get_forbidden_after(false, 0), &[(false, 1)]);
735        assert_eq!(mutex.get_forbidden_after(true, 1), &[(true, 1)]);
736        assert_eq!(mutex.get_forbidden_after(false, 1), &[(false, 1)]);
737    }
738
739    #[test]
740    fn filter_successor_generator_without_any_constraints_and_suffix() {
741        let model = Rc::new(Model {
742            forward_transitions: vec![Transition::default()],
743            forward_forced_transitions: vec![Transition::default()],
744            ..Default::default()
745        });
746        let expected = SuccessorGenerator::<Transition>::from_model(model, false);
747        let transitions = expected
748            .transitions
749            .iter()
750            .chain(expected.forced_transitions.iter())
751            .map(|t| t.as_ref().clone())
752            .collect::<Vec<_>>();
753        let mutex = TransitionMutex::new(transitions);
754        let generator = mutex.filter_successor_generator(&expected, &[], &[]);
755        assert_eq!(generator, expected);
756    }
757
758    #[test]
759    fn filter_successor_generator_without_any_constraints() {
760        let model = Rc::new(Model {
761            forward_transitions: vec![Transition::default()],
762            forward_forced_transitions: vec![Transition::default()],
763            ..Default::default()
764        });
765        let expected = SuccessorGenerator::<Transition>::from_model(model, false);
766        let transitions = expected
767            .transitions
768            .iter()
769            .chain(expected.forced_transitions.iter())
770            .map(|t| t.as_ref().clone())
771            .collect::<Vec<_>>();
772        let mutex = TransitionMutex::new(transitions);
773        let generator = mutex.filter_successor_generator(
774            &expected,
775            &[],
776            &[
777                TransitionWithId {
778                    id: 0,
779                    forced: false,
780                    ..Default::default()
781                },
782                TransitionWithId {
783                    id: 0,
784                    forced: true,
785                    ..Default::default()
786                },
787            ],
788        );
789        assert_eq!(generator, expected);
790    }
791
792    #[test]
793    fn filter_successor_generator_with_constraints_without_suffix() {
794        let model = Rc::new(Model {
795            forward_transitions: vec![Transition {
796                preconditions: vec![GroundedCondition {
797                    condition: Condition::Set(Box::new(SetCondition::IsIn(
798                        ElementExpression::Constant(1),
799                        SetExpression::Reference(ReferenceExpression::Variable(0)),
800                    ))),
801                    ..Default::default()
802                }],
803                effect: Effect {
804                    set_effects: vec![(
805                        0,
806                        SetExpression::SetElementOperation(
807                            SetElementOperator::Remove,
808                            ElementExpression::Constant(1),
809                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(0))),
810                        ),
811                    )],
812                    ..Default::default()
813                },
814                ..Default::default()
815            }],
816            forward_forced_transitions: vec![Transition {
817                preconditions: vec![GroundedCondition {
818                    condition: Condition::Not(Box::new(Condition::Set(Box::new(
819                        SetCondition::IsIn(
820                            ElementExpression::Constant(2),
821                            SetExpression::Reference(ReferenceExpression::Variable(1)),
822                        ),
823                    )))),
824                    ..Default::default()
825                }],
826                effect: Effect {
827                    set_effects: vec![(
828                        1,
829                        SetExpression::SetElementOperation(
830                            SetElementOperator::Add,
831                            ElementExpression::Constant(2),
832                            Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
833                        ),
834                    )],
835                    ..Default::default()
836                },
837                ..Default::default()
838            }],
839            ..Default::default()
840        });
841        let expected = SuccessorGenerator::<Transition>::from_model(model, false);
842        let transitions = expected
843            .transitions
844            .iter()
845            .chain(expected.forced_transitions.iter())
846            .map(|t| t.as_ref().clone())
847            .collect::<Vec<_>>();
848        let mutex = TransitionMutex::new(transitions);
849        let generator = mutex.filter_successor_generator(&expected, &[], &[]);
850        assert_eq!(generator, expected);
851    }
852
853    #[test]
854    fn filter_successor_generator_with_constraints_and_prefix() {
855        let t1 = Transition {
856            preconditions: vec![GroundedCondition {
857                condition: Condition::Set(Box::new(SetCondition::IsIn(
858                    ElementExpression::Constant(1),
859                    SetExpression::Reference(ReferenceExpression::Variable(0)),
860                ))),
861                ..Default::default()
862            }],
863            effect: Effect {
864                set_effects: vec![(
865                    0,
866                    SetExpression::SetElementOperation(
867                        SetElementOperator::Remove,
868                        ElementExpression::Constant(1),
869                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(0))),
870                    ),
871                )],
872                ..Default::default()
873            },
874            ..Default::default()
875        };
876        let t2 = Transition {
877            preconditions: vec![GroundedCondition {
878                condition: Condition::Not(Box::new(Condition::Set(Box::new(SetCondition::IsIn(
879                    ElementExpression::Constant(2),
880                    SetExpression::Reference(ReferenceExpression::Variable(1)),
881                ))))),
882                ..Default::default()
883            }],
884            effect: Effect {
885                set_effects: vec![(
886                    1,
887                    SetExpression::SetElementOperation(
888                        SetElementOperator::Add,
889                        ElementExpression::Constant(2),
890                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
891                    ),
892                )],
893                ..Default::default()
894            },
895            ..Default::default()
896        };
897        let model = Rc::new(Model {
898            forward_transitions: vec![t1.clone(), Transition::default()],
899            forward_forced_transitions: vec![t2.clone(), Transition::default()],
900            ..Default::default()
901        });
902        let generator = SuccessorGenerator::<Transition>::from_model(model.clone(), false);
903        let transitions = generator
904            .transitions
905            .iter()
906            .chain(generator.forced_transitions.iter())
907            .map(|t| t.as_ref().clone())
908            .collect::<Vec<_>>();
909        let mutex = TransitionMutex::new(transitions);
910        let t1 = TransitionWithId {
911            id: 0,
912            forced: false,
913            transition: t1,
914        };
915        let t2 = TransitionWithId {
916            id: 0,
917            forced: true,
918            transition: t2,
919        };
920        let generator = mutex.filter_successor_generator(&generator, &[t1, t2], &[]);
921        let forced_transitions = vec![Rc::new(TransitionWithId {
922            id: 1,
923            forced: true,
924            ..Default::default()
925        })];
926        let transitions = vec![Rc::new(TransitionWithId {
927            id: 1,
928            forced: false,
929            ..Default::default()
930        })];
931        let expected = SuccessorGenerator::new(forced_transitions, transitions, false, model);
932        assert_eq!(generator, expected);
933    }
934
935    #[test]
936    fn filter_successor_generator_with_constraints_and_suffix() {
937        let t1 = Transition {
938            preconditions: vec![GroundedCondition {
939                condition: Condition::Set(Box::new(SetCondition::IsIn(
940                    ElementExpression::Constant(1),
941                    SetExpression::Reference(ReferenceExpression::Variable(0)),
942                ))),
943                ..Default::default()
944            }],
945            effect: Effect {
946                set_effects: vec![(
947                    0,
948                    SetExpression::SetElementOperation(
949                        SetElementOperator::Remove,
950                        ElementExpression::Constant(1),
951                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(0))),
952                    ),
953                )],
954                ..Default::default()
955            },
956            ..Default::default()
957        };
958        let t2 = Transition {
959            preconditions: vec![GroundedCondition {
960                condition: Condition::Not(Box::new(Condition::Set(Box::new(SetCondition::IsIn(
961                    ElementExpression::Constant(2),
962                    SetExpression::Reference(ReferenceExpression::Variable(1)),
963                ))))),
964                ..Default::default()
965            }],
966            effect: Effect {
967                set_effects: vec![(
968                    1,
969                    SetExpression::SetElementOperation(
970                        SetElementOperator::Add,
971                        ElementExpression::Constant(2),
972                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
973                    ),
974                )],
975                ..Default::default()
976            },
977            ..Default::default()
978        };
979        let model = Rc::new(Model {
980            forward_transitions: vec![t1.clone(), Transition::default()],
981            forward_forced_transitions: vec![t2.clone(), Transition::default()],
982            ..Default::default()
983        });
984        let generator = SuccessorGenerator::<Transition>::from_model(model.clone(), false);
985        let transitions = generator
986            .transitions
987            .iter()
988            .chain(generator.forced_transitions.iter())
989            .map(|t| t.as_ref().clone())
990            .collect::<Vec<_>>();
991        let mutex = TransitionMutex::new(transitions);
992        let t1 = TransitionWithId {
993            id: 0,
994            forced: false,
995            transition: t1,
996        };
997        let t2 = TransitionWithId {
998            id: 0,
999            forced: true,
1000            transition: t2,
1001        };
1002        let generator = mutex.filter_successor_generator(&generator, &[], &[t1, t2]);
1003        let forced_transitions = vec![Rc::new(TransitionWithId {
1004            id: 1,
1005            forced: true,
1006            ..Default::default()
1007        })];
1008        let transitions = vec![Rc::new(TransitionWithId {
1009            id: 1,
1010            forced: false,
1011            ..Default::default()
1012        })];
1013        let expected =
1014            SuccessorGenerator::new(forced_transitions, transitions, false, model.clone());
1015        assert_eq!(generator, expected);
1016    }
1017
1018    #[test]
1019    fn filter_successor_generator_with_constraints_and_prefix_and_suffix() {
1020        let t1 = Transition {
1021            preconditions: vec![GroundedCondition {
1022                condition: Condition::Set(Box::new(SetCondition::IsIn(
1023                    ElementExpression::Constant(1),
1024                    SetExpression::Reference(ReferenceExpression::Variable(0)),
1025                ))),
1026                ..Default::default()
1027            }],
1028            effect: Effect {
1029                set_effects: vec![(
1030                    0,
1031                    SetExpression::SetElementOperation(
1032                        SetElementOperator::Remove,
1033                        ElementExpression::Constant(1),
1034                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(0))),
1035                    ),
1036                )],
1037                ..Default::default()
1038            },
1039            ..Default::default()
1040        };
1041        let t2 = Transition {
1042            preconditions: vec![GroundedCondition {
1043                condition: Condition::Not(Box::new(Condition::Set(Box::new(SetCondition::IsIn(
1044                    ElementExpression::Constant(2),
1045                    SetExpression::Reference(ReferenceExpression::Variable(1)),
1046                ))))),
1047                ..Default::default()
1048            }],
1049            effect: Effect {
1050                set_effects: vec![(
1051                    1,
1052                    SetExpression::SetElementOperation(
1053                        SetElementOperator::Add,
1054                        ElementExpression::Constant(2),
1055                        Box::new(SetExpression::Reference(ReferenceExpression::Variable(1))),
1056                    ),
1057                )],
1058                ..Default::default()
1059            },
1060            ..Default::default()
1061        };
1062        let model = Rc::new(Model {
1063            forward_transitions: vec![t1.clone(), Transition::default()],
1064            forward_forced_transitions: vec![t2.clone(), Transition::default()],
1065            ..Default::default()
1066        });
1067        let generator = SuccessorGenerator::<Transition>::from_model(model.clone(), false);
1068        let transitions = generator
1069            .transitions
1070            .iter()
1071            .chain(generator.forced_transitions.iter())
1072            .map(|t| t.as_ref().clone())
1073            .collect::<Vec<_>>();
1074        let mutex = TransitionMutex::new(transitions);
1075        let t1 = TransitionWithId {
1076            id: 0,
1077            forced: false,
1078            transition: t1,
1079        };
1080        let t2 = TransitionWithId {
1081            id: 0,
1082            forced: true,
1083            transition: t2,
1084        };
1085        let generator = mutex.filter_successor_generator(&generator, &[t1], &[t2]);
1086        let forced_transitions = vec![Rc::new(TransitionWithId {
1087            id: 1,
1088            forced: true,
1089            ..Default::default()
1090        })];
1091        let transitions = vec![Rc::new(TransitionWithId {
1092            id: 1,
1093            forced: false,
1094            ..Default::default()
1095        })];
1096        let expected =
1097            SuccessorGenerator::new(forced_transitions, transitions, false, model.clone());
1098        assert_eq!(generator, expected);
1099    }
1100}