lemma/inversion/
derived.rs1use crate::planning::semantics::{
8 ArithmeticComputation, ComparisonComputation, FactPath, LiteralValue, MathematicalComputation,
9 NegationType, RulePath, SemanticConversionTarget, VetoExpression,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct DerivedExpression {
19 pub kind: DerivedExpressionKind,
20}
21
22impl DerivedExpression {
23 pub fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
24 self.kind.collect_fact_paths(facts);
25 }
26
27 pub fn semantic_hash<H: Hasher>(&self, state: &mut H) {
28 self.kind.semantic_hash(state);
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum DerivedExpressionKind {
35 Literal(Box<LiteralValue>),
37 FactPath(FactPath),
38 RulePath(RulePath),
39 LogicalAnd(Arc<DerivedExpression>, Arc<DerivedExpression>),
40 LogicalOr(Arc<DerivedExpression>, Arc<DerivedExpression>),
41 Arithmetic(
42 Arc<DerivedExpression>,
43 ArithmeticComputation,
44 Arc<DerivedExpression>,
45 ),
46 Comparison(
47 Arc<DerivedExpression>,
48 ComparisonComputation,
49 Arc<DerivedExpression>,
50 ),
51 UnitConversion(Arc<DerivedExpression>, SemanticConversionTarget),
52 LogicalNegation(Arc<DerivedExpression>, NegationType),
53 MathematicalComputation(MathematicalComputation, Arc<DerivedExpression>),
54 Veto(VetoExpression),
55}
56
57impl DerivedExpressionKind {
58 fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
59 match self {
60 DerivedExpressionKind::FactPath(fp) => {
61 facts.insert(fp.clone());
62 }
63 DerivedExpressionKind::LogicalAnd(left, right)
64 | DerivedExpressionKind::LogicalOr(left, right)
65 | DerivedExpressionKind::Arithmetic(left, _, right)
66 | DerivedExpressionKind::Comparison(left, _, right) => {
67 left.collect_fact_paths(facts);
68 right.collect_fact_paths(facts);
69 }
70 DerivedExpressionKind::UnitConversion(inner, _)
71 | DerivedExpressionKind::LogicalNegation(inner, _)
72 | DerivedExpressionKind::MathematicalComputation(_, inner) => {
73 inner.collect_fact_paths(facts);
74 }
75 DerivedExpressionKind::Literal(_)
76 | DerivedExpressionKind::RulePath(_)
77 | DerivedExpressionKind::Veto(_) => {}
78 }
79 }
80
81 fn semantic_hash<H: Hasher>(&self, state: &mut H) {
82 std::mem::discriminant(self).hash(state);
83 match self {
84 DerivedExpressionKind::Literal(lit) => lit.hash(state),
85 DerivedExpressionKind::FactPath(fp) => fp.hash(state),
86 DerivedExpressionKind::RulePath(rp) => rp.hash(state),
87 DerivedExpressionKind::LogicalAnd(left, right)
88 | DerivedExpressionKind::LogicalOr(left, right) => {
89 left.semantic_hash(state);
90 right.semantic_hash(state);
91 }
92 DerivedExpressionKind::Arithmetic(left, op, right) => {
93 left.semantic_hash(state);
94 op.hash(state);
95 right.semantic_hash(state);
96 }
97 DerivedExpressionKind::Comparison(left, op, right) => {
98 left.semantic_hash(state);
99 op.hash(state);
100 right.semantic_hash(state);
101 }
102 DerivedExpressionKind::UnitConversion(expr, target) => {
103 expr.semantic_hash(state);
104 target.hash(state);
105 }
106 DerivedExpressionKind::LogicalNegation(expr, neg_type) => {
107 expr.semantic_hash(state);
108 neg_type.hash(state);
109 }
110 DerivedExpressionKind::MathematicalComputation(op, expr) => {
111 op.hash(state);
112 expr.semantic_hash(state);
113 }
114 DerivedExpressionKind::Veto(v) => v.message.hash(state),
115 }
116 }
117}
118
119impl Eq for DerivedExpression {}
120impl Hash for DerivedExpression {
121 fn hash<H: Hasher>(&self, state: &mut H) {
122 self.semantic_hash(state);
123 }
124}
125
126impl Eq for DerivedExpressionKind {}
127impl Hash for DerivedExpressionKind {
128 fn hash<H: Hasher>(&self, state: &mut H) {
129 self.semantic_hash(state);
130 }
131}