Skip to main content

kermit_algos/
const_rewrite.rs

1//! Const-view rewrite implementing Veldhuizen 2014 §3.4 point 4.
2//!
3//! Transforms body atoms (e.g. `p(X, c42)`) into fresh variables
4//! filtered by synthetic unary `Const_c42` predicates, so the existing
5//! LFTJ engine can handle them without modification. Intended to run
6//! immediately before [`crate::JoinAlgo::join_iter`].
7
8use {
9    kermit_parser::{JoinQuery, Predicate, Term},
10    std::fmt,
11};
12
13/// Error returned by [`rewrite_atoms`] when an atom does not match the
14/// expected `c<digits>` shape.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum RewriteError {
17    /// An atom was not of the form `c<digits>`. kermit currently only
18    /// supports constants encoded as dictionary IDs using this
19    /// convention.
20    BadAtom(String),
21}
22
23impl fmt::Display for RewriteError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            | RewriteError::BadAtom(s) => write!(
27                f,
28                "atom {s:?} does not match the expected c<digits> shape — kermit currently only \
29                 supports constants encoded as dictionary IDs",
30            ),
31        }
32    }
33}
34
35impl std::error::Error for RewriteError {}
36
37/// Pairs a synthetic predicate name (e.g. `"Const_c42"`) with its
38/// dictionary ID. One entry is produced per rewritten atom occurrence.
39pub type ConstSpec = (String, usize);
40
41/// Rewrites `query.body`: each `Term::Atom("c<id>")` becomes a fresh
42/// variable `K<i>`, with a new unary predicate `Const_c<id>(K<i>)`
43/// appended to the body.
44///
45/// Each atom occurrence gets its own fresh variable, even if the same
46/// dictionary ID appears multiple times. This avoids forcing equality
47/// between unrelated body positions.
48///
49/// # Head asymmetry
50///
51/// **Only body atoms are rewritten.** Head atoms (e.g.
52/// `Q(c5) :- p(X).`) are left unchanged. The head list describes the
53/// output shape and does not flow through the LFTJ engine the way
54/// body predicates do, so filtering there is the parser / caller's
55/// responsibility. The preprocessor emits queries of the form
56/// `Head(V0, …, Vn) :- body.` where every head term is a variable, so
57/// in practice head atoms never reach this function from the WatDiv
58/// pipeline. Keep this asymmetry in mind if authoring queries by
59/// hand: a `Term::Atom` in the head position will not be filtered.
60///
61/// # Errors
62///
63/// Returns [`RewriteError::BadAtom`] if any atom doesn't match `c\d+`.
64pub fn rewrite_atoms(mut query: JoinQuery) -> Result<(JoinQuery, Vec<ConstSpec>), RewriteError> {
65    let mut next_k = highest_k_index(&query).map_or(0, |n| n + 1);
66    let mut specs: Vec<ConstSpec> = Vec::new();
67    let mut new_preds: Vec<Predicate> = Vec::new();
68
69    // Body only — head atoms are intentionally not rewritten; see the
70    // "Head asymmetry" section of this function's doc-comment.
71    for pred in &mut query.body {
72        for term in &mut pred.terms {
73            let atom = match term {
74                | Term::Atom(s) => s.clone(),
75                | _ => continue,
76            };
77            let id = parse_const_atom(&atom)?;
78            let fresh = format!("K{next_k}");
79            next_k += 1;
80            *term = Term::Var(fresh.clone());
81            let const_name = format!("Const_{atom}");
82            new_preds.push(Predicate {
83                name: const_name.clone(),
84                terms: vec![Term::Var(fresh)],
85            });
86            specs.push((const_name, id));
87        }
88    }
89    query.body.extend(new_preds);
90    Ok((query, specs))
91}
92
93fn parse_const_atom(s: &str) -> Result<usize, RewriteError> {
94    let rest = s
95        .strip_prefix('c')
96        .ok_or_else(|| RewriteError::BadAtom(s.to_string()))?;
97    if rest.is_empty() || !rest.chars().all(|c| c.is_ascii_digit()) {
98        return Err(RewriteError::BadAtom(s.to_string()));
99    }
100    rest.parse::<usize>()
101        .map_err(|_| RewriteError::BadAtom(s.to_string()))
102}
103
104fn highest_k_index(query: &JoinQuery) -> Option<usize> {
105    let scan = |p: &Predicate| -> Option<usize> {
106        p.terms
107            .iter()
108            .filter_map(|t| match t {
109                | Term::Var(name) => name.strip_prefix('K').and_then(|r| r.parse::<usize>().ok()),
110                | _ => None,
111            })
112            .max()
113    };
114    query
115        .body
116        .iter()
117        .chain(std::iter::once(&query.head))
118        .filter_map(scan)
119        .max()
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    fn parse(q: &str) -> JoinQuery { q.parse().unwrap() }
127
128    #[test]
129    fn zero_atoms_is_identity() {
130        let q = parse("Q(X) :- p(X), r(X, Y).");
131        let (out, specs) = rewrite_atoms(q.clone()).unwrap();
132        assert_eq!(out, q);
133        assert!(specs.is_empty());
134    }
135
136    #[test]
137    fn single_atom_produces_one_fresh_var_and_one_const_pred() {
138        let q = parse("Q(X) :- p(X, c42).");
139        let (out, specs) = rewrite_atoms(q).unwrap();
140        assert_eq!(out.body.len(), 2);
141        assert_eq!(out.body[0].name, "p");
142        assert!(matches!(out.body[0].terms[1], Term::Var(ref n) if n == "K0"));
143        assert_eq!(out.body[1].name, "Const_c42");
144        assert!(matches!(out.body[1].terms[0], Term::Var(ref n) if n == "K0"));
145        assert_eq!(specs, vec![("Const_c42".into(), 42)]);
146    }
147
148    #[test]
149    fn multiple_atoms_get_distinct_fresh_vars() {
150        let q = parse("Q(X) :- p(X, c42), r(Y, c99).");
151        let (out, specs) = rewrite_atoms(q).unwrap();
152        assert_eq!(out.body.len(), 4);
153        assert_eq!(specs, vec![
154            ("Const_c42".into(), 42),
155            ("Const_c99".into(), 99),
156        ]);
157    }
158
159    #[test]
160    fn repeated_atom_value_gets_distinct_vars_but_same_const_pred() {
161        let q = parse("Q(X) :- p(X, c5), r(Y, c5).");
162        let (out, specs) = rewrite_atoms(q).unwrap();
163        assert_eq!(out.body.len(), 4);
164        assert_eq!(specs.len(), 2);
165        assert_eq!(specs[0].0, "Const_c5");
166        assert_eq!(specs[1].0, "Const_c5");
167        let k0 = match &out.body[0].terms[1] {
168            | Term::Var(n) => n.clone(),
169            | _ => panic!(),
170        };
171        let k1 = match &out.body[1].terms[1] {
172            | Term::Var(n) => n.clone(),
173            | _ => panic!(),
174        };
175        assert_ne!(k0, k1);
176    }
177
178    #[test]
179    fn fresh_var_allocation_avoids_existing_k_names() {
180        let q = parse("Q(K5) :- p(K5, c7).");
181        let (out, _) = rewrite_atoms(q).unwrap();
182        let fresh = match &out.body[0].terms[1] {
183            | Term::Var(n) => n.clone(),
184            | _ => panic!(),
185        };
186        let n: usize = fresh.strip_prefix('K').unwrap().parse().unwrap();
187        assert!(n > 5, "got {fresh}, expected > K5");
188    }
189
190    #[test]
191    fn malformed_atom_errors() {
192        for bad in ["foo", "c", "c1x", "cc5", "x42"] {
193            let q = JoinQuery {
194                head: Predicate {
195                    name: "Q".into(),
196                    terms: vec![Term::Var("X".into())],
197                },
198                body: vec![Predicate {
199                    name: "p".into(),
200                    terms: vec![Term::Var("X".into()), Term::Atom(bad.into())],
201                }],
202            };
203            assert!(
204                matches!(rewrite_atoms(q), Err(RewriteError::BadAtom(_))),
205                "expected error for {bad}"
206            );
207        }
208    }
209
210    #[test]
211    fn placeholders_left_alone() {
212        let q = parse("Q(X) :- p(X, _), r(_, c7).");
213        let (out, specs) = rewrite_atoms(q).unwrap();
214        assert_eq!(out.body.len(), 3);
215        assert!(matches!(out.body[0].terms[1], Term::Placeholder));
216        assert_eq!(specs, vec![("Const_c7".into(), 7)]);
217    }
218
219    #[test]
220    fn head_atoms_are_not_rewritten() {
221        // Head atoms are outside this function's contract; see the
222        // module docstring. This test pins the asymmetry so a future
223        // refactor can't accidentally start rewriting head terms.
224        let q = JoinQuery {
225            head: Predicate {
226                name: "Q".into(),
227                terms: vec![Term::Atom("c5".into()), Term::Var("X".into())],
228            },
229            body: vec![Predicate {
230                name: "p".into(),
231                terms: vec![Term::Var("X".into()), Term::Atom("c7".into())],
232            }],
233        };
234        let (out, specs) = rewrite_atoms(q).unwrap();
235        assert!(matches!(out.head.terms[0], Term::Atom(ref s) if s == "c5"));
236        assert!(matches!(out.head.terms[1], Term::Var(ref n) if n == "X"));
237        assert_eq!(specs, vec![("Const_c7".into(), 7)]);
238    }
239}