1use crate::premise::{PremiseId, PremiseTracker};
13use crate::proof::{Proof, ProofNodeId, ProofStep};
14use rustc_hash::{FxHashMap, FxHashSet};
15use std::fmt;
16
17#[derive(Debug, Clone)]
19pub struct Partition {
20 a_premises: FxHashSet<PremiseId>,
22 b_premises: FxHashSet<PremiseId>,
24}
25
26impl Partition {
27 #[must_use]
29 pub fn new(
30 a_premises: impl IntoIterator<Item = PremiseId>,
31 b_premises: impl IntoIterator<Item = PremiseId>,
32 ) -> Self {
33 Self {
34 a_premises: a_premises.into_iter().collect(),
35 b_premises: b_premises.into_iter().collect(),
36 }
37 }
38
39 #[must_use]
41 pub fn is_a_premise(&self, premise: PremiseId) -> bool {
42 self.a_premises.contains(&premise)
43 }
44
45 #[must_use]
47 pub fn is_b_premise(&self, premise: PremiseId) -> bool {
48 self.b_premises.contains(&premise)
49 }
50
51 #[must_use]
53 pub fn a_premises(&self) -> &FxHashSet<PremiseId> {
54 &self.a_premises
55 }
56
57 #[must_use]
59 pub fn b_premises(&self) -> &FxHashSet<PremiseId> {
60 &self.b_premises
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum Color {
67 A,
69 B,
71 AB,
73}
74
75impl fmt::Display for Color {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 Self::A => write!(f, "A"),
79 Self::B => write!(f, "B"),
80 Self::AB => write!(f, "AB"),
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct Interpolant {
88 pub formula: String,
90 pub symbols: FxHashSet<String>,
92}
93
94impl Interpolant {
95 #[must_use]
97 pub fn new(formula: impl Into<String>) -> Self {
98 let formula = formula.into();
99 let symbols = extract_symbols(&formula);
100 Self { formula, symbols }
101 }
102
103 #[must_use]
105 pub fn is_valid(&self, a_symbols: &FxHashSet<String>, b_symbols: &FxHashSet<String>) -> bool {
106 let common: FxHashSet<String> = a_symbols.intersection(b_symbols).cloned().collect();
107 self.symbols.is_subset(&common)
108 }
109}
110
111impl fmt::Display for Interpolant {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 write!(f, "{}", self.formula)
114 }
115}
116
117#[derive(Debug)]
122pub struct InterpolantExtractor {
123 #[allow(dead_code)]
125 partition: Partition,
126 #[allow(dead_code)]
128 premise_tracker: PremiseTracker,
129 colors: FxHashMap<ProofNodeId, Color>,
131 partial_interpolants: FxHashMap<ProofNodeId, String>,
133}
134
135impl InterpolantExtractor {
136 #[must_use]
138 pub fn new(partition: Partition, premise_tracker: PremiseTracker) -> Self {
139 Self {
140 partition,
141 premise_tracker,
142 colors: FxHashMap::default(),
143 partial_interpolants: FxHashMap::default(),
144 }
145 }
146
147 pub fn extract(&mut self, proof: &Proof) -> Result<Interpolant, String> {
151 let root = proof
152 .root()
153 .ok_or_else(|| "Proof has no root".to_string())?;
154
155 self.compute_colors(proof, root)?;
157
158 self.build_interpolants(proof, root)?;
160
161 let root_color = self
163 .colors
164 .get(&root)
165 .copied()
166 .ok_or_else(|| "Root node has no color".to_string())?;
167
168 if root_color != Color::AB {
169 return Err(format!(
170 "Root should have color AB, but has color {}",
171 root_color
172 ));
173 }
174
175 let interpolant_formula = self
177 .partial_interpolants
178 .get(&root)
179 .ok_or_else(|| "No interpolant at root".to_string())?
180 .clone();
181
182 Ok(Interpolant::new(interpolant_formula))
183 }
184
185 fn compute_colors(&mut self, proof: &Proof, node_id: ProofNodeId) -> Result<Color, String> {
187 if let Some(&color) = self.colors.get(&node_id) {
189 return Ok(color);
190 }
191
192 let node = proof
193 .get_node(node_id)
194 .ok_or_else(|| format!("Node {} not found", node_id))?;
195
196 let color = match &node.step {
197 ProofStep::Axiom { .. } => {
198 Color::A
201 }
202 ProofStep::Inference { premises, .. } => {
203 let mut has_a = false;
205 let mut has_b = false;
206
207 for &premise_id in premises {
208 let premise_color = self.compute_colors(proof, premise_id)?;
209 match premise_color {
210 Color::A => has_a = true,
211 Color::B => has_b = true,
212 Color::AB => {
213 has_a = true;
214 has_b = true;
215 }
216 }
217 }
218
219 if has_a && has_b {
220 Color::AB
221 } else if has_a {
222 Color::A
223 } else if has_b {
224 Color::B
225 } else {
226 Color::A
228 }
229 }
230 };
231
232 self.colors.insert(node_id, color);
233 Ok(color)
234 }
235
236 fn build_interpolants(&mut self, proof: &Proof, node_id: ProofNodeId) -> Result<(), String> {
238 if self.partial_interpolants.contains_key(&node_id) {
240 return Ok(());
241 }
242
243 let node = proof
244 .get_node(node_id)
245 .ok_or_else(|| format!("Node {} not found", node_id))?;
246
247 let color = *self
248 .colors
249 .get(&node_id)
250 .ok_or_else(|| format!("Node {} has no color", node_id))?;
251
252 let interpolant = match &node.step {
253 ProofStep::Axiom { conclusion } => {
254 match color {
256 Color::A => "true".to_string(),
257 Color::B => conclusion.clone(),
258 Color::AB => {
259 "true".to_string()
261 }
262 }
263 }
264 ProofStep::Inference {
265 rule,
266 premises,
267 conclusion,
268 ..
269 } => {
270 for &premise_id in premises {
272 self.build_interpolants(proof, premise_id)?;
273 }
274
275 self.combine_interpolants(rule, premises, conclusion, color)?
277 }
278 };
279
280 self.partial_interpolants.insert(node_id, interpolant);
281 Ok(())
282 }
283
284 fn combine_interpolants(
286 &self,
287 rule: &str,
288 premises: &[ProofNodeId],
289 _conclusion: &str,
290 color: Color,
291 ) -> Result<String, String> {
292 match color {
293 Color::A => {
294 Ok("true".to_string())
296 }
297 Color::B => {
298 Ok("(b-node)".to_string())
301 }
302 Color::AB => {
303 if rule == "resolution" && premises.len() == 2 {
305 let i1 = self
306 .partial_interpolants
307 .get(&premises[0])
308 .ok_or_else(|| format!("No interpolant for premise {}", premises[0]))?;
309 let i2 = self
310 .partial_interpolants
311 .get(&premises[1])
312 .ok_or_else(|| format!("No interpolant for premise {}", premises[1]))?;
313
314 let c1 = self.colors.get(&premises[0]).copied().unwrap_or(Color::A);
315 let c2 = self.colors.get(&premises[1]).copied().unwrap_or(Color::A);
316
317 match (c1, c2) {
319 (Color::A, Color::B) | (Color::B, Color::A) => {
320 Ok(format!("(or {} {})", i1, i2))
322 }
323 (Color::A, Color::AB) | (Color::AB, Color::A) => {
324 if c1 == Color::AB {
326 Ok(i1.clone())
327 } else {
328 Ok(i2.clone())
329 }
330 }
331 (Color::B, Color::AB) | (Color::AB, Color::B) => {
332 if c1 == Color::AB {
334 Ok(i1.clone())
335 } else {
336 Ok(i2.clone())
337 }
338 }
339 (Color::AB, Color::AB) => {
340 Ok(format!("(or {} {})", i1, i2))
342 }
343 _ => {
344 Ok("(error)".to_string())
346 }
347 }
348 } else {
349 Ok("(combined)".to_string())
351 }
352 }
353 }
354 }
355}
356
357fn extract_symbols(formula: &str) -> FxHashSet<String> {
361 let mut symbols = FxHashSet::default();
362
363 let keywords: FxHashSet<&str> = [
365 "and", "or", "not", "implies", "iff", "xor", "forall", "exists", "true", "false", "let",
366 "ite", "distinct",
367 ]
368 .iter()
369 .copied()
370 .collect();
371
372 let mut current = String::new();
374 for ch in formula.chars() {
375 if ch.is_alphanumeric() || ch == '_' {
376 current.push(ch);
377 } else {
378 if !current.is_empty()
379 && !current.chars().all(|c| c.is_numeric())
380 && !keywords.contains(current.as_str())
381 {
382 symbols.insert(current.clone());
383 }
384 current.clear();
385 }
386 }
387
388 if !current.is_empty()
389 && !current.chars().all(|c| c.is_numeric())
390 && !keywords.contains(current.as_str())
391 {
392 symbols.insert(current);
393 }
394
395 symbols
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_partition_creation() {
404 let partition = Partition::new(vec![PremiseId(0), PremiseId(1)], vec![PremiseId(2)]);
405
406 assert!(partition.is_a_premise(PremiseId(0)));
407 assert!(partition.is_a_premise(PremiseId(1)));
408 assert!(partition.is_b_premise(PremiseId(2)));
409 assert!(!partition.is_b_premise(PremiseId(0)));
410 }
411
412 #[test]
413 fn test_color_display() {
414 assert_eq!(format!("{}", Color::A), "A");
415 assert_eq!(format!("{}", Color::B), "B");
416 assert_eq!(format!("{}", Color::AB), "AB");
417 }
418
419 #[test]
420 fn test_interpolant_creation() {
421 let interp = Interpolant::new("(and p q)");
422 assert_eq!(interp.formula, "(and p q)");
423 assert!(!interp.symbols.contains("and"));
425 assert!(interp.symbols.contains("p"));
426 assert!(interp.symbols.contains("q"));
427 }
428
429 #[test]
430 fn test_interpolant_validity() {
431 let interp = Interpolant::new("(and x y)");
432
433 let mut a_symbols = FxHashSet::default();
434 a_symbols.insert("x".to_string());
435 a_symbols.insert("y".to_string());
436 a_symbols.insert("z".to_string());
437
438 let mut b_symbols = FxHashSet::default();
439 b_symbols.insert("x".to_string());
440 b_symbols.insert("y".to_string());
441 b_symbols.insert("w".to_string());
442
443 assert!(interp.is_valid(&a_symbols, &b_symbols));
445 }
446
447 #[test]
448 fn test_interpolant_invalid() {
449 let interp = Interpolant::new("(and x z)");
450
451 let mut a_symbols = FxHashSet::default();
452 a_symbols.insert("x".to_string());
453 a_symbols.insert("z".to_string());
454
455 let mut b_symbols = FxHashSet::default();
456 b_symbols.insert("x".to_string());
457 b_symbols.insert("w".to_string());
458
459 assert!(!interp.is_valid(&a_symbols, &b_symbols));
461 }
462
463 #[test]
464 fn test_extract_symbols() {
465 let symbols = extract_symbols("(and x (or y z))");
466 assert!(!symbols.contains("and"));
468 assert!(symbols.contains("x"));
469 assert!(!symbols.contains("or"));
470 assert!(symbols.contains("y"));
471 assert!(symbols.contains("z"));
472 }
473
474 #[test]
475 fn test_extract_symbols_numbers() {
476 let symbols = extract_symbols("(= x 42)");
477 assert!(symbols.contains("x"));
478 assert!(!symbols.contains("42"));
480 }
481
482 #[test]
483 fn test_extractor_creation() {
484 let partition = Partition::new(vec![PremiseId(0)], vec![PremiseId(1)]);
485 let tracker = PremiseTracker::new();
486 let extractor = InterpolantExtractor::new(partition, tracker);
487
488 assert_eq!(extractor.colors.len(), 0);
489 assert_eq!(extractor.partial_interpolants.len(), 0);
490 }
491
492 #[test]
493 fn test_simple_interpolant_extraction() {
494 let partition = Partition::new(vec![PremiseId(0)], vec![PremiseId(1)]);
495 let tracker = PremiseTracker::new();
496 let mut extractor = InterpolantExtractor::new(partition, tracker);
497
498 let mut proof = Proof::new();
500 proof.add_axiom("p");
501
502 let result = extractor.extract(&proof);
505 assert!(result.is_err());
506 }
507}