dypdl/expression/
vector_expression.rs

1use super::condition::Condition;
2use super::element_expression::ElementExpression;
3use super::reference_expression::ReferenceExpression;
4use super::set_expression::SetExpression;
5use crate::state::StateInterface;
6use crate::state_functions::{StateFunctionCache, StateFunctions};
7use crate::table_registry::TableRegistry;
8use crate::variable_type::Vector;
9
10/// Vector expression.
11#[derive(Debug, PartialEq, Clone)]
12pub enum VectorExpression {
13    /// Reference to a constant or a variable.
14    Reference(ReferenceExpression<Vector>),
15    /// Indices of a vector.
16    Indices(Box<VectorExpression>),
17    /// Reverse a vector.
18    Reverse(Box<VectorExpression>),
19    /// Set an element in a vector.
20    Set(ElementExpression, Box<VectorExpression>, ElementExpression),
21    /// Push an element to a vector.
22    Push(ElementExpression, Box<VectorExpression>),
23    /// Pop an element from a vector.
24    Pop(Box<VectorExpression>),
25    /// Conversion from a set.
26    FromSet(Box<SetExpression>),
27    /// If-then-else expression, which returns the first one if the condition holds and the second one otherwise.
28    If(Box<Condition>, Box<VectorExpression>, Box<VectorExpression>),
29}
30
31impl VectorExpression {
32    /// Returns the evaluation result.
33    ///
34    /// # Panics
35    ///
36    /// Panics if the cost of the transition state is used or a min/max reduce operation is performed on an empty set or vector.
37    pub fn eval<T: StateInterface>(
38        &self,
39        state: &T,
40        function_cache: &mut StateFunctionCache,
41        state_functions: &StateFunctions,
42        registry: &TableRegistry,
43    ) -> Vector {
44        match self {
45            Self::Reference(expression) => expression
46                .eval(state, function_cache, state_functions, registry)
47                .clone(),
48            Self::Indices(vector) => {
49                let mut vector = vector.eval(state, function_cache, state_functions, registry);
50                vector.iter_mut().enumerate().for_each(|(i, v)| *v = i);
51                vector
52            }
53            Self::Reverse(vector) => {
54                let mut vector = vector.eval(state, function_cache, state_functions, registry);
55                vector.reverse();
56                vector
57            }
58            Self::Set(element, vector, i) => {
59                let mut vector = vector.eval(state, function_cache, state_functions, registry);
60                vector[i.eval(state, function_cache, state_functions, registry)] =
61                    element.eval(state, function_cache, state_functions, registry);
62                vector
63            }
64            Self::Push(element, vector) => {
65                let element = element.eval(state, function_cache, state_functions, registry);
66                let mut vector = vector.eval(state, function_cache, state_functions, registry);
67                vector.push(element);
68                vector
69            }
70            Self::Pop(vector) => {
71                let mut vector = vector.eval(state, function_cache, state_functions, registry);
72                vector.pop();
73                vector
74            }
75            Self::FromSet(set) => match set.as_ref() {
76                SetExpression::Reference(set) => set
77                    .eval(state, function_cache, state_functions, registry)
78                    .ones()
79                    .collect(),
80                set => set
81                    .eval(state, function_cache, state_functions, registry)
82                    .ones()
83                    .collect(),
84            },
85            Self::If(condition, x, y) => {
86                if condition.eval(state, function_cache, state_functions, registry) {
87                    x.eval(state, function_cache, state_functions, registry)
88                } else {
89                    y.eval(state, function_cache, state_functions, registry)
90                }
91            }
92        }
93    }
94
95    /// Returns a simplified version by precomputation.
96    ///
97    /// # Panics
98    ///
99    /// Panics if the cost of the transition state is used or a min/max reduce operation is performed on an empty set or vector.
100    pub fn simplify(&self, registry: &TableRegistry) -> VectorExpression {
101        match self {
102            Self::Reference(vector) => {
103                Self::Reference(vector.simplify(registry, &registry.vector_tables))
104            }
105            Self::Indices(vector) => match vector.simplify(registry) {
106                VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
107                    vector.iter_mut().enumerate().for_each(|(i, v)| *v = i);
108                    Self::Reference(ReferenceExpression::Constant(vector))
109                }
110                vector => Self::Indices(Box::new(vector)),
111            },
112            Self::Reverse(vector) => match vector.simplify(registry) {
113                VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
114                    vector.reverse();
115                    Self::Reference(ReferenceExpression::Constant(vector))
116                }
117                vector => Self::Reverse(Box::new(vector)),
118            },
119            Self::Set(element, vector, i) => match (
120                element.simplify(registry),
121                vector.simplify(registry),
122                i.simplify(registry),
123            ) {
124                (
125                    ElementExpression::Constant(element),
126                    VectorExpression::Reference(ReferenceExpression::Constant(mut vector)),
127                    ElementExpression::Constant(i),
128                ) => {
129                    vector[i] = element;
130                    Self::Reference(ReferenceExpression::Constant(vector))
131                }
132                (element, vector, i) => Self::Set(element, Box::new(vector), i),
133            },
134            Self::Push(element, vector) => {
135                match (element.simplify(registry), vector.simplify(registry)) {
136                    (
137                        ElementExpression::Constant(element),
138                        VectorExpression::Reference(ReferenceExpression::Constant(mut vector)),
139                    ) => {
140                        vector.push(element);
141                        Self::Reference(ReferenceExpression::Constant(vector))
142                    }
143                    (element, vector) => Self::Push(element, Box::new(vector)),
144                }
145            }
146            Self::Pop(vector) => match vector.simplify(registry) {
147                VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
148                    vector.pop();
149                    Self::Reference(ReferenceExpression::Constant(vector))
150                }
151                vector => Self::Pop(Box::new(vector)),
152            },
153            Self::FromSet(set) => match set.simplify(registry) {
154                SetExpression::Reference(ReferenceExpression::Constant(set)) => {
155                    Self::Reference(ReferenceExpression::Constant(set.ones().collect()))
156                }
157                set => Self::FromSet(Box::new(set)),
158            },
159            Self::If(condition, x, y) => match condition.simplify(registry) {
160                Condition::Constant(true) => x.simplify(registry),
161                Condition::Constant(false) => y.simplify(registry),
162                condition => Self::If(
163                    Box::new(condition),
164                    Box::new(x.simplify(registry)),
165                    Box::new(y.simplify(registry)),
166                ),
167            },
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::super::condition::ComparisonOperator;
175    use super::super::integer_expression::IntegerExpression;
176    use super::super::table_expression::TableExpression;
177    use super::*;
178    use crate::state::*;
179    use crate::table::*;
180    use crate::table_data::*;
181    use crate::variable_type::Set;
182    use rustc_hash::FxHashMap;
183
184    fn generate_registry() -> TableRegistry {
185        let mut name_to_constant = FxHashMap::default();
186        name_to_constant.insert(String::from("f0"), 1);
187
188        let tables_1d = vec![Table1D::new(vec![1, 0])];
189        let mut name_to_table_1d = FxHashMap::default();
190        name_to_table_1d.insert(String::from("f1"), 0);
191
192        let tables_2d = vec![Table2D::new(vec![vec![1, 0]])];
193        let mut name_to_table_2d = FxHashMap::default();
194        name_to_table_2d.insert(String::from("f2"), 0);
195
196        let tables_3d = vec![Table3D::new(vec![vec![vec![1, 0]]])];
197        let mut name_to_table_3d = FxHashMap::default();
198        name_to_table_3d.insert(String::from("f3"), 0);
199
200        let mut map = FxHashMap::default();
201        let key = vec![0, 0, 0, 0];
202        map.insert(key, 1);
203        let key = vec![0, 0, 0, 1];
204        map.insert(key, 0);
205        let tables = vec![Table::new(map, 0)];
206        let mut name_to_table = FxHashMap::default();
207        name_to_table.insert(String::from("f4"), 0);
208
209        let element_tables = TableData {
210            name_to_constant,
211            tables_1d,
212            name_to_table_1d,
213            tables_2d,
214            name_to_table_2d,
215            tables_3d,
216            name_to_table_3d,
217            tables,
218            name_to_table,
219        };
220
221        let mut name_to_table_1d = FxHashMap::default();
222        name_to_table_1d.insert(String::from("t1"), 0);
223        let vector_tables = TableData {
224            tables_1d: vec![Table1D::new(vec![vec![0, 1]])],
225            name_to_table_1d,
226            ..Default::default()
227        };
228
229        let mut set = Set::with_capacity(3);
230        set.insert(0);
231        set.insert(2);
232        let default = Set::with_capacity(3);
233        let tables_1d = vec![Table1D::new(vec![set, default.clone(), default])];
234        let mut name_to_table_1d = FxHashMap::default();
235        name_to_table_1d.insert(String::from("s1"), 0);
236        let set_tables = TableData {
237            tables_1d,
238            name_to_table_1d,
239            ..Default::default()
240        };
241
242        TableRegistry {
243            element_tables,
244            set_tables,
245            vector_tables,
246            ..Default::default()
247        }
248    }
249
250    fn generate_state() -> State {
251        let mut set1 = Set::with_capacity(3);
252        set1.insert(0);
253        set1.insert(2);
254        let mut set2 = Set::with_capacity(3);
255        set2.insert(0);
256        set2.insert(1);
257        State {
258            signature_variables: SignatureVariables {
259                set_variables: vec![set1, set2],
260                vector_variables: vec![vec![0, 2]],
261                element_variables: vec![1],
262                ..Default::default()
263            },
264            resource_variables: ResourceVariables {
265                element_variables: vec![2],
266                ..Default::default()
267            },
268        }
269    }
270
271    #[test]
272    fn vector_reference_eval() {
273        let state = generate_state();
274        let state_functions = StateFunctions::default();
275        let mut function_cache = StateFunctionCache::new(&state_functions);
276        let registry = generate_registry();
277        let expression = VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2]));
278        assert_eq!(
279            expression.eval(
280                &state,
281                &mut function_cache,
282                &state_functions,
283                &registry
284            ),
285            vec![1, 2]
286        );
287    }
288
289    #[test]
290    fn vector_indices_eval() {
291        let state = generate_state();
292        let state_functions = StateFunctions::default();
293        let mut function_cache = StateFunctionCache::new(&state_functions);
294        let registry = generate_registry();
295        let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
296            ReferenceExpression::Constant(vec![1, 2]),
297        )));
298        assert_eq!(
299            expression.eval(
300                &state,
301                &mut function_cache,
302                &state_functions,
303                &registry
304            ),
305            vec![0, 1]
306        );
307    }
308
309    #[test]
310    fn vector_reverse_eval() {
311        let state = generate_state();
312        let state_functions = StateFunctions::default();
313        let mut function_cache = StateFunctionCache::new(&state_functions);
314        let registry = generate_registry();
315        let expression = VectorExpression::Reverse(Box::new(VectorExpression::Reference(
316            ReferenceExpression::Constant(vec![1, 2]),
317        )));
318        assert_eq!(
319            expression.eval(
320                &state,
321                &mut function_cache,
322                &state_functions,
323                &registry
324            ),
325            vec![2, 1]
326        );
327    }
328
329    #[test]
330    fn vector_set_eval() {
331        let state = generate_state();
332        let state_functions = StateFunctions::default();
333        let mut function_cache = StateFunctionCache::new(&state_functions);
334        let registry = generate_registry();
335        let expression = VectorExpression::Set(
336            ElementExpression::Constant(3),
337            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
338                vec![1, 2],
339            ))),
340            ElementExpression::Constant(0),
341        );
342        assert_eq!(
343            expression.eval(
344                &state,
345                &mut function_cache,
346                &state_functions,
347                &registry
348            ),
349            vec![3, 2]
350        );
351    }
352
353    #[test]
354    fn vector_push_eval() {
355        let state = generate_state();
356        let state_functions = StateFunctions::default();
357        let mut function_cache = StateFunctionCache::new(&state_functions);
358        let registry = generate_registry();
359        let expression = VectorExpression::Push(
360            ElementExpression::Constant(0),
361            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
362                vec![1, 2],
363            ))),
364        );
365        assert_eq!(
366            expression.eval(
367                &state,
368                &mut function_cache,
369                &state_functions,
370                &registry
371            ),
372            vec![1, 2, 0]
373        );
374    }
375
376    #[test]
377    fn vector_pop_eval() {
378        let state = generate_state();
379        let state_functions = StateFunctions::default();
380        let mut function_cache = StateFunctionCache::new(&state_functions);
381        let registry = generate_registry();
382        let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
383            ReferenceExpression::Constant(vec![1, 2]),
384        )));
385        assert_eq!(
386            expression.eval(
387                &state,
388                &mut function_cache,
389                &state_functions,
390                &registry
391            ),
392            vec![1]
393        );
394    }
395
396    #[test]
397    fn vector_from_set_eval() {
398        let state = generate_state();
399        let state_functions = StateFunctions::default();
400        let mut function_cache = StateFunctionCache::new(&state_functions);
401        let registry = generate_registry();
402        let mut set = Set::with_capacity(3);
403        set.insert(0);
404        set.insert(1);
405        let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
406            ReferenceExpression::Constant(set),
407        )));
408        assert_eq!(
409            expression.eval(
410                &state,
411                &mut function_cache,
412                &state_functions,
413                &registry
414            ),
415            vec![0, 1]
416        );
417        let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
418            ReferenceExpression::Variable(0),
419        )));
420        assert_eq!(
421            expression.eval(
422                &state,
423                &mut function_cache,
424                &state_functions,
425                &registry
426            ),
427            vec![0, 2]
428        );
429    }
430
431    #[test]
432    fn vector_if_eval() {
433        let state = generate_state();
434        let state_functions = StateFunctions::default();
435        let mut function_cache = StateFunctionCache::new(&state_functions);
436        let registry = generate_registry();
437        let expression = VectorExpression::If(
438            Box::new(Condition::Constant(true)),
439            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
440                vec![0, 1],
441            ))),
442            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
443                vec![1, 0],
444            ))),
445        );
446        assert_eq!(
447            expression.eval(
448                &state,
449                &mut function_cache,
450                &state_functions,
451                &registry
452            ),
453            vec![0, 1]
454        );
455        let expression = VectorExpression::If(
456            Box::new(Condition::Constant(false)),
457            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
458                vec![0, 1],
459            ))),
460            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
461                vec![1, 0],
462            ))),
463        );
464        assert_eq!(
465            expression.eval(
466                &state,
467                &mut function_cache,
468                &state_functions,
469                &registry
470            ),
471            vec![1, 0]
472        );
473    }
474
475    #[test]
476    fn vector_reference_simplify() {
477        let registry = generate_registry();
478        let expression = VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2]));
479        assert_eq!(expression.simplify(&registry), expression);
480        let expression = VectorExpression::Reference(ReferenceExpression::Table(
481            TableExpression::Table1D(0, ElementExpression::Constant(0)),
482        ));
483        assert_eq!(
484            expression.simplify(&registry),
485            VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
486        );
487    }
488
489    #[test]
490    fn vector_indices_simplify() {
491        let registry = generate_registry();
492
493        let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
494            ReferenceExpression::Variable(0),
495        )));
496        assert_eq!(expression.simplify(&registry), expression);
497
498        let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
499            ReferenceExpression::Constant(vec![1, 2]),
500        )));
501        assert_eq!(
502            expression.simplify(&registry),
503            VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
504        );
505    }
506
507    #[test]
508    fn vector_push_simplify() {
509        let registry = generate_registry();
510        let expression = VectorExpression::Push(
511            ElementExpression::Constant(0),
512            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
513                vec![1, 2],
514            ))),
515        );
516        assert_eq!(
517            expression.simplify(&registry),
518            VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2, 0]))
519        );
520        let expression = VectorExpression::Push(
521            ElementExpression::Constant(0),
522            Box::new(VectorExpression::Reference(ReferenceExpression::Variable(
523                0,
524            ))),
525        );
526        assert_eq!(expression.simplify(&registry), expression);
527    }
528
529    #[test]
530    fn vector_pop_simplify() {
531        let registry = generate_registry();
532        let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
533            ReferenceExpression::Constant(vec![1, 2]),
534        )));
535        assert_eq!(
536            expression.simplify(&registry),
537            VectorExpression::Reference(ReferenceExpression::Constant(vec![1]))
538        );
539        let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
540            ReferenceExpression::Variable(0),
541        )));
542        assert_eq!(expression.simplify(&registry), expression);
543    }
544
545    #[test]
546    fn vector_set_simplify() {
547        let registry = generate_registry();
548        let expression = VectorExpression::Set(
549            ElementExpression::Constant(0),
550            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
551                vec![1, 2],
552            ))),
553            ElementExpression::Constant(0),
554        );
555        assert_eq!(
556            expression.simplify(&registry),
557            VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 2]))
558        );
559        let expression = VectorExpression::Set(
560            ElementExpression::Constant(0),
561            Box::new(VectorExpression::Reference(ReferenceExpression::Variable(
562                0,
563            ))),
564            ElementExpression::Variable(0),
565        );
566        assert_eq!(expression.simplify(&registry), expression);
567    }
568
569    #[test]
570    fn vector_from_set_simplify() {
571        let registry = generate_registry();
572        let mut set = Set::with_capacity(3);
573        set.insert(0);
574        set.insert(1);
575        let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
576            ReferenceExpression::Constant(set),
577        )));
578        assert_eq!(
579            expression.simplify(&registry),
580            VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
581        );
582        let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
583            ReferenceExpression::Variable(0),
584        )));
585        assert_eq!(expression.simplify(&registry), expression);
586    }
587
588    #[test]
589    fn vector_if_simplify() {
590        let registry = generate_registry();
591        let expression = VectorExpression::If(
592            Box::new(Condition::Constant(true)),
593            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
594                vec![0, 1],
595            ))),
596            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
597                vec![1, 0],
598            ))),
599        );
600        assert_eq!(
601            expression.simplify(&registry),
602            VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
603        );
604        let expression = VectorExpression::If(
605            Box::new(Condition::Constant(false)),
606            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
607                vec![0, 1],
608            ))),
609            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
610                vec![1, 0],
611            ))),
612        );
613        assert_eq!(
614            expression.simplify(&registry),
615            VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 0]))
616        );
617        let expression = VectorExpression::If(
618            Box::new(Condition::ComparisonI(
619                ComparisonOperator::Gt,
620                Box::new(IntegerExpression::Variable(0)),
621                Box::new(IntegerExpression::Constant(1)),
622            )),
623            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
624                vec![0, 1],
625            ))),
626            Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
627                vec![1, 0],
628            ))),
629        );
630        assert_eq!(expression.simplify(&registry), expression);
631    }
632}