1use nom::character::complete::multispace0;
23use nom_locate::LocatedSpan;
24
25use crate::analysis::Kind;
26use crate::logic::parser::Span;
27use crate::sld;
28use crate::unification::Rename;
29
30use std::convert::TryInto;
31use std::fmt;
32use std::fmt::Debug;
33use std::ops::Range;
34use std::str;
35use std::sync::atomic::{AtomicU32, Ordering};
36use std::{collections::HashSet, hash::Hash};
37
38impl fmt::Display for IRTerm {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            IRTerm::Constant(s) => write!(f, "\"{}\"", s),
42            IRTerm::UserVariable(s) => write!(f, "{}", s),
43            IRTerm::Array(ts) => write!(
44                f,
45                "[{}]",
46                ts.iter()
47                    .map(|x| x.to_string())
48                    .collect::<Vec<_>>()
49                    .join(", ")
50            ),
51            IRTerm::AuxiliaryVariable(i) => write!(f, "__AUX_{}", i),
53            IRTerm::RenamedVariable(i, t) => write!(f, "{}_{}", t, i),
54            IRTerm::AnonymousVariable(_) => write!(f, "_"),
55        }
56    }
57}
58
59pub static AVAILABLE_VARIABLE_INDEX: AtomicU32 = AtomicU32::new(0);
60
61impl Rename<IRTerm> for IRTerm {
62    fn rename(&self) -> IRTerm {
63        match self {
64            IRTerm::Constant(_) => (*self).clone(),
65            IRTerm::Array(ts) => IRTerm::Array(ts.iter().map(|t| t.rename()).collect()),
66            _ => {
67                let index = AVAILABLE_VARIABLE_INDEX.fetch_add(1, Ordering::SeqCst);
68                IRTerm::RenamedVariable(index, Box::new((*self).clone()))
69            }
70        }
71    }
72}
73
74impl sld::Auxiliary for IRTerm {
75    fn aux(anonymous: bool) -> IRTerm {
76        let index = AVAILABLE_VARIABLE_INDEX.fetch_add(1, Ordering::SeqCst);
77        if anonymous {
78            IRTerm::AnonymousVariable(index)
79        } else {
80            IRTerm::AuxiliaryVariable(index)
81        }
82    }
83}
84
85#[derive(Clone, PartialEq, Eq, Hash, Debug)]
87pub struct Predicate(pub String);
88
89impl Predicate {
90    pub fn is_operator(&self) -> bool {
92        self.0.starts_with("_operator_")
93    }
94
95    pub fn unmangle(self) -> Predicate {
97        if self.is_operator() {
98            Predicate(
99                self.0
100                    .trim_start_matches("_operator_")
101                    .trim_end_matches("_begin")
102                    .trim_end_matches("_end")
103                    .to_string(),
104            )
105        } else {
106            self
107        }
108    }
109
110    pub fn naive_predicate_kind(&self) -> Kind {
113        match self.0.as_str() {
114            "from" => Kind::Image,
115            "run" | "copy" => Kind::Layer,
116            _ => Kind::Logic,
117        }
118    }
119}
120
121impl From<String> for Predicate {
122    fn from(s: String) -> Self {
123        Predicate(s)
124    }
125}
126
127#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
128pub enum IRTerm {
129    Constant(String),
130    UserVariable(String),
131    Array(Vec<IRTerm>),
132
133    AuxiliaryVariable(u32),
135
136    RenamedVariable(u32, Box<IRTerm>),
137
138    AnonymousVariable(u32),
141}
142
143impl IRTerm {
144    pub fn is_constant(&self) -> bool {
148        matches!(self, Self::Constant(..))
149    }
150
151    pub fn is_constant_or_compound_constant(&self) -> bool {
152        match self {
153            Self::Constant(_) => true,
154            Self::Array(ts) => ts.iter().all(|t| t.is_constant_or_compound_constant()),
155            _ => false,
156        }
157    }
158
159    pub fn is_underlying_anonymous_variable(&self) -> bool {
164        match self {
165            Self::AnonymousVariable(_) => true,
166            Self::RenamedVariable(_, t) => t.is_underlying_anonymous_variable(),
167            _ => false,
168        }
169    }
170
171    pub fn as_constant(&self) -> Option<&str> {
172        match self {
173            IRTerm::Constant(c) => Some(&c[..]),
174            _ => None,
175        }
176    }
177
178    pub fn get_original(&self) -> &IRTerm {
180        match self {
181            IRTerm::RenamedVariable(_, t) => t.get_original(),
182            t => t,
183        }
184    }
185
186    pub fn is_anonymous_variable(&self) -> bool {
190        matches!(self, Self::AnonymousVariable(..))
191    }
192}
193
194#[derive(Clone, PartialEq, Eq, Hash, Debug)]
198pub struct SpannedPosition {
199    pub offset: usize,
201
202    pub length: usize,
204}
205
206impl From<&SpannedPosition> for Range<usize> {
207    fn from(s: &SpannedPosition) -> Self {
208        s.offset..(s.offset + s.length)
209    }
210}
211
212impl From<Span<'_>> for SpannedPosition {
213    fn from(s: Span) -> Self {
214        SpannedPosition {
215            length: s.fragment().len(),
216            offset: s.location_offset(),
217        }
218    }
219}
220
221#[derive(Clone, PartialEq, Eq, Hash, Debug)]
222pub struct Literal<T = IRTerm> {
223    pub positive: bool,
226
227    pub position: Option<SpannedPosition>,
228    pub predicate: Predicate,
229    pub args: Vec<T>,
230}
231
232#[cfg(test)]
233impl<T: PartialEq> Literal<T> {
234    pub fn eq_ignoring_position(&self, other: &Literal<T>) -> bool {
236        self.positive == other.positive
237            && self.predicate == other.predicate
238            && self.args == other.args
239    }
240}
241
242#[derive(Clone, PartialEq, Eq, Hash, Debug)]
243pub struct Signature(pub Predicate, pub u32);
244
245#[derive(Clone, PartialEq, Debug)]
246pub struct Clause<T = IRTerm> {
247    pub head: Literal<T>,
248    pub body: Vec<Literal<T>>,
249}
250
251#[cfg(test)]
252impl<T: PartialEq> Clause<T> {
253    pub fn eq_ignoring_position(&self, other: &Clause<T>) -> bool {
254        self.head.eq_ignoring_position(&other.head)
255            && self.body.len() == other.body.len()
256            && self
257                .body
258                .iter()
259                .enumerate()
260                .all(|(i, l)| l.eq_ignoring_position(&other.body[i]))
261    }
262}
263
264pub trait Ground {
265    fn is_ground(&self) -> bool;
266}
267
268impl IRTerm {
269    pub fn variables(&self, include_anonymous: bool) -> HashSet<IRTerm> {
270        let mut set = HashSet::<IRTerm>::new();
271        match (self, include_anonymous) {
272            (IRTerm::AnonymousVariable(_), true) => {
273                set.insert(self.clone());
274            }
275            (IRTerm::Array(ts), b) => {
276                set.extend(ts.iter().flat_map(|t| t.variables(b)));
277            }
278            (IRTerm::AuxiliaryVariable(_), _)
279            | (IRTerm::RenamedVariable(..), _)
280            | (IRTerm::UserVariable(_), _) => {
281                set.insert(self.clone());
282            }
283            (IRTerm::Constant(_), _) | (IRTerm::AnonymousVariable(_), false) => (),
284        }
285        set
286    }
287}
288
289impl Literal {
290    pub fn signature(&self) -> Signature {
291        Signature(self.predicate.clone(), self.args.len().try_into().unwrap())
292    }
293    pub fn variables(&self, include_anonymous: bool) -> HashSet<IRTerm> {
294        self.args
295            .iter()
296            .map(|r| r.variables(include_anonymous))
297            .reduce(|mut l, r| {
298                l.extend(r);
299                l
300            })
301            .unwrap_or_default()
302    }
303
304    pub fn unmangle(self) -> Literal {
307        if self.predicate.is_operator() {
308            Literal {
309                predicate: self.predicate.unmangle(),
310                args: self.args[1..].to_vec(),
311                ..self
312            }
313        } else {
314            self
315        }
316    }
317
318    pub fn negated(&self) -> Literal {
319        Literal {
320            positive: !self.positive,
321            ..self.clone()
322        }
323    }
324}
325
326impl<T> Literal<T> {
327    pub fn with_position(self, position: Option<SpannedPosition>) -> Literal<T> {
328        Literal { position, ..self }
329    }
330}
331
332impl Clause {
333    pub fn variables(&self, include_anonymous: bool) -> HashSet<IRTerm> {
334        let mut body = self
335            .body
336            .iter()
337            .map(|r| r.variables(include_anonymous))
338            .reduce(|mut l, r| {
339                l.extend(r);
340                l
341            })
342            .unwrap_or_default();
343        body.extend(self.head.variables(include_anonymous));
344        body
345    }
346}
347
348impl Ground for IRTerm {
349    fn is_ground(&self) -> bool {
350        matches!(self, IRTerm::Constant(_))
351    }
352}
353
354impl Ground for Literal {
355    fn is_ground(&self) -> bool {
356        self.variables(true).is_empty()
357    }
358}
359
360impl Ground for Clause {
361    fn is_ground(&self) -> bool {
362        self.variables(true).is_empty()
363    }
364}
365
366impl fmt::Display for Signature {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        write!(f, "{}/{}", self.0, self.1)
369    }
370}
371
372impl fmt::Display for Predicate {
373    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
374        write!(f, "{}", self.0)
375    }
376}
377
378fn display_sep<T: fmt::Display>(seq: &[T], sep: &str) -> String {
379    return seq
380        .iter()
381        .map(|t| t.to_string())
382        .collect::<Vec<String>>()
383        .join(sep);
384}
385
386impl fmt::Display for Literal {
387    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        match &*self.args {
389            [] => write!(
390                f,
391                "{}{}",
392                if self.positive { "" } else { "!" },
393                self.predicate
394            ),
395            _ => write!(
396                f,
397                "{}{}({})",
398                if self.positive { "" } else { "!" },
399                self.predicate,
400                display_sep(&self.args, ", ")
401            ),
402        }
403    }
404}
405
406impl fmt::Display for Clause {
407    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408        write!(f, "{} :- {}", self.head, display_sep(&self.body, ", "))
409    }
410}
411
412impl str::FromStr for Clause {
413    type Err = String;
414
415    fn from_str(s: &str) -> Result<Self, Self::Err> {
416        let span = Span::new(s);
417        match parser::clause(parser::term)(span) {
418            Result::Ok((_, o)) => Ok(o),
419            Result::Err(e) => Result::Err(format!("{}", e)),
420        }
421    }
422}
423
424impl str::FromStr for Literal {
425    type Err = String;
426
427    fn from_str(s: &str) -> Result<Self, Self::Err> {
428        let span = Span::new(s);
429        match parser::literal(parser::term, multispace0)(span) {
430            Result::Ok((_, o)) => Ok(o),
431            Result::Err(e) => Result::Err(format!("{}", e)),
432        }
433    }
434}
435
436pub mod parser {
438
439    use super::*;
440
441    use nom::{
442        branch::alt,
443        bytes::complete::{is_a, take_until},
444        character::complete::{alpha1, alphanumeric1, multispace0},
445        combinator::{cut, map, opt, recognize},
446        multi::{many0, many0_count, separated_list0, separated_list1},
447        sequence::{delimited, pair, preceded, terminated, tuple},
448        Offset, Slice,
449    };
450    use nom_supreme::{error::ErrorTree, tag::complete::tag};
451
452    pub type Span<'a> = LocatedSpan<&'a str>;
453
454    pub type IResult<T, O> = nom::IResult<T, O, ErrorTree<T>>;
456
457    pub fn recognized_span<'a, P, T>(
460        mut inner: P,
461    ) -> impl FnMut(Span<'a>) -> IResult<Span<'a>, (SpannedPosition, T)>
462    where
463        P: FnMut(Span<'a>) -> IResult<Span<'a>, T>,
464    {
465        move |i| {
466            let original_i = i.clone();
467
468            let (i, o) = inner(i)?;
469
470            let index = original_i.offset(&i);
471            let recognized_section = original_i.slice(..index);
472            let spanned_pos: SpannedPosition = recognized_section.into();
473
474            Ok((i, (spanned_pos, o)))
475        }
476    }
477
478    fn ws<'a, F: 'a, O>(inner: F) -> impl FnMut(Span<'a>) -> IResult<Span<'a>, O>
479    where
480        F: FnMut(Span<'a>) -> IResult<Span<'a>, O>,
481    {
482        delimited(multispace0, inner, multispace0)
483    }
484
485    fn constant(i: Span) -> IResult<Span, Span> {
486        delimited(tag("\""), take_until("\""), tag("\""))(i)
487    }
488
489    fn variable(i: Span) -> IResult<Span, Span> {
490        literal_identifier(i)
491    }
492
493    fn array_term(i: Span) -> IResult<Span, Vec<IRTerm>> {
494        delimited(
495            terminated(tag("["), multispace0),
496            separated_list0(delimited(multispace0, tag(","), multispace0), term),
497            preceded(multispace0, tag("]")),
498        )(i)
499    }
500
501    pub fn term(i: Span) -> IResult<Span, IRTerm> {
502        alt((
503            map(array_term, IRTerm::Array),
504            map(constant, |s| IRTerm::Constant(s.fragment().to_string())),
505            map(is_a("_"), |_| sld::Auxiliary::aux(true)),
506            map(variable, |s| IRTerm::UserVariable(s.fragment().to_string())),
507        ))(i)
508    }
509
510    pub fn literal_identifier(i: Span) -> IResult<Span, Span> {
512        recognize(pair(
513            alt((alpha1, tag("_"))),
514            many0(alt((alphanumeric1, tag("_"), tag("-")))),
515        ))(i)
516    }
517
518    pub fn literal<'a, FT: 'a, T, S, Any>(
520        term: FT,
521        space0: S,
522    ) -> impl FnMut(Span<'a>) -> IResult<Span<'a>, Literal<T>>
523    where
524        FT: FnMut(Span<'a>) -> IResult<Span<'a>, T> + Clone,
525        S: FnMut(Span<'a>) -> IResult<Span<'a>, Any> + Clone,
526    {
527        move |i| {
528            let (i, (spanned_pos, (neg_count, name, args))) = recognized_span(tuple((
529                many0_count(terminated(
531                    nom::character::complete::char('!'),
532                    space0.clone(),
533                )),
534                terminated(literal_identifier, space0.clone()),
535                opt(delimited(
536                    terminated(tag("("), space0.clone()),
537                    separated_list1(
538                        terminated(tag(","), space0.clone()),
539                        terminated(term.clone(), space0.clone()),
540                    ),
541                    cut(terminated(tag(")"), space0.clone())),
542                )),
543            )))(i)?;
544
545            Ok((
546                i,
547                Literal {
548                    positive: neg_count % 2 == 0,
549                    position: Some(spanned_pos),
550                    predicate: Predicate(name.fragment().to_string()),
551                    args: match args {
552                        Some(args) => args,
553                        None => Vec::new(),
554                    },
555                },
556            ))
557        }
558    }
559
560    pub fn clause<'a, FT: 'a, T>(term: FT) -> impl FnMut(Span<'a>) -> IResult<Span<'a>, Clause<T>>
561    where
562        FT: FnMut(Span) -> IResult<Span, T> + Clone,
563    {
564        map(
565            pair(
566                literal(term.clone(), multispace0),
567                opt(preceded(
568                    ws(tag(":-")),
569                    separated_list0(ws(tag(",")), literal(term, multispace0)),
570                )),
571            ),
572            |(head, body)| Clause {
573                head,
574                body: body.unwrap_or(Vec::new()),
575            },
576        )
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[test]
585    fn simple_term() {
586        let inp = "\"\"";
587
588        let expected = IRTerm::Constant("".into());
589        let actual: IRTerm = parser::term(Span::new(inp)).unwrap().1;
590
591        assert_eq!(expected, actual);
592    }
593
594    #[test]
595    fn literals() {
596        let l1 = Literal {
597            positive: true,
598            position: None,
599            predicate: Predicate("l1".into()),
600            args: vec![IRTerm::Constant("c".into()), IRTerm::Constant("d".into())],
601        };
602
603        assert_eq!("l1(\"c\", \"d\")", l1.to_string());
604
605        let actual1: Literal = "l1(\"c\", \"d\")".parse().unwrap();
606        let actual2: Literal = "l1(\"c\",\n\t\"d\")".parse().unwrap();
607        assert!(l1.eq_ignoring_position(&actual1));
608        assert!(l1.eq_ignoring_position(&actual2));
609    }
610
611    #[test]
612    fn literal_with_variable() {
613        let l1 = Literal {
614            positive: true,
615            position: None,
616            predicate: Predicate("l1".into()),
617            args: vec![
618                IRTerm::Constant("".into()),
619                IRTerm::UserVariable("X".into()),
620            ],
621        };
622
623        assert_eq!("l1(\"\", X)", l1.to_string());
624
625        let actual: Literal = "l1(\"\", X)".parse().unwrap();
626        assert!(l1.eq_ignoring_position(&actual));
627    }
628
629    #[test]
630    fn negated_literal() {
631        let l1 = Literal {
632            positive: false,
633            position: None,
634            predicate: Predicate("l1".into()),
635            args: vec![
636                IRTerm::Constant("".into()),
637                IRTerm::UserVariable("X".into()),
638            ],
639        };
640
641        assert_eq!("!l1(\"\", X)", l1.to_string());
642
643        let actual: Literal = "!!!l1(\"\", X)".parse().unwrap();
644        assert!(l1.eq_ignoring_position(&actual));
645    }
646
647    #[test]
648    fn span_of_literal() {
649        let spanned_pos = SpannedPosition {
650            length: 22,
651            offset: 0,
652        };
653
654        let actual: Literal = "l1(\"test_constant\", X)".parse().unwrap();
655        assert_eq!(Some(spanned_pos), actual.position);
656    }
657
658    #[test]
659    fn simple_rule() {
660        let c = IRTerm::Constant("c".into());
661        let va = IRTerm::UserVariable("A".into());
662        let vb = IRTerm::UserVariable("B".into());
663        let l1 = Literal {
664            positive: true,
665            position: None,
666            predicate: Predicate("l1".into()),
667            args: vec![va.clone(), vb.clone()],
668        };
669        let l2 = Literal {
670            positive: true,
671            position: None,
672            predicate: Predicate("l2".into()),
673            args: vec![va.clone(), c.clone()],
674        };
675        let l3 = Literal {
676            positive: true,
677            position: None,
678            predicate: Predicate("l3".into()),
679            args: vec![vb.clone(), c.clone()],
680        };
681        let r = Clause {
682            head: l1,
683            body: vec![l2, l3],
684        };
685
686        assert_eq!("l1(A, B) :- l2(A, \"c\"), l3(B, \"c\")", r.to_string());
687
688        let actual: Clause = "l1(A, B) :- l2(A, \"c\"), l3(B, \"c\")".parse().unwrap();
689        assert!(r.eq_ignoring_position(&actual));
690    }
691
692    #[test]
693    fn nullary_predicate() {
694        let va = IRTerm::UserVariable("A".into());
695        let l1 = Literal {
696            positive: true,
697            position: None,
698            predicate: Predicate("l1".into()),
699            args: Vec::new(),
700        };
701        let l2 = Literal {
702            positive: true,
703            position: None,
704            predicate: Predicate("l2".into()),
705            args: vec![va.clone()],
706        };
707        let r = Clause {
708            head: l1,
709            body: vec![l2],
710        };
711
712        assert_eq!("l1 :- l2(A)", r.to_string());
713
714        let actual: Clause = "l1 :- l2(A)".parse().unwrap();
715        assert!(r.eq_ignoring_position(&actual))
716    }
717}