lemma/inversion/
derived.rs1use crate::planning::semantics::{
8 ArithmeticComputation, ComparisonComputation, FactPath, LiteralValue, MathematicalComputation,
9 NegationType, RulePath, SemanticConversionTarget, VetoExpression,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct DerivedExpression {
19 pub kind: DerivedExpressionKind,
20}
21
22impl DerivedExpression {
23 pub fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
24 self.kind.collect_fact_paths(facts);
25 }
26
27 pub fn semantic_hash<H: Hasher>(&self, state: &mut H) {
28 self.kind.semantic_hash(state);
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum DerivedExpressionKind {
35 Literal(Box<LiteralValue>),
37 FactPath(FactPath),
38 RulePath(RulePath),
39 LogicalAnd(Arc<DerivedExpression>, Arc<DerivedExpression>),
40 Arithmetic(
41 Arc<DerivedExpression>,
42 ArithmeticComputation,
43 Arc<DerivedExpression>,
44 ),
45 Comparison(
46 Arc<DerivedExpression>,
47 ComparisonComputation,
48 Arc<DerivedExpression>,
49 ),
50 UnitConversion(Arc<DerivedExpression>, SemanticConversionTarget),
51 LogicalNegation(Arc<DerivedExpression>, NegationType),
52 MathematicalComputation(MathematicalComputation, Arc<DerivedExpression>),
53 Veto(VetoExpression),
54}
55
56impl DerivedExpressionKind {
57 fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
58 match self {
59 DerivedExpressionKind::FactPath(fp) => {
60 facts.insert(fp.clone());
61 }
62 DerivedExpressionKind::LogicalAnd(left, right)
63 | DerivedExpressionKind::Arithmetic(left, _, right)
64 | DerivedExpressionKind::Comparison(left, _, right) => {
65 left.collect_fact_paths(facts);
66 right.collect_fact_paths(facts);
67 }
68 DerivedExpressionKind::UnitConversion(inner, _)
69 | DerivedExpressionKind::LogicalNegation(inner, _)
70 | DerivedExpressionKind::MathematicalComputation(_, inner) => {
71 inner.collect_fact_paths(facts);
72 }
73 DerivedExpressionKind::Literal(_)
74 | DerivedExpressionKind::RulePath(_)
75 | DerivedExpressionKind::Veto(_) => {}
76 }
77 }
78
79 fn semantic_hash<H: Hasher>(&self, state: &mut H) {
80 std::mem::discriminant(self).hash(state);
81 match self {
82 DerivedExpressionKind::Literal(lit) => lit.hash(state),
83 DerivedExpressionKind::FactPath(fp) => fp.hash(state),
84 DerivedExpressionKind::RulePath(rp) => rp.hash(state),
85 DerivedExpressionKind::LogicalAnd(left, right) => {
86 left.semantic_hash(state);
87 right.semantic_hash(state);
88 }
89 DerivedExpressionKind::Arithmetic(left, op, right) => {
90 left.semantic_hash(state);
91 op.hash(state);
92 right.semantic_hash(state);
93 }
94 DerivedExpressionKind::Comparison(left, op, right) => {
95 left.semantic_hash(state);
96 op.hash(state);
97 right.semantic_hash(state);
98 }
99 DerivedExpressionKind::UnitConversion(expr, target) => {
100 expr.semantic_hash(state);
101 target.hash(state);
102 }
103 DerivedExpressionKind::LogicalNegation(expr, neg_type) => {
104 expr.semantic_hash(state);
105 neg_type.hash(state);
106 }
107 DerivedExpressionKind::MathematicalComputation(op, expr) => {
108 op.hash(state);
109 expr.semantic_hash(state);
110 }
111 DerivedExpressionKind::Veto(v) => v.message.hash(state),
112 }
113 }
114}
115
116impl Eq for DerivedExpression {}
117impl Hash for DerivedExpression {
118 fn hash<H: Hasher>(&self, state: &mut H) {
119 self.semantic_hash(state);
120 }
121}
122
123impl Eq for DerivedExpressionKind {}
124impl Hash for DerivedExpressionKind {
125 fn hash<H: Hasher>(&self, state: &mut H) {
126 self.semantic_hash(state);
127 }
128}