lemma/inversion/
shape.rs

1//! Shape representation for inversion results
2
3use crate::{Expression, FactReference, LiteralValue};
4use serde::ser::{Serialize, SerializeMap, SerializeStruct, Serializer};
5use std::fmt;
6
7/// A shape representing the solution space for an inversion query
8///
9/// Contains one or more branches, each representing a solution.
10/// Each branch specifies conditions and the corresponding outcome.
11#[derive(Debug, Clone, PartialEq)]
12pub struct Shape {
13    /// Solution branches - each branch is a valid solution
14    pub branches: Vec<ShapeBranch>,
15
16    /// Variables that are not fully constrained (free to vary)
17    pub free_variables: Vec<FactReference>,
18}
19
20/// A single branch in a shape - represents one solution
21#[derive(Debug, Clone, PartialEq)]
22pub struct ShapeBranch {
23    /// Condition when this branch applies
24    pub condition: Expression,
25
26    /// Outcome when condition is met (value expression or veto)
27    pub outcome: BranchOutcome,
28}
29
30/// Outcome of a piecewise branch
31#[derive(Debug, Clone, PartialEq)]
32pub enum BranchOutcome {
33    /// Produces a value defined by an expression
34    Value(Expression),
35    /// Produces a veto with an optional message
36    Veto(Option<String>),
37}
38
39/// Domain specification for valid values
40#[derive(Debug, Clone, PartialEq)]
41pub enum Domain {
42    /// A single continuous range
43    Range { min: Bound, max: Bound },
44
45    /// Multiple disjoint ranges
46    Union(Vec<Domain>),
47
48    /// Specific enumerated values only
49    Enumeration(Vec<LiteralValue>),
50
51    /// Everything except these constraints
52    Complement(Box<Domain>),
53
54    /// Any value (no constraints)
55    Unconstrained,
56}
57
58/// Bound specification for ranges
59#[derive(Debug, Clone, PartialEq)]
60pub enum Bound {
61    /// Inclusive bound [value
62    Inclusive(LiteralValue),
63
64    /// Exclusive bound (value
65    Exclusive(LiteralValue),
66
67    /// Unbounded (-∞ or +∞)
68    Unbounded,
69}
70
71impl Shape {
72    /// Create a new shape
73    pub fn new(branches: Vec<ShapeBranch>, free_variables: Vec<FactReference>) -> Self {
74        Shape {
75            branches,
76            free_variables,
77        }
78    }
79
80    /// Check if this shape has any free variables
81    pub fn is_fully_constrained(&self) -> bool {
82        self.free_variables.is_empty()
83    }
84}
85
86// ---------------------------
87// Display implementations
88// ---------------------------
89
90impl fmt::Display for Shape {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        if self.branches.len() == 1 {
93            write!(f, "{}", self.branches[0])
94        } else {
95            writeln!(f, "shape with {} branches:", self.branches.len())?;
96            for (i, br) in self.branches.iter().enumerate() {
97                writeln!(f, "  {}. {}", i + 1, br)?;
98            }
99            Ok(())
100        }
101    }
102}
103
104impl fmt::Display for ShapeBranch {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        write!(f, "if {} then {}", self.condition, self.outcome)
107    }
108}
109
110impl fmt::Display for BranchOutcome {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            BranchOutcome::Value(expr) => write!(f, "{}", expr),
114            BranchOutcome::Veto(Some(msg)) => write!(f, "veto \"{}\"", msg),
115            BranchOutcome::Veto(None) => write!(f, "veto"),
116        }
117    }
118}
119
120impl fmt::Display for Domain {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        match self {
123            Domain::Unconstrained => write!(f, "any"),
124            Domain::Enumeration(vals) => {
125                write!(f, "{{")?;
126                for (i, v) in vals.iter().enumerate() {
127                    if i > 0 {
128                        write!(f, ", ")?;
129                    }
130                    write!(f, "{}", v)?;
131                }
132                write!(f, "}}")
133            }
134            Domain::Range { min, max } => {
135                // Represent ranges in mathematical interval notation: (a, b], [a, +∞), etc.
136                let (l_bracket, r_bracket) = match (min, max) {
137                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
138                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
139                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
140                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
141                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
142                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
143                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
144                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
145                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
146                };
147
148                let min_str = match min {
149                    Bound::Unbounded => "-∞".to_string(),
150                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.to_string(),
151                };
152                let max_str = match max {
153                    Bound::Unbounded => "+∞".to_string(),
154                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.to_string(),
155                };
156                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
157            }
158            Domain::Union(parts) => {
159                for (i, p) in parts.iter().enumerate() {
160                    if i > 0 {
161                        write!(f, " ∪ ")?;
162                    }
163                    write!(f, "{}", p)?;
164                }
165                Ok(())
166            }
167            Domain::Complement(inner) => write!(f, "not ({})", inner),
168        }
169    }
170}
171
172impl fmt::Display for Bound {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        match self {
175            Bound::Unbounded => write!(f, "∞"),
176            Bound::Inclusive(v) => write!(f, "[{}", v),
177            Bound::Exclusive(v) => write!(f, "({}", v),
178        }
179    }
180}
181
182// ---------------------------
183// Serialize implementations
184// ---------------------------
185
186impl Serialize for Shape {
187    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
188    where
189        S: Serializer,
190    {
191        let mut st = serializer.serialize_struct("shape", 2)?;
192        st.serialize_field("branches", &self.branches)?;
193        st.serialize_field("free_variables", &self.free_variables)?;
194        st.end()
195    }
196}
197
198impl Serialize for FactReference {
199    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
200    where
201        S: Serializer,
202    {
203        serializer.serialize_str(&self.reference.join("."))
204    }
205}
206
207impl Serialize for ShapeBranch {
208    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
209    where
210        S: Serializer,
211    {
212        let mut st = serializer.serialize_struct("shape_branch", 2)?;
213        st.serialize_field("condition", &self.condition.to_string())?;
214        st.serialize_field("outcome", &self.outcome)?;
215        st.end()
216    }
217}
218
219impl Serialize for BranchOutcome {
220    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
221    where
222        S: Serializer,
223    {
224        match self {
225            BranchOutcome::Value(expr) => {
226                let mut st = serializer.serialize_map(Some(2))?;
227                st.serialize_entry("type", "value")?;
228                st.serialize_entry("expression", &expr.to_string())?;
229                st.end()
230            }
231            BranchOutcome::Veto(msg) => {
232                let mut st = serializer.serialize_map(Some(2))?;
233                st.serialize_entry("type", "veto")?;
234                if let Some(m) = msg {
235                    st.serialize_entry("message", m)?;
236                }
237                st.end()
238            }
239        }
240    }
241}
242
243impl Serialize for Domain {
244    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
245    where
246        S: Serializer,
247    {
248        match self {
249            Domain::Unconstrained => {
250                let mut st = serializer.serialize_struct("domain", 1)?;
251                st.serialize_field("type", "unconstrained")?;
252                st.end()
253            }
254            Domain::Enumeration(vals) => {
255                let mut st = serializer.serialize_struct("domain", 2)?;
256                st.serialize_field("type", "enumeration")?;
257                st.serialize_field("values", vals)?;
258                st.end()
259            }
260            Domain::Range { min, max } => {
261                let mut st = serializer.serialize_struct("domain", 3)?;
262                st.serialize_field("type", "range")?;
263                st.serialize_field("min", min)?;
264                st.serialize_field("max", max)?;
265                st.end()
266            }
267            Domain::Union(parts) => {
268                let mut st = serializer.serialize_struct("domain", 2)?;
269                st.serialize_field("type", "union")?;
270                st.serialize_field("parts", parts)?;
271                st.end()
272            }
273            Domain::Complement(inner) => {
274                let mut st = serializer.serialize_struct("domain", 2)?;
275                st.serialize_field("type", "complement")?;
276                st.serialize_field("inner", inner)?;
277                st.end()
278            }
279        }
280    }
281}
282
283impl Serialize for Bound {
284    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
285    where
286        S: Serializer,
287    {
288        match self {
289            Bound::Unbounded => {
290                let mut st = serializer.serialize_struct("bound", 1)?;
291                st.serialize_field("type", "unbounded")?;
292                st.end()
293            }
294            Bound::Inclusive(v) => {
295                let mut st = serializer.serialize_struct("bound", 2)?;
296                st.serialize_field("type", "inclusive")?;
297                st.serialize_field("value", v)?;
298                st.end()
299            }
300            Bound::Exclusive(v) => {
301                let mut st = serializer.serialize_struct("bound", 2)?;
302                st.serialize_field("type", "exclusive")?;
303                st.serialize_field("value", v)?;
304                st.end()
305            }
306        }
307    }
308}