1use super::condition::Condition;
2use super::element_expression::ElementExpression;
3use super::reference_expression::ReferenceExpression;
4use super::set_expression::SetExpression;
5use crate::state::StateInterface;
6use crate::state_functions::{StateFunctionCache, StateFunctions};
7use crate::table_registry::TableRegistry;
8use crate::variable_type::Vector;
9
10#[derive(Debug, PartialEq, Clone)]
12pub enum VectorExpression {
13 Reference(ReferenceExpression<Vector>),
15 Indices(Box<VectorExpression>),
17 Reverse(Box<VectorExpression>),
19 Set(ElementExpression, Box<VectorExpression>, ElementExpression),
21 Push(ElementExpression, Box<VectorExpression>),
23 Pop(Box<VectorExpression>),
25 FromSet(Box<SetExpression>),
27 If(Box<Condition>, Box<VectorExpression>, Box<VectorExpression>),
29}
30
31impl VectorExpression {
32 pub fn eval<T: StateInterface>(
38 &self,
39 state: &T,
40 function_cache: &mut StateFunctionCache,
41 state_functions: &StateFunctions,
42 registry: &TableRegistry,
43 ) -> Vector {
44 match self {
45 Self::Reference(expression) => expression
46 .eval(state, function_cache, state_functions, registry)
47 .clone(),
48 Self::Indices(vector) => {
49 let mut vector = vector.eval(state, function_cache, state_functions, registry);
50 vector.iter_mut().enumerate().for_each(|(i, v)| *v = i);
51 vector
52 }
53 Self::Reverse(vector) => {
54 let mut vector = vector.eval(state, function_cache, state_functions, registry);
55 vector.reverse();
56 vector
57 }
58 Self::Set(element, vector, i) => {
59 let mut vector = vector.eval(state, function_cache, state_functions, registry);
60 vector[i.eval(state, function_cache, state_functions, registry)] =
61 element.eval(state, function_cache, state_functions, registry);
62 vector
63 }
64 Self::Push(element, vector) => {
65 let element = element.eval(state, function_cache, state_functions, registry);
66 let mut vector = vector.eval(state, function_cache, state_functions, registry);
67 vector.push(element);
68 vector
69 }
70 Self::Pop(vector) => {
71 let mut vector = vector.eval(state, function_cache, state_functions, registry);
72 vector.pop();
73 vector
74 }
75 Self::FromSet(set) => match set.as_ref() {
76 SetExpression::Reference(set) => set
77 .eval(state, function_cache, state_functions, registry)
78 .ones()
79 .collect(),
80 set => set
81 .eval(state, function_cache, state_functions, registry)
82 .ones()
83 .collect(),
84 },
85 Self::If(condition, x, y) => {
86 if condition.eval(state, function_cache, state_functions, registry) {
87 x.eval(state, function_cache, state_functions, registry)
88 } else {
89 y.eval(state, function_cache, state_functions, registry)
90 }
91 }
92 }
93 }
94
95 pub fn simplify(&self, registry: &TableRegistry) -> VectorExpression {
101 match self {
102 Self::Reference(vector) => {
103 Self::Reference(vector.simplify(registry, ®istry.vector_tables))
104 }
105 Self::Indices(vector) => match vector.simplify(registry) {
106 VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
107 vector.iter_mut().enumerate().for_each(|(i, v)| *v = i);
108 Self::Reference(ReferenceExpression::Constant(vector))
109 }
110 vector => Self::Indices(Box::new(vector)),
111 },
112 Self::Reverse(vector) => match vector.simplify(registry) {
113 VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
114 vector.reverse();
115 Self::Reference(ReferenceExpression::Constant(vector))
116 }
117 vector => Self::Reverse(Box::new(vector)),
118 },
119 Self::Set(element, vector, i) => match (
120 element.simplify(registry),
121 vector.simplify(registry),
122 i.simplify(registry),
123 ) {
124 (
125 ElementExpression::Constant(element),
126 VectorExpression::Reference(ReferenceExpression::Constant(mut vector)),
127 ElementExpression::Constant(i),
128 ) => {
129 vector[i] = element;
130 Self::Reference(ReferenceExpression::Constant(vector))
131 }
132 (element, vector, i) => Self::Set(element, Box::new(vector), i),
133 },
134 Self::Push(element, vector) => {
135 match (element.simplify(registry), vector.simplify(registry)) {
136 (
137 ElementExpression::Constant(element),
138 VectorExpression::Reference(ReferenceExpression::Constant(mut vector)),
139 ) => {
140 vector.push(element);
141 Self::Reference(ReferenceExpression::Constant(vector))
142 }
143 (element, vector) => Self::Push(element, Box::new(vector)),
144 }
145 }
146 Self::Pop(vector) => match vector.simplify(registry) {
147 VectorExpression::Reference(ReferenceExpression::Constant(mut vector)) => {
148 vector.pop();
149 Self::Reference(ReferenceExpression::Constant(vector))
150 }
151 vector => Self::Pop(Box::new(vector)),
152 },
153 Self::FromSet(set) => match set.simplify(registry) {
154 SetExpression::Reference(ReferenceExpression::Constant(set)) => {
155 Self::Reference(ReferenceExpression::Constant(set.ones().collect()))
156 }
157 set => Self::FromSet(Box::new(set)),
158 },
159 Self::If(condition, x, y) => match condition.simplify(registry) {
160 Condition::Constant(true) => x.simplify(registry),
161 Condition::Constant(false) => y.simplify(registry),
162 condition => Self::If(
163 Box::new(condition),
164 Box::new(x.simplify(registry)),
165 Box::new(y.simplify(registry)),
166 ),
167 },
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::super::condition::ComparisonOperator;
175 use super::super::integer_expression::IntegerExpression;
176 use super::super::table_expression::TableExpression;
177 use super::*;
178 use crate::state::*;
179 use crate::table::*;
180 use crate::table_data::*;
181 use crate::variable_type::Set;
182 use rustc_hash::FxHashMap;
183
184 fn generate_registry() -> TableRegistry {
185 let mut name_to_constant = FxHashMap::default();
186 name_to_constant.insert(String::from("f0"), 1);
187
188 let tables_1d = vec![Table1D::new(vec![1, 0])];
189 let mut name_to_table_1d = FxHashMap::default();
190 name_to_table_1d.insert(String::from("f1"), 0);
191
192 let tables_2d = vec![Table2D::new(vec![vec![1, 0]])];
193 let mut name_to_table_2d = FxHashMap::default();
194 name_to_table_2d.insert(String::from("f2"), 0);
195
196 let tables_3d = vec![Table3D::new(vec![vec![vec![1, 0]]])];
197 let mut name_to_table_3d = FxHashMap::default();
198 name_to_table_3d.insert(String::from("f3"), 0);
199
200 let mut map = FxHashMap::default();
201 let key = vec![0, 0, 0, 0];
202 map.insert(key, 1);
203 let key = vec![0, 0, 0, 1];
204 map.insert(key, 0);
205 let tables = vec![Table::new(map, 0)];
206 let mut name_to_table = FxHashMap::default();
207 name_to_table.insert(String::from("f4"), 0);
208
209 let element_tables = TableData {
210 name_to_constant,
211 tables_1d,
212 name_to_table_1d,
213 tables_2d,
214 name_to_table_2d,
215 tables_3d,
216 name_to_table_3d,
217 tables,
218 name_to_table,
219 };
220
221 let mut name_to_table_1d = FxHashMap::default();
222 name_to_table_1d.insert(String::from("t1"), 0);
223 let vector_tables = TableData {
224 tables_1d: vec![Table1D::new(vec![vec![0, 1]])],
225 name_to_table_1d,
226 ..Default::default()
227 };
228
229 let mut set = Set::with_capacity(3);
230 set.insert(0);
231 set.insert(2);
232 let default = Set::with_capacity(3);
233 let tables_1d = vec![Table1D::new(vec![set, default.clone(), default])];
234 let mut name_to_table_1d = FxHashMap::default();
235 name_to_table_1d.insert(String::from("s1"), 0);
236 let set_tables = TableData {
237 tables_1d,
238 name_to_table_1d,
239 ..Default::default()
240 };
241
242 TableRegistry {
243 element_tables,
244 set_tables,
245 vector_tables,
246 ..Default::default()
247 }
248 }
249
250 fn generate_state() -> State {
251 let mut set1 = Set::with_capacity(3);
252 set1.insert(0);
253 set1.insert(2);
254 let mut set2 = Set::with_capacity(3);
255 set2.insert(0);
256 set2.insert(1);
257 State {
258 signature_variables: SignatureVariables {
259 set_variables: vec![set1, set2],
260 vector_variables: vec![vec![0, 2]],
261 element_variables: vec![1],
262 ..Default::default()
263 },
264 resource_variables: ResourceVariables {
265 element_variables: vec![2],
266 ..Default::default()
267 },
268 }
269 }
270
271 #[test]
272 fn vector_reference_eval() {
273 let state = generate_state();
274 let state_functions = StateFunctions::default();
275 let mut function_cache = StateFunctionCache::new(&state_functions);
276 let registry = generate_registry();
277 let expression = VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2]));
278 assert_eq!(
279 expression.eval(
280 &state,
281 &mut function_cache,
282 &state_functions,
283 ®istry
284 ),
285 vec![1, 2]
286 );
287 }
288
289 #[test]
290 fn vector_indices_eval() {
291 let state = generate_state();
292 let state_functions = StateFunctions::default();
293 let mut function_cache = StateFunctionCache::new(&state_functions);
294 let registry = generate_registry();
295 let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
296 ReferenceExpression::Constant(vec![1, 2]),
297 )));
298 assert_eq!(
299 expression.eval(
300 &state,
301 &mut function_cache,
302 &state_functions,
303 ®istry
304 ),
305 vec![0, 1]
306 );
307 }
308
309 #[test]
310 fn vector_reverse_eval() {
311 let state = generate_state();
312 let state_functions = StateFunctions::default();
313 let mut function_cache = StateFunctionCache::new(&state_functions);
314 let registry = generate_registry();
315 let expression = VectorExpression::Reverse(Box::new(VectorExpression::Reference(
316 ReferenceExpression::Constant(vec![1, 2]),
317 )));
318 assert_eq!(
319 expression.eval(
320 &state,
321 &mut function_cache,
322 &state_functions,
323 ®istry
324 ),
325 vec![2, 1]
326 );
327 }
328
329 #[test]
330 fn vector_set_eval() {
331 let state = generate_state();
332 let state_functions = StateFunctions::default();
333 let mut function_cache = StateFunctionCache::new(&state_functions);
334 let registry = generate_registry();
335 let expression = VectorExpression::Set(
336 ElementExpression::Constant(3),
337 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
338 vec![1, 2],
339 ))),
340 ElementExpression::Constant(0),
341 );
342 assert_eq!(
343 expression.eval(
344 &state,
345 &mut function_cache,
346 &state_functions,
347 ®istry
348 ),
349 vec![3, 2]
350 );
351 }
352
353 #[test]
354 fn vector_push_eval() {
355 let state = generate_state();
356 let state_functions = StateFunctions::default();
357 let mut function_cache = StateFunctionCache::new(&state_functions);
358 let registry = generate_registry();
359 let expression = VectorExpression::Push(
360 ElementExpression::Constant(0),
361 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
362 vec![1, 2],
363 ))),
364 );
365 assert_eq!(
366 expression.eval(
367 &state,
368 &mut function_cache,
369 &state_functions,
370 ®istry
371 ),
372 vec![1, 2, 0]
373 );
374 }
375
376 #[test]
377 fn vector_pop_eval() {
378 let state = generate_state();
379 let state_functions = StateFunctions::default();
380 let mut function_cache = StateFunctionCache::new(&state_functions);
381 let registry = generate_registry();
382 let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
383 ReferenceExpression::Constant(vec![1, 2]),
384 )));
385 assert_eq!(
386 expression.eval(
387 &state,
388 &mut function_cache,
389 &state_functions,
390 ®istry
391 ),
392 vec![1]
393 );
394 }
395
396 #[test]
397 fn vector_from_set_eval() {
398 let state = generate_state();
399 let state_functions = StateFunctions::default();
400 let mut function_cache = StateFunctionCache::new(&state_functions);
401 let registry = generate_registry();
402 let mut set = Set::with_capacity(3);
403 set.insert(0);
404 set.insert(1);
405 let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
406 ReferenceExpression::Constant(set),
407 )));
408 assert_eq!(
409 expression.eval(
410 &state,
411 &mut function_cache,
412 &state_functions,
413 ®istry
414 ),
415 vec![0, 1]
416 );
417 let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
418 ReferenceExpression::Variable(0),
419 )));
420 assert_eq!(
421 expression.eval(
422 &state,
423 &mut function_cache,
424 &state_functions,
425 ®istry
426 ),
427 vec![0, 2]
428 );
429 }
430
431 #[test]
432 fn vector_if_eval() {
433 let state = generate_state();
434 let state_functions = StateFunctions::default();
435 let mut function_cache = StateFunctionCache::new(&state_functions);
436 let registry = generate_registry();
437 let expression = VectorExpression::If(
438 Box::new(Condition::Constant(true)),
439 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
440 vec![0, 1],
441 ))),
442 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
443 vec![1, 0],
444 ))),
445 );
446 assert_eq!(
447 expression.eval(
448 &state,
449 &mut function_cache,
450 &state_functions,
451 ®istry
452 ),
453 vec![0, 1]
454 );
455 let expression = VectorExpression::If(
456 Box::new(Condition::Constant(false)),
457 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
458 vec![0, 1],
459 ))),
460 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
461 vec![1, 0],
462 ))),
463 );
464 assert_eq!(
465 expression.eval(
466 &state,
467 &mut function_cache,
468 &state_functions,
469 ®istry
470 ),
471 vec![1, 0]
472 );
473 }
474
475 #[test]
476 fn vector_reference_simplify() {
477 let registry = generate_registry();
478 let expression = VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2]));
479 assert_eq!(expression.simplify(®istry), expression);
480 let expression = VectorExpression::Reference(ReferenceExpression::Table(
481 TableExpression::Table1D(0, ElementExpression::Constant(0)),
482 ));
483 assert_eq!(
484 expression.simplify(®istry),
485 VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
486 );
487 }
488
489 #[test]
490 fn vector_indices_simplify() {
491 let registry = generate_registry();
492
493 let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
494 ReferenceExpression::Variable(0),
495 )));
496 assert_eq!(expression.simplify(®istry), expression);
497
498 let expression = VectorExpression::Indices(Box::new(VectorExpression::Reference(
499 ReferenceExpression::Constant(vec![1, 2]),
500 )));
501 assert_eq!(
502 expression.simplify(®istry),
503 VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
504 );
505 }
506
507 #[test]
508 fn vector_push_simplify() {
509 let registry = generate_registry();
510 let expression = VectorExpression::Push(
511 ElementExpression::Constant(0),
512 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
513 vec![1, 2],
514 ))),
515 );
516 assert_eq!(
517 expression.simplify(®istry),
518 VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 2, 0]))
519 );
520 let expression = VectorExpression::Push(
521 ElementExpression::Constant(0),
522 Box::new(VectorExpression::Reference(ReferenceExpression::Variable(
523 0,
524 ))),
525 );
526 assert_eq!(expression.simplify(®istry), expression);
527 }
528
529 #[test]
530 fn vector_pop_simplify() {
531 let registry = generate_registry();
532 let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
533 ReferenceExpression::Constant(vec![1, 2]),
534 )));
535 assert_eq!(
536 expression.simplify(®istry),
537 VectorExpression::Reference(ReferenceExpression::Constant(vec![1]))
538 );
539 let expression = VectorExpression::Pop(Box::new(VectorExpression::Reference(
540 ReferenceExpression::Variable(0),
541 )));
542 assert_eq!(expression.simplify(®istry), expression);
543 }
544
545 #[test]
546 fn vector_set_simplify() {
547 let registry = generate_registry();
548 let expression = VectorExpression::Set(
549 ElementExpression::Constant(0),
550 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
551 vec![1, 2],
552 ))),
553 ElementExpression::Constant(0),
554 );
555 assert_eq!(
556 expression.simplify(®istry),
557 VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 2]))
558 );
559 let expression = VectorExpression::Set(
560 ElementExpression::Constant(0),
561 Box::new(VectorExpression::Reference(ReferenceExpression::Variable(
562 0,
563 ))),
564 ElementExpression::Variable(0),
565 );
566 assert_eq!(expression.simplify(®istry), expression);
567 }
568
569 #[test]
570 fn vector_from_set_simplify() {
571 let registry = generate_registry();
572 let mut set = Set::with_capacity(3);
573 set.insert(0);
574 set.insert(1);
575 let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
576 ReferenceExpression::Constant(set),
577 )));
578 assert_eq!(
579 expression.simplify(®istry),
580 VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
581 );
582 let expression = VectorExpression::FromSet(Box::new(SetExpression::Reference(
583 ReferenceExpression::Variable(0),
584 )));
585 assert_eq!(expression.simplify(®istry), expression);
586 }
587
588 #[test]
589 fn vector_if_simplify() {
590 let registry = generate_registry();
591 let expression = VectorExpression::If(
592 Box::new(Condition::Constant(true)),
593 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
594 vec![0, 1],
595 ))),
596 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
597 vec![1, 0],
598 ))),
599 );
600 assert_eq!(
601 expression.simplify(®istry),
602 VectorExpression::Reference(ReferenceExpression::Constant(vec![0, 1]))
603 );
604 let expression = VectorExpression::If(
605 Box::new(Condition::Constant(false)),
606 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
607 vec![0, 1],
608 ))),
609 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
610 vec![1, 0],
611 ))),
612 );
613 assert_eq!(
614 expression.simplify(®istry),
615 VectorExpression::Reference(ReferenceExpression::Constant(vec![1, 0]))
616 );
617 let expression = VectorExpression::If(
618 Box::new(Condition::ComparisonI(
619 ComparisonOperator::Gt,
620 Box::new(IntegerExpression::Variable(0)),
621 Box::new(IntegerExpression::Constant(1)),
622 )),
623 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
624 vec![0, 1],
625 ))),
626 Box::new(VectorExpression::Reference(ReferenceExpression::Constant(
627 vec![1, 0],
628 ))),
629 );
630 assert_eq!(expression.simplify(®istry), expression);
631 }
632}