1use nom::branch::alt;
2use nom::bytes::complete::{tag, take_until, take_while, take_while1, take_while_m_n};
3use nom::character::complete::{i64, multispace1};
4use nom::character::is_hex_digit;
5use nom::combinator::{cut, map, map_res, opt, value};
6use nom::error::{context, ContextError, FromExternalError, ParseError};
7use nom::multi::{many0, separated_list0, separated_list1};
8use nom::sequence::{delimited, preceded, separated_pair, terminated, tuple};
9use nom::IResult;
10
11#[cfg_attr(
12 feature = "serde_internal",
13 derive(serde::Serialize, serde::Deserialize)
14)]
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub enum ParseTerm {
17 Variable(String),
18 Bool(bool),
19 Integer(i64),
20 String(String),
21 Uuid(uuid::Uuid),
22}
23
24impl std::fmt::Display for ParseTerm {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 ParseTerm::Variable(v) => write!(f, "{v}"),
28 ParseTerm::Bool(v) => write!(f, "{v}"),
29 ParseTerm::Integer(v) => write!(f, "{v}"),
30 ParseTerm::String(s) => {
31 if s.chars().any(|c| c.is_whitespace()) {
32 write!(f, "\"{s}\"")
33 } else {
34 write!(f, "{s}")
35 }
36 }
37 ParseTerm::Uuid(id) => write!(f, "#{id}"),
38 }
39 }
40}
41
42#[cfg_attr(
43 feature = "serde_internal",
44 derive(serde::Serialize, serde::Deserialize)
45)]
46#[derive(Debug, PartialEq, Eq, Clone, Hash)]
47pub struct Predicate {
48 pub is_intrinsic: bool,
49 pub name: String,
50}
51
52impl std::fmt::Display for Predicate {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 if self.is_intrinsic {
55 write!(f, "@{}", self.name)
56 } else {
57 write!(f, "{}", self.name)
58 }
59 }
60}
61
62#[cfg_attr(
63 feature = "serde_internal",
64 derive(serde::Serialize, serde::Deserialize)
65)]
66#[derive(Debug, PartialEq, Eq, Clone)]
67pub struct Atom {
68 pub predicate: Predicate,
69 pub terms: Vec<ParseTerm>,
70}
71
72impl std::fmt::Display for Atom {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "{}", self.predicate)?;
75 write!(f, "(")?;
76 if let Some(fr) = self.terms.first() {
77 write!(f, "{}", fr)?;
78 }
79 for term in self.terms.iter().skip(1) {
80 write!(f, ",")?;
81 write!(f, "{}", term)?;
82 }
83 write!(f, ")")
84 }
85}
86
87#[cfg_attr(
88 feature = "serde_internal",
89 derive(serde::Serialize, serde::Deserialize)
90)]
91#[derive(Debug, PartialEq, Eq, Clone)]
92pub enum BodyAtom {
93 Positive(Atom),
94 Negative(Atom),
95}
96
97impl std::fmt::Display for BodyAtom {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 Self::Negative(atom) => write!(f, "not {atom}"),
101 Self::Positive(atom) => write!(f, "{atom}"),
102 }
103 }
104}
105
106impl BodyAtom {
107 pub fn atom(&self) -> &Atom {
108 match self {
109 Self::Positive(a) => a,
110 Self::Negative(n) => n,
111 }
112 }
113}
114
115#[cfg_attr(
116 feature = "serde_internal",
117 derive(serde::Serialize, serde::Deserialize)
118)]
119#[derive(Debug, PartialEq, Eq)]
120pub enum Constraint {
121 Fact(Atom),
122 Goal(BodyAtom),
123 Rule { head: Atom, body: Vec<BodyAtom> },
124}
125
126#[derive(Debug, PartialEq, Eq)]
127pub enum Ast<'a> {
128 ParseTerm(&'a ParseTerm),
129 Atom(&'a Atom),
130 BodyAtom(&'a BodyAtom),
131 Constraint(&'a Constraint),
132}
133
134impl<'a> From<&'a Constraint> for Ast<'a> {
135 fn from(value: &'a Constraint) -> Self {
136 Ast::Constraint(value)
137 }
138}
139
140impl<'a> From<&'a BodyAtom> for Ast<'a> {
141 fn from(value: &'a BodyAtom) -> Self {
142 Ast::BodyAtom(value)
143 }
144}
145
146impl<'a> From<&'a Atom> for Ast<'a> {
147 fn from(value: &'a Atom) -> Self {
148 Ast::Atom(value)
149 }
150}
151
152impl<'a> From<&'a ParseTerm> for Ast<'a> {
153 fn from(value: &'a ParseTerm) -> Self {
154 Ast::ParseTerm(value)
155 }
156}
157
158pub trait Visitor {
159 type Output;
160
161 fn visit_parse_term(&self, term: &ParseTerm) -> Self::Output;
162 fn visit_atom(&self, atom: &Atom) -> Self::Output;
163 fn visit_body_atom(&self, body_atom: &BodyAtom) -> Self::Output;
164 fn visit_constraint(&self, constraint: &Constraint) -> Self::Output;
165
166 fn visit(&self, ast: Ast) -> Self::Output {
167 match ast {
168 Ast::ParseTerm(r) => self.visit_parse_term(r),
169 Ast::Atom(r) => self.visit_atom(r),
170 Ast::BodyAtom(r) => self.visit_body_atom(r),
171 Ast::Constraint(r) => self.visit_constraint(r),
172 }
173 }
174}
175
176pub trait VisitorMut {
177 type Output;
178
179 fn visit_parse_term(&mut self, term: &ParseTerm) -> Self::Output;
180 fn visit_atom(&mut self, atom: &Atom) -> Self::Output;
181 fn visit_body_atom(&mut self, body_atom: &BodyAtom) -> Self::Output;
182 fn visit_constraint(&mut self, constraint: &Constraint) -> Self::Output;
183
184 fn visit(&mut self, ast: Ast) -> Self::Output {
185 match ast {
186 Ast::ParseTerm(r) => self.visit_parse_term(r),
187 Ast::Atom(r) => self.visit_atom(r),
188 Ast::BodyAtom(r) => self.visit_body_atom(r),
189 Ast::Constraint(r) => self.visit_constraint(r),
190 }
191 }
192}
193
194impl<T: Visitor> VisitorMut for T {
195 type Output = <T as Visitor>::Output;
196
197 fn visit_parse_term(&mut self, term: &ParseTerm) -> Self::Output {
198 <T as Visitor>::visit_parse_term(self, term)
199 }
200
201 fn visit_atom(&mut self, atom: &Atom) -> Self::Output {
202 <T as Visitor>::visit_atom(self, atom)
203 }
204
205 fn visit_body_atom(&mut self, body_atom: &BodyAtom) -> Self::Output {
206 <T as Visitor>::visit_body_atom(self, body_atom)
207 }
208
209 fn visit_constraint(&mut self, constraint: &Constraint) -> Self::Output {
210 <T as Visitor>::visit_constraint(self, constraint)
211 }
212}
213
214fn parse_bool<'input, E: ParseError<&'input str> + ContextError<&'input str>>(
215 input: &'input str,
216) -> IResult<&str, bool, E> {
217 context(
218 "bool",
219 alt((value(true, tag("true")), value(false, tag("false")))),
220 )(input)
221}
222
223fn is_ident_char(c: char) -> bool {
224 "_!~+-*/&|".contains(c)
225}
226
227fn identifier<'input, E: ParseError<&'input str>>(input: &'input str) -> IResult<&str, String, E> {
228 map(
229 tuple((
230 take_while1(|s: char| s.is_alphabetic() || is_ident_char(s)),
231 take_while(|s: char| s.is_alphanumeric() || is_ident_char(s)),
232 )),
233 |(s, t): (&str, &str)| format!("{s}{t}"),
234 )(input)
235}
236
237fn cap_identifier<'input, E: ParseError<&'input str>>(
238 input: &'input str,
239) -> IResult<&str, String, E> {
240 map(
241 tuple((
242 take_while1(|s: char| s.is_uppercase() && s.is_alphabetic()),
243 take_while(|s: char| s.is_alphanumeric() || is_ident_char(s)),
244 )),
245 |(s, t): (&str, &str)| format!("{s}{t}"),
246 )(input)
247}
248
249#[derive(thiserror::Error, Debug)]
250pub enum TermError {
251 #[error("uuid error: {0}")]
252 Uuid(#[source] uuid::Error),
253 #[error("predicate cannot be used")]
254 PredicateNotError,
255}
256
257impl From<uuid::Error> for TermError {
258 fn from(value: uuid::Error) -> Self {
259 Self::Uuid(value)
260 }
261}
262
263fn parse_term<
264 'input,
265 E: ParseError<&'input str> + ContextError<&'input str> + FromExternalError<&'input str, TermError>,
266>(
267 input: &'input str,
268) -> IResult<&str, ParseTerm, E> {
269 alt((
270 map(parse_bool, ParseTerm::Bool),
271 map(context("integer", i64), ParseTerm::Integer),
272 map(context("variable", cap_identifier), ParseTerm::Variable),
273 map_res(context("atomic-string", identifier), |s: String| {
274 forbidden_predicates(&s).map(|_| ParseTerm::String(s))
275 }),
276 map(
277 context(
278 "string",
279 preceded(tag("\""), cut(terminated(take_until("\""), tag("\"")))),
280 ),
281 |s: &str| ParseTerm::String(s.to_string()),
282 ),
283 context(
284 "uuid",
285 preceded(
286 tag("#"),
287 cut(alt((
288 map_res(
289 take_while_m_n(32, 32, |c| is_hex_digit(c as u8)),
290 |s: &str| Ok(ParseTerm::Uuid(uuid::Uuid::parse_str(s)?)),
291 ),
292 map_res(
293 tuple((
294 take_while_m_n(8, 8, |c| is_hex_digit(c as u8)),
295 tag("-"),
296 take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
297 tag("-"),
298 take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
299 tag("-"),
300 take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
301 tag("-"),
302 take_while_m_n(12, 12, |c| is_hex_digit(c as u8)),
303 )),
304 |(a, _, b, _, c, _, d, _, e)| {
305 Ok(ParseTerm::Uuid(uuid::Uuid::parse_str(&format!(
306 "{a}-{b}-{c}-{d}-{e}"
307 ))?))
308 },
309 ),
310 ))),
311 ),
312 ),
313 ))(input)
314}
315
316fn forbidden_predicates(s: &str) -> Result<(), TermError> {
317 if ["not"].contains(&s) {
318 Err(TermError::PredicateNotError)
319 } else {
320 Ok(())
321 }
322}
323
324fn parse_predicate<
325 'input,
326 E: ParseError<&'input str> + ContextError<&'input str> + FromExternalError<&'input str, TermError>,
327>(
328 input: &'input str,
329) -> IResult<&str, Predicate, E> {
330 context(
331 "predicate",
332 alt((
333 map_res(preceded(tag("@"), identifier), |name: String| {
334 forbidden_predicates(&name).map(|_| Predicate {
335 is_intrinsic: true,
336 name,
337 })
338 }),
339 map_res(identifier, |name: String| {
340 forbidden_predicates(&name).map(|_| Predicate {
341 is_intrinsic: false,
342 name,
343 })
344 }),
345 )),
346 )(input)
347}
348
349fn parse_comment<'input, E: ParseError<&'input str>>(input: &'input str) -> IResult<&str, &str, E> {
350 delimited(tag("%"), take_until("\n"), tag("\n"))(input)
351}
352
353fn parse_trivia<'input, E: ParseError<&'input str>>(
354 input: &'input str,
355) -> IResult<&str, Vec<&str>, E> {
356 many0(alt((multispace1, parse_comment)))(input)
357}
358
359fn parse_atom<
360 'input,
361 E: ParseError<&'input str>
362 + ContextError<&'input str>
363 + FromExternalError<&'input str, TermError>
364 + FromExternalError<&'input str, uuid::Error>,
365>(
366 input: &'input str,
367) -> IResult<&str, Atom, E> {
368 let (input, predicate) = context("atom_predicate", parse_predicate)(input)?;
369 let (input, terms) = context(
370 "atom_terms",
371 opt(delimited(
372 tuple((parse_trivia, tag("("), parse_trivia)),
373 separated_list0(tuple((parse_trivia, tag(","), parse_trivia)), parse_term),
374 tuple((parse_trivia, tag(")"), parse_trivia)),
375 )),
376 )(input)?;
377 Ok((
378 input,
379 Atom {
380 predicate,
381 terms: terms.unwrap_or(vec![]),
382 },
383 ))
384}
385
386fn parse_body_atom<
387 'input,
388 E: ParseError<&'input str>
389 + ContextError<&'input str>
390 + FromExternalError<&'input str, TermError>
391 + FromExternalError<&'input str, uuid::Error>,
392>(
393 input: &'input str,
394) -> IResult<&str, BodyAtom, E> {
395 context(
396 "body_atom",
397 alt((
398 map(
399 preceded(
400 tuple((tag("not"), tuple((multispace1, parse_trivia)))),
401 parse_atom,
402 ),
403 BodyAtom::Negative,
404 ),
405 map(parse_atom, BodyAtom::Positive),
406 )),
407 )(input)
408}
409
410fn parse_constraint<
411 'input,
412 E: ParseError<&'input str>
413 + ContextError<&'input str>
414 + FromExternalError<&'input str, TermError>
415 + FromExternalError<&'input str, uuid::Error>,
416>(
417 input: &'input str,
418) -> IResult<&str, Constraint, E> {
419 context(
420 "constraint",
421 alt((
422 map(
423 terminated(
424 separated_pair(
425 parse_atom,
426 tuple((parse_trivia, tag(":-"), parse_trivia)),
427 separated_list1(
428 tuple((parse_trivia, tag(","), parse_trivia)),
429 parse_body_atom,
430 ),
431 ),
432 preceded(parse_trivia, tag(".")),
433 ),
434 |(atom, body_atoms): (Atom, Vec<BodyAtom>)| Constraint::Rule {
435 head: atom,
436 body: body_atoms,
437 },
438 ),
439 map(
440 terminated(parse_atom, preceded(parse_trivia, tag("."))),
441 Constraint::Fact,
442 ),
443 map(
444 terminated(parse_body_atom, preceded(parse_trivia, tag("?"))),
445 Constraint::Goal,
446 ),
447 )),
448 )(input)
449}
450
451pub fn parser<
452 'input,
453 E: ParseError<&'input str>
454 + ContextError<&'input str>
455 + FromExternalError<&'input str, TermError>
456 + FromExternalError<&'input str, uuid::Error>,
457>(
458 input: &'input str,
459) -> IResult<&str, Vec<Constraint>, E> {
460 context(
461 "program",
462 preceded(
463 parse_trivia,
464 many0(terminated(parse_constraint, parse_trivia)),
465 ),
466 )(input)
467}
468
469#[cfg(test)]
470mod tests {
471 use datadriven::walk;
472 use nom::{error::VerboseError, Finish};
473
474 use super::parser;
475
476 #[test]
477 fn run() {
478 walk("tests/parser", |f| {
479 f.run(|test| -> String {
480 match test.directive.as_str() {
481 "root" => {
482 let (remaining, output) =
483 match parser::<VerboseError<&str>>(&test.input).finish() {
484 Ok(data) => data,
485 Err(e) => return e.to_string(),
486 };
487 assert_eq!(remaining, "");
488 serde_json::to_string_pretty(&output).unwrap()
489 }
490 _ => "Invalid directive".to_string(),
491 }
492 })
493 });
494 }
495}