1use {
9 kermit_parser::{JoinQuery, Predicate, Term},
10 std::fmt,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum RewriteError {
17 BadAtom(String),
21}
22
23impl fmt::Display for RewriteError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 | RewriteError::BadAtom(s) => write!(
27 f,
28 "atom {s:?} does not match the expected c<digits> shape — kermit currently only \
29 supports constants encoded as dictionary IDs",
30 ),
31 }
32 }
33}
34
35impl std::error::Error for RewriteError {}
36
37pub type ConstSpec = (String, usize);
40
41pub fn rewrite_atoms(mut query: JoinQuery) -> Result<(JoinQuery, Vec<ConstSpec>), RewriteError> {
65 let mut next_k = highest_k_index(&query).map_or(0, |n| n + 1);
66 let mut specs: Vec<ConstSpec> = Vec::new();
67 let mut new_preds: Vec<Predicate> = Vec::new();
68
69 for pred in &mut query.body {
72 for term in &mut pred.terms {
73 let atom = match term {
74 | Term::Atom(s) => s.clone(),
75 | _ => continue,
76 };
77 let id = parse_const_atom(&atom)?;
78 let fresh = format!("K{next_k}");
79 next_k += 1;
80 *term = Term::Var(fresh.clone());
81 let const_name = format!("Const_{atom}");
82 new_preds.push(Predicate {
83 name: const_name.clone(),
84 terms: vec![Term::Var(fresh)],
85 });
86 specs.push((const_name, id));
87 }
88 }
89 query.body.extend(new_preds);
90 Ok((query, specs))
91}
92
93fn parse_const_atom(s: &str) -> Result<usize, RewriteError> {
94 let rest = s
95 .strip_prefix('c')
96 .ok_or_else(|| RewriteError::BadAtom(s.to_string()))?;
97 if rest.is_empty() || !rest.chars().all(|c| c.is_ascii_digit()) {
98 return Err(RewriteError::BadAtom(s.to_string()));
99 }
100 rest.parse::<usize>()
101 .map_err(|_| RewriteError::BadAtom(s.to_string()))
102}
103
104fn highest_k_index(query: &JoinQuery) -> Option<usize> {
105 let scan = |p: &Predicate| -> Option<usize> {
106 p.terms
107 .iter()
108 .filter_map(|t| match t {
109 | Term::Var(name) => name.strip_prefix('K').and_then(|r| r.parse::<usize>().ok()),
110 | _ => None,
111 })
112 .max()
113 };
114 query
115 .body
116 .iter()
117 .chain(std::iter::once(&query.head))
118 .filter_map(scan)
119 .max()
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 fn parse(q: &str) -> JoinQuery { q.parse().unwrap() }
127
128 #[test]
129 fn zero_atoms_is_identity() {
130 let q = parse("Q(X) :- p(X), r(X, Y).");
131 let (out, specs) = rewrite_atoms(q.clone()).unwrap();
132 assert_eq!(out, q);
133 assert!(specs.is_empty());
134 }
135
136 #[test]
137 fn single_atom_produces_one_fresh_var_and_one_const_pred() {
138 let q = parse("Q(X) :- p(X, c42).");
139 let (out, specs) = rewrite_atoms(q).unwrap();
140 assert_eq!(out.body.len(), 2);
141 assert_eq!(out.body[0].name, "p");
142 assert!(matches!(out.body[0].terms[1], Term::Var(ref n) if n == "K0"));
143 assert_eq!(out.body[1].name, "Const_c42");
144 assert!(matches!(out.body[1].terms[0], Term::Var(ref n) if n == "K0"));
145 assert_eq!(specs, vec![("Const_c42".into(), 42)]);
146 }
147
148 #[test]
149 fn multiple_atoms_get_distinct_fresh_vars() {
150 let q = parse("Q(X) :- p(X, c42), r(Y, c99).");
151 let (out, specs) = rewrite_atoms(q).unwrap();
152 assert_eq!(out.body.len(), 4);
153 assert_eq!(specs, vec![
154 ("Const_c42".into(), 42),
155 ("Const_c99".into(), 99),
156 ]);
157 }
158
159 #[test]
160 fn repeated_atom_value_gets_distinct_vars_but_same_const_pred() {
161 let q = parse("Q(X) :- p(X, c5), r(Y, c5).");
162 let (out, specs) = rewrite_atoms(q).unwrap();
163 assert_eq!(out.body.len(), 4);
164 assert_eq!(specs.len(), 2);
165 assert_eq!(specs[0].0, "Const_c5");
166 assert_eq!(specs[1].0, "Const_c5");
167 let k0 = match &out.body[0].terms[1] {
168 | Term::Var(n) => n.clone(),
169 | _ => panic!(),
170 };
171 let k1 = match &out.body[1].terms[1] {
172 | Term::Var(n) => n.clone(),
173 | _ => panic!(),
174 };
175 assert_ne!(k0, k1);
176 }
177
178 #[test]
179 fn fresh_var_allocation_avoids_existing_k_names() {
180 let q = parse("Q(K5) :- p(K5, c7).");
181 let (out, _) = rewrite_atoms(q).unwrap();
182 let fresh = match &out.body[0].terms[1] {
183 | Term::Var(n) => n.clone(),
184 | _ => panic!(),
185 };
186 let n: usize = fresh.strip_prefix('K').unwrap().parse().unwrap();
187 assert!(n > 5, "got {fresh}, expected > K5");
188 }
189
190 #[test]
191 fn malformed_atom_errors() {
192 for bad in ["foo", "c", "c1x", "cc5", "x42"] {
193 let q = JoinQuery {
194 head: Predicate {
195 name: "Q".into(),
196 terms: vec![Term::Var("X".into())],
197 },
198 body: vec![Predicate {
199 name: "p".into(),
200 terms: vec![Term::Var("X".into()), Term::Atom(bad.into())],
201 }],
202 };
203 assert!(
204 matches!(rewrite_atoms(q), Err(RewriteError::BadAtom(_))),
205 "expected error for {bad}"
206 );
207 }
208 }
209
210 #[test]
211 fn placeholders_left_alone() {
212 let q = parse("Q(X) :- p(X, _), r(_, c7).");
213 let (out, specs) = rewrite_atoms(q).unwrap();
214 assert_eq!(out.body.len(), 3);
215 assert!(matches!(out.body[0].terms[1], Term::Placeholder));
216 assert_eq!(specs, vec![("Const_c7".into(), 7)]);
217 }
218
219 #[test]
220 fn head_atoms_are_not_rewritten() {
221 let q = JoinQuery {
225 head: Predicate {
226 name: "Q".into(),
227 terms: vec![Term::Atom("c5".into()), Term::Var("X".into())],
228 },
229 body: vec![Predicate {
230 name: "p".into(),
231 terms: vec![Term::Var("X".into()), Term::Atom("c7".into())],
232 }],
233 };
234 let (out, specs) = rewrite_atoms(q).unwrap();
235 assert!(matches!(out.head.terms[0], Term::Atom(ref s) if s == "c5"));
236 assert!(matches!(out.head.terms[1], Term::Var(ref n) if n == "X"));
237 assert_eq!(specs, vec![("Const_c7".into(), 7)]);
238 }
239}