arde/
parser.rs

1use nom::branch::alt;
2use nom::bytes::complete::{tag, take_until, take_while, take_while1, take_while_m_n};
3use nom::character::complete::{i64, multispace1};
4use nom::character::is_hex_digit;
5use nom::combinator::{cut, map, map_res, opt, value};
6use nom::error::{context, ContextError, FromExternalError, ParseError};
7use nom::multi::{many0, separated_list0, separated_list1};
8use nom::sequence::{delimited, preceded, separated_pair, terminated, tuple};
9use nom::IResult;
10
11#[cfg_attr(
12    feature = "serde_internal",
13    derive(serde::Serialize, serde::Deserialize)
14)]
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub enum ParseTerm {
17    Variable(String),
18    Bool(bool),
19    Integer(i64),
20    String(String),
21    Uuid(uuid::Uuid),
22}
23
24impl std::fmt::Display for ParseTerm {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            ParseTerm::Variable(v) => write!(f, "{v}"),
28            ParseTerm::Bool(v) => write!(f, "{v}"),
29            ParseTerm::Integer(v) => write!(f, "{v}"),
30            ParseTerm::String(s) => {
31                if s.chars().any(|c| c.is_whitespace()) {
32                    write!(f, "\"{s}\"")
33                } else {
34                    write!(f, "{s}")
35                }
36            }
37            ParseTerm::Uuid(id) => write!(f, "#{id}"),
38        }
39    }
40}
41
42#[cfg_attr(
43    feature = "serde_internal",
44    derive(serde::Serialize, serde::Deserialize)
45)]
46#[derive(Debug, PartialEq, Eq, Clone, Hash)]
47pub struct Predicate {
48    pub is_intrinsic: bool,
49    pub name: String,
50}
51
52impl std::fmt::Display for Predicate {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        if self.is_intrinsic {
55            write!(f, "@{}", self.name)
56        } else {
57            write!(f, "{}", self.name)
58        }
59    }
60}
61
62#[cfg_attr(
63    feature = "serde_internal",
64    derive(serde::Serialize, serde::Deserialize)
65)]
66#[derive(Debug, PartialEq, Eq, Clone)]
67pub struct Atom {
68    pub predicate: Predicate,
69    pub terms: Vec<ParseTerm>,
70}
71
72impl std::fmt::Display for Atom {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", self.predicate)?;
75        write!(f, "(")?;
76        if let Some(fr) = self.terms.first() {
77            write!(f, "{}", fr)?;
78        }
79        for term in self.terms.iter().skip(1) {
80            write!(f, ",")?;
81            write!(f, "{}", term)?;
82        }
83        write!(f, ")")
84    }
85}
86
87#[cfg_attr(
88    feature = "serde_internal",
89    derive(serde::Serialize, serde::Deserialize)
90)]
91#[derive(Debug, PartialEq, Eq, Clone)]
92pub enum BodyAtom {
93    Positive(Atom),
94    Negative(Atom),
95}
96
97impl std::fmt::Display for BodyAtom {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        match self {
100            Self::Negative(atom) => write!(f, "not {atom}"),
101            Self::Positive(atom) => write!(f, "{atom}"),
102        }
103    }
104}
105
106impl BodyAtom {
107    pub fn atom(&self) -> &Atom {
108        match self {
109            Self::Positive(a) => a,
110            Self::Negative(n) => n,
111        }
112    }
113}
114
115#[cfg_attr(
116    feature = "serde_internal",
117    derive(serde::Serialize, serde::Deserialize)
118)]
119#[derive(Debug, PartialEq, Eq)]
120pub enum Constraint {
121    Fact(Atom),
122    Goal(BodyAtom),
123    Rule { head: Atom, body: Vec<BodyAtom> },
124}
125
126#[derive(Debug, PartialEq, Eq)]
127pub enum Ast<'a> {
128    ParseTerm(&'a ParseTerm),
129    Atom(&'a Atom),
130    BodyAtom(&'a BodyAtom),
131    Constraint(&'a Constraint),
132}
133
134impl<'a> From<&'a Constraint> for Ast<'a> {
135    fn from(value: &'a Constraint) -> Self {
136        Ast::Constraint(value)
137    }
138}
139
140impl<'a> From<&'a BodyAtom> for Ast<'a> {
141    fn from(value: &'a BodyAtom) -> Self {
142        Ast::BodyAtom(value)
143    }
144}
145
146impl<'a> From<&'a Atom> for Ast<'a> {
147    fn from(value: &'a Atom) -> Self {
148        Ast::Atom(value)
149    }
150}
151
152impl<'a> From<&'a ParseTerm> for Ast<'a> {
153    fn from(value: &'a ParseTerm) -> Self {
154        Ast::ParseTerm(value)
155    }
156}
157
158pub trait Visitor {
159    type Output;
160
161    fn visit_parse_term(&self, term: &ParseTerm) -> Self::Output;
162    fn visit_atom(&self, atom: &Atom) -> Self::Output;
163    fn visit_body_atom(&self, body_atom: &BodyAtom) -> Self::Output;
164    fn visit_constraint(&self, constraint: &Constraint) -> Self::Output;
165
166    fn visit(&self, ast: Ast) -> Self::Output {
167        match ast {
168            Ast::ParseTerm(r) => self.visit_parse_term(r),
169            Ast::Atom(r) => self.visit_atom(r),
170            Ast::BodyAtom(r) => self.visit_body_atom(r),
171            Ast::Constraint(r) => self.visit_constraint(r),
172        }
173    }
174}
175
176pub trait VisitorMut {
177    type Output;
178
179    fn visit_parse_term(&mut self, term: &ParseTerm) -> Self::Output;
180    fn visit_atom(&mut self, atom: &Atom) -> Self::Output;
181    fn visit_body_atom(&mut self, body_atom: &BodyAtom) -> Self::Output;
182    fn visit_constraint(&mut self, constraint: &Constraint) -> Self::Output;
183
184    fn visit(&mut self, ast: Ast) -> Self::Output {
185        match ast {
186            Ast::ParseTerm(r) => self.visit_parse_term(r),
187            Ast::Atom(r) => self.visit_atom(r),
188            Ast::BodyAtom(r) => self.visit_body_atom(r),
189            Ast::Constraint(r) => self.visit_constraint(r),
190        }
191    }
192}
193
194impl<T: Visitor> VisitorMut for T {
195    type Output = <T as Visitor>::Output;
196
197    fn visit_parse_term(&mut self, term: &ParseTerm) -> Self::Output {
198        <T as Visitor>::visit_parse_term(self, term)
199    }
200
201    fn visit_atom(&mut self, atom: &Atom) -> Self::Output {
202        <T as Visitor>::visit_atom(self, atom)
203    }
204
205    fn visit_body_atom(&mut self, body_atom: &BodyAtom) -> Self::Output {
206        <T as Visitor>::visit_body_atom(self, body_atom)
207    }
208
209    fn visit_constraint(&mut self, constraint: &Constraint) -> Self::Output {
210        <T as Visitor>::visit_constraint(self, constraint)
211    }
212}
213
214fn parse_bool<'input, E: ParseError<&'input str> + ContextError<&'input str>>(
215    input: &'input str,
216) -> IResult<&str, bool, E> {
217    context(
218        "bool",
219        alt((value(true, tag("true")), value(false, tag("false")))),
220    )(input)
221}
222
223fn is_ident_char(c: char) -> bool {
224    "_!~+-*/&|".contains(c)
225}
226
227fn identifier<'input, E: ParseError<&'input str>>(input: &'input str) -> IResult<&str, String, E> {
228    map(
229        tuple((
230            take_while1(|s: char| s.is_alphabetic() || is_ident_char(s)),
231            take_while(|s: char| s.is_alphanumeric() || is_ident_char(s)),
232        )),
233        |(s, t): (&str, &str)| format!("{s}{t}"),
234    )(input)
235}
236
237fn cap_identifier<'input, E: ParseError<&'input str>>(
238    input: &'input str,
239) -> IResult<&str, String, E> {
240    map(
241        tuple((
242            take_while1(|s: char| s.is_uppercase() && s.is_alphabetic()),
243            take_while(|s: char| s.is_alphanumeric() || is_ident_char(s)),
244        )),
245        |(s, t): (&str, &str)| format!("{s}{t}"),
246    )(input)
247}
248
249#[derive(thiserror::Error, Debug)]
250pub enum TermError {
251    #[error("uuid error: {0}")]
252    Uuid(#[source] uuid::Error),
253    #[error("predicate cannot be used")]
254    PredicateNotError,
255}
256
257impl From<uuid::Error> for TermError {
258    fn from(value: uuid::Error) -> Self {
259        Self::Uuid(value)
260    }
261}
262
263fn parse_term<
264    'input,
265    E: ParseError<&'input str> + ContextError<&'input str> + FromExternalError<&'input str, TermError>,
266>(
267    input: &'input str,
268) -> IResult<&str, ParseTerm, E> {
269    alt((
270        map(parse_bool, ParseTerm::Bool),
271        map(context("integer", i64), ParseTerm::Integer),
272        map(context("variable", cap_identifier), ParseTerm::Variable),
273        map_res(context("atomic-string", identifier), |s: String| {
274            forbidden_predicates(&s).map(|_| ParseTerm::String(s))
275        }),
276        map(
277            context(
278                "string",
279                preceded(tag("\""), cut(terminated(take_until("\""), tag("\"")))),
280            ),
281            |s: &str| ParseTerm::String(s.to_string()),
282        ),
283        context(
284            "uuid",
285            preceded(
286                tag("#"),
287                cut(alt((
288                    map_res(
289                        take_while_m_n(32, 32, |c| is_hex_digit(c as u8)),
290                        |s: &str| Ok(ParseTerm::Uuid(uuid::Uuid::parse_str(s)?)),
291                    ),
292                    map_res(
293                        tuple((
294                            take_while_m_n(8, 8, |c| is_hex_digit(c as u8)),
295                            tag("-"),
296                            take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
297                            tag("-"),
298                            take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
299                            tag("-"),
300                            take_while_m_n(4, 4, |c| is_hex_digit(c as u8)),
301                            tag("-"),
302                            take_while_m_n(12, 12, |c| is_hex_digit(c as u8)),
303                        )),
304                        |(a, _, b, _, c, _, d, _, e)| {
305                            Ok(ParseTerm::Uuid(uuid::Uuid::parse_str(&format!(
306                                "{a}-{b}-{c}-{d}-{e}"
307                            ))?))
308                        },
309                    ),
310                ))),
311            ),
312        ),
313    ))(input)
314}
315
316fn forbidden_predicates(s: &str) -> Result<(), TermError> {
317    if ["not"].contains(&s) {
318        Err(TermError::PredicateNotError)
319    } else {
320        Ok(())
321    }
322}
323
324fn parse_predicate<
325    'input,
326    E: ParseError<&'input str> + ContextError<&'input str> + FromExternalError<&'input str, TermError>,
327>(
328    input: &'input str,
329) -> IResult<&str, Predicate, E> {
330    context(
331        "predicate",
332        alt((
333            map_res(preceded(tag("@"), identifier), |name: String| {
334                forbidden_predicates(&name).map(|_| Predicate {
335                    is_intrinsic: true,
336                    name,
337                })
338            }),
339            map_res(identifier, |name: String| {
340                forbidden_predicates(&name).map(|_| Predicate {
341                    is_intrinsic: false,
342                    name,
343                })
344            }),
345        )),
346    )(input)
347}
348
349fn parse_comment<'input, E: ParseError<&'input str>>(input: &'input str) -> IResult<&str, &str, E> {
350    delimited(tag("%"), take_until("\n"), tag("\n"))(input)
351}
352
353fn parse_trivia<'input, E: ParseError<&'input str>>(
354    input: &'input str,
355) -> IResult<&str, Vec<&str>, E> {
356    many0(alt((multispace1, parse_comment)))(input)
357}
358
359fn parse_atom<
360    'input,
361    E: ParseError<&'input str>
362        + ContextError<&'input str>
363        + FromExternalError<&'input str, TermError>
364        + FromExternalError<&'input str, uuid::Error>,
365>(
366    input: &'input str,
367) -> IResult<&str, Atom, E> {
368    let (input, predicate) = context("atom_predicate", parse_predicate)(input)?;
369    let (input, terms) = context(
370        "atom_terms",
371        opt(delimited(
372            tuple((parse_trivia, tag("("), parse_trivia)),
373            separated_list0(tuple((parse_trivia, tag(","), parse_trivia)), parse_term),
374            tuple((parse_trivia, tag(")"), parse_trivia)),
375        )),
376    )(input)?;
377    Ok((
378        input,
379        Atom {
380            predicate,
381            terms: terms.unwrap_or(vec![]),
382        },
383    ))
384}
385
386fn parse_body_atom<
387    'input,
388    E: ParseError<&'input str>
389        + ContextError<&'input str>
390        + FromExternalError<&'input str, TermError>
391        + FromExternalError<&'input str, uuid::Error>,
392>(
393    input: &'input str,
394) -> IResult<&str, BodyAtom, E> {
395    context(
396        "body_atom",
397        alt((
398            map(
399                preceded(
400                    tuple((tag("not"), tuple((multispace1, parse_trivia)))),
401                    parse_atom,
402                ),
403                BodyAtom::Negative,
404            ),
405            map(parse_atom, BodyAtom::Positive),
406        )),
407    )(input)
408}
409
410fn parse_constraint<
411    'input,
412    E: ParseError<&'input str>
413        + ContextError<&'input str>
414        + FromExternalError<&'input str, TermError>
415        + FromExternalError<&'input str, uuid::Error>,
416>(
417    input: &'input str,
418) -> IResult<&str, Constraint, E> {
419    context(
420        "constraint",
421        alt((
422            map(
423                terminated(
424                    separated_pair(
425                        parse_atom,
426                        tuple((parse_trivia, tag(":-"), parse_trivia)),
427                        separated_list1(
428                            tuple((parse_trivia, tag(","), parse_trivia)),
429                            parse_body_atom,
430                        ),
431                    ),
432                    preceded(parse_trivia, tag(".")),
433                ),
434                |(atom, body_atoms): (Atom, Vec<BodyAtom>)| Constraint::Rule {
435                    head: atom,
436                    body: body_atoms,
437                },
438            ),
439            map(
440                terminated(parse_atom, preceded(parse_trivia, tag("."))),
441                Constraint::Fact,
442            ),
443            map(
444                terminated(parse_body_atom, preceded(parse_trivia, tag("?"))),
445                Constraint::Goal,
446            ),
447        )),
448    )(input)
449}
450
451pub fn parser<
452    'input,
453    E: ParseError<&'input str>
454        + ContextError<&'input str>
455        + FromExternalError<&'input str, TermError>
456        + FromExternalError<&'input str, uuid::Error>,
457>(
458    input: &'input str,
459) -> IResult<&str, Vec<Constraint>, E> {
460    context(
461        "program",
462        preceded(
463            parse_trivia,
464            many0(terminated(parse_constraint, parse_trivia)),
465        ),
466    )(input)
467}
468
469#[cfg(test)]
470mod tests {
471    use datadriven::walk;
472    use nom::{error::VerboseError, Finish};
473
474    use super::parser;
475
476    #[test]
477    fn run() {
478        walk("tests/parser", |f| {
479            f.run(|test| -> String {
480                match test.directive.as_str() {
481                    "root" => {
482                        let (remaining, output) =
483                            match parser::<VerboseError<&str>>(&test.input).finish() {
484                                Ok(data) => data,
485                                Err(e) => return e.to_string(),
486                            };
487                        assert_eq!(remaining, "");
488                        serde_json::to_string_pretty(&output).unwrap()
489                    }
490                    _ => "Invalid directive".to_string(),
491                }
492            })
493        });
494    }
495}