Skip to main content

lemma/inversion/
derived.rs

1//! Derived expressions for inversion
2//!
3//! Expressions created during solving have no source location.
4//! They are derived from plan expressions, not parsed from user input.
5//! Strong separation: Expression (planning) has source; DerivedExpression (inversion) does not.
6
7use crate::planning::semantics::{
8    ArithmeticComputation, ComparisonComputation, FactPath, LiteralValue, MathematicalComputation,
9    NegationType, RulePath, SemanticConversionTarget, VetoExpression,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16/// Expression derived during inversion/solving. No source location.
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct DerivedExpression {
19    pub kind: DerivedExpressionKind,
20}
21
22impl DerivedExpression {
23    pub fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
24        self.kind.collect_fact_paths(facts);
25    }
26
27    pub fn semantic_hash<H: Hasher>(&self, state: &mut H) {
28        self.kind.semantic_hash(state);
29    }
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum DerivedExpressionKind {
35    /// Boxed to keep enum size small (LiteralValue is large)
36    Literal(Box<LiteralValue>),
37    FactPath(FactPath),
38    RulePath(RulePath),
39    LogicalAnd(Arc<DerivedExpression>, Arc<DerivedExpression>),
40    Arithmetic(
41        Arc<DerivedExpression>,
42        ArithmeticComputation,
43        Arc<DerivedExpression>,
44    ),
45    Comparison(
46        Arc<DerivedExpression>,
47        ComparisonComputation,
48        Arc<DerivedExpression>,
49    ),
50    UnitConversion(Arc<DerivedExpression>, SemanticConversionTarget),
51    LogicalNegation(Arc<DerivedExpression>, NegationType),
52    MathematicalComputation(MathematicalComputation, Arc<DerivedExpression>),
53    Veto(VetoExpression),
54}
55
56impl DerivedExpressionKind {
57    fn collect_fact_paths(&self, facts: &mut HashSet<FactPath>) {
58        match self {
59            DerivedExpressionKind::FactPath(fp) => {
60                facts.insert(fp.clone());
61            }
62            DerivedExpressionKind::LogicalAnd(left, right)
63            | DerivedExpressionKind::Arithmetic(left, _, right)
64            | DerivedExpressionKind::Comparison(left, _, right) => {
65                left.collect_fact_paths(facts);
66                right.collect_fact_paths(facts);
67            }
68            DerivedExpressionKind::UnitConversion(inner, _)
69            | DerivedExpressionKind::LogicalNegation(inner, _)
70            | DerivedExpressionKind::MathematicalComputation(_, inner) => {
71                inner.collect_fact_paths(facts);
72            }
73            DerivedExpressionKind::Literal(_)
74            | DerivedExpressionKind::RulePath(_)
75            | DerivedExpressionKind::Veto(_) => {}
76        }
77    }
78
79    fn semantic_hash<H: Hasher>(&self, state: &mut H) {
80        std::mem::discriminant(self).hash(state);
81        match self {
82            DerivedExpressionKind::Literal(lit) => lit.hash(state),
83            DerivedExpressionKind::FactPath(fp) => fp.hash(state),
84            DerivedExpressionKind::RulePath(rp) => rp.hash(state),
85            DerivedExpressionKind::LogicalAnd(left, right) => {
86                left.semantic_hash(state);
87                right.semantic_hash(state);
88            }
89            DerivedExpressionKind::Arithmetic(left, op, right) => {
90                left.semantic_hash(state);
91                op.hash(state);
92                right.semantic_hash(state);
93            }
94            DerivedExpressionKind::Comparison(left, op, right) => {
95                left.semantic_hash(state);
96                op.hash(state);
97                right.semantic_hash(state);
98            }
99            DerivedExpressionKind::UnitConversion(expr, target) => {
100                expr.semantic_hash(state);
101                target.hash(state);
102            }
103            DerivedExpressionKind::LogicalNegation(expr, neg_type) => {
104                expr.semantic_hash(state);
105                neg_type.hash(state);
106            }
107            DerivedExpressionKind::MathematicalComputation(op, expr) => {
108                op.hash(state);
109                expr.semantic_hash(state);
110            }
111            DerivedExpressionKind::Veto(v) => v.message.hash(state),
112        }
113    }
114}
115
116impl Eq for DerivedExpression {}
117impl Hash for DerivedExpression {
118    fn hash<H: Hasher>(&self, state: &mut H) {
119        self.semantic_hash(state);
120    }
121}
122
123impl Eq for DerivedExpressionKind {}
124impl Hash for DerivedExpressionKind {
125    fn hash<H: Hasher>(&self, state: &mut H) {
126        self.semantic_hash(state);
127    }
128}