Skip to main content

lemma/inversion/
derived.rs

1//! Derived expressions for inversion
2//!
3//! Expressions created during solving have no source location.
4//! They are derived from plan expressions, not parsed from user input.
5//! Strong separation: Expression (planning) has source; DerivedExpression (inversion) does not.
6
7use crate::planning::semantics::{
8    ArithmeticComputation, ComparisonComputation, FactPath, LiteralValue, MathematicalComputation,
9    NegationType, RulePath, SemanticConversionTarget, VetoExpression,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16/// Expression derived during inversion/solving. No source location.
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct DerivedExpression {
19    pub kind: DerivedExpressionKind,
20}
21
22impl DerivedExpression {
23    pub fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
24        self.kind.collect_fact_paths(facts);
25    }
26
27    pub fn semantic_hash<H: Hasher>(&self, state: &mut H) {
28        self.kind.semantic_hash(state);
29    }
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum DerivedExpressionKind {
35    /// Boxed to keep enum size small (LiteralValue is large)
36    Literal(Box<LiteralValue>),
37    FactPath(FactPath),
38    RulePath(RulePath),
39    LogicalAnd(Arc<DerivedExpression>, Arc<DerivedExpression>),
40    LogicalOr(Arc<DerivedExpression>, Arc<DerivedExpression>),
41    Arithmetic(
42        Arc<DerivedExpression>,
43        ArithmeticComputation,
44        Arc<DerivedExpression>,
45    ),
46    Comparison(
47        Arc<DerivedExpression>,
48        ComparisonComputation,
49        Arc<DerivedExpression>,
50    ),
51    UnitConversion(Arc<DerivedExpression>, SemanticConversionTarget),
52    LogicalNegation(Arc<DerivedExpression>, NegationType),
53    MathematicalComputation(MathematicalComputation, Arc<DerivedExpression>),
54    Veto(VetoExpression),
55}
56
57impl DerivedExpressionKind {
58    fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
59        match self {
60            DerivedExpressionKind::FactPath(fp) => {
61                facts.insert(fp.clone());
62            }
63            DerivedExpressionKind::LogicalAnd(left, right)
64            | DerivedExpressionKind::LogicalOr(left, right)
65            | DerivedExpressionKind::Arithmetic(left, _, right)
66            | DerivedExpressionKind::Comparison(left, _, right) => {
67                left.collect_fact_paths(facts);
68                right.collect_fact_paths(facts);
69            }
70            DerivedExpressionKind::UnitConversion(inner, _)
71            | DerivedExpressionKind::LogicalNegation(inner, _)
72            | DerivedExpressionKind::MathematicalComputation(_, inner) => {
73                inner.collect_fact_paths(facts);
74            }
75            DerivedExpressionKind::Literal(_)
76            | DerivedExpressionKind::RulePath(_)
77            | DerivedExpressionKind::Veto(_) => {}
78        }
79    }
80
81    fn semantic_hash<H: Hasher>(&self, state: &mut H) {
82        std::mem::discriminant(self).hash(state);
83        match self {
84            DerivedExpressionKind::Literal(lit) => lit.hash(state),
85            DerivedExpressionKind::FactPath(fp) => fp.hash(state),
86            DerivedExpressionKind::RulePath(rp) => rp.hash(state),
87            DerivedExpressionKind::LogicalAnd(left, right)
88            | DerivedExpressionKind::LogicalOr(left, right) => {
89                left.semantic_hash(state);
90                right.semantic_hash(state);
91            }
92            DerivedExpressionKind::Arithmetic(left, op, right) => {
93                left.semantic_hash(state);
94                op.hash(state);
95                right.semantic_hash(state);
96            }
97            DerivedExpressionKind::Comparison(left, op, right) => {
98                left.semantic_hash(state);
99                op.hash(state);
100                right.semantic_hash(state);
101            }
102            DerivedExpressionKind::UnitConversion(expr, target) => {
103                expr.semantic_hash(state);
104                target.hash(state);
105            }
106            DerivedExpressionKind::LogicalNegation(expr, neg_type) => {
107                expr.semantic_hash(state);
108                neg_type.hash(state);
109            }
110            DerivedExpressionKind::MathematicalComputation(op, expr) => {
111                op.hash(state);
112                expr.semantic_hash(state);
113            }
114            DerivedExpressionKind::Veto(v) => v.message.hash(state),
115        }
116    }
117}
118
119impl Eq for DerivedExpression {}
120impl Hash for DerivedExpression {
121    fn hash<H: Hasher>(&self, state: &mut H) {
122        self.semantic_hash(state);
123    }
124}
125
126impl Eq for DerivedExpressionKind {}
127impl Hash for DerivedExpressionKind {
128    fn hash<H: Hasher>(&self, state: &mut H) {
129        self.semantic_hash(state);
130    }
131}