ipfrs_tensorlogic/
datalog.rs

1//! Datalog syntax parser for TensorLogic
2//!
3//! Supports parsing Datalog syntax for facts, rules, and queries:
4//! - Facts: `parent(alice, bob).`
5//! - Rules: `grandparent(X, Z) :- parent(X, Y), parent(Y, Z).`
6//! - Queries: `?- parent(alice, X).`
7
8use crate::ir::{Constant, Predicate, Rule, Term};
9use std::fmt;
10
11/// Datalog parse error
12#[derive(Debug, Clone)]
13pub struct ParseError {
14    pub message: String,
15    pub position: usize,
16}
17
18impl fmt::Display for ParseError {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        write!(
21            f,
22            "Parse error at position {}: {}",
23            self.position, self.message
24        )
25    }
26}
27
28impl std::error::Error for ParseError {}
29
30type ParseResult<T> = Result<T, ParseError>;
31
32/// Datalog parser
33pub struct DatalogParser {
34    input: String,
35    position: usize,
36}
37
38impl DatalogParser {
39    /// Create a new parser for the given input
40    pub fn new(input: String) -> Self {
41        Self { input, position: 0 }
42    }
43
44    /// Parse a fact or rule
45    pub fn parse_statement(&mut self) -> ParseResult<Statement> {
46        self.skip_whitespace();
47
48        if self.peek_char() == Some('?') {
49            // Query
50            self.advance(); // skip '?'
51            self.expect_char('-')?;
52            self.skip_whitespace();
53            let predicate = self.parse_predicate()?;
54            self.skip_whitespace();
55            self.expect_char('.')?;
56            Ok(Statement::Query(predicate))
57        } else {
58            // Fact or Rule
59            let head = self.parse_predicate()?;
60            self.skip_whitespace();
61
62            if self.peek_char() == Some('.') {
63                // Fact
64                self.advance();
65                Ok(Statement::Fact(head))
66            } else if self.peek_str(2) == Some(":-") {
67                // Rule
68                self.advance();
69                self.advance();
70                self.skip_whitespace();
71
72                let body = self.parse_predicate_list()?;
73                self.skip_whitespace();
74                self.expect_char('.')?;
75
76                Ok(Statement::Rule(Rule::new(head, body)))
77            } else {
78                Err(ParseError {
79                    message: "Expected '.' or ':-'".to_string(),
80                    position: self.position,
81                })
82            }
83        }
84    }
85
86    /// Parse a predicate like `parent(alice, bob)`
87    fn parse_predicate(&mut self) -> ParseResult<Predicate> {
88        let name = self.parse_identifier()?;
89        self.skip_whitespace();
90        self.expect_char('(')?;
91        self.skip_whitespace();
92
93        let args = self.parse_term_list()?;
94        self.skip_whitespace();
95        self.expect_char(')')?;
96
97        Ok(Predicate::new(name, args))
98    }
99
100    /// Parse a comma-separated list of predicates
101    fn parse_predicate_list(&mut self) -> ParseResult<Vec<Predicate>> {
102        let mut predicates = Vec::new();
103
104        loop {
105            predicates.push(self.parse_predicate()?);
106            self.skip_whitespace();
107
108            if self.peek_char() == Some(',') {
109                self.advance();
110                self.skip_whitespace();
111            } else {
112                break;
113            }
114        }
115
116        Ok(predicates)
117    }
118
119    /// Parse a comma-separated list of terms
120    fn parse_term_list(&mut self) -> ParseResult<Vec<Term>> {
121        let mut terms = Vec::new();
122
123        if self.peek_char() == Some(')') {
124            return Ok(terms); // Empty list
125        }
126
127        loop {
128            terms.push(self.parse_term()?);
129            self.skip_whitespace();
130
131            if self.peek_char() == Some(',') {
132                self.advance();
133                self.skip_whitespace();
134            } else {
135                break;
136            }
137        }
138
139        Ok(terms)
140    }
141
142    /// Parse a term (variable, constant, or function)
143    fn parse_term(&mut self) -> ParseResult<Term> {
144        self.skip_whitespace();
145
146        let ch = self.peek_char().ok_or_else(|| ParseError {
147            message: "Unexpected end of input".to_string(),
148            position: self.position,
149        })?;
150
151        if ch == '?' || ch.is_uppercase() {
152            // Variable
153            if ch == '?' {
154                self.advance();
155            }
156            let name = self.parse_identifier()?;
157            Ok(Term::Var(name))
158        } else if ch == '"' {
159            // String constant
160            self.advance(); // skip opening quote
161            let value = self.parse_string()?;
162            self.expect_char('"')?;
163            Ok(Term::Const(Constant::String(value)))
164        } else if ch.is_ascii_digit() || ch == '-' {
165            // Numeric constant
166            let value = self.parse_number()?;
167            Ok(Term::Const(Constant::Int(value)))
168        } else if ch.is_lowercase() {
169            // Could be a constant atom or function
170            let name = self.parse_identifier()?;
171            self.skip_whitespace();
172
173            if self.peek_char() == Some('(') {
174                // Function
175                self.advance();
176                self.skip_whitespace();
177                let args = self.parse_term_list()?;
178                self.skip_whitespace();
179                self.expect_char(')')?;
180                Ok(Term::Fun(name, args))
181            } else {
182                // Atom constant
183                Ok(Term::Const(Constant::String(name)))
184            }
185        } else {
186            Err(ParseError {
187                message: format!("Unexpected character: '{}'", ch),
188                position: self.position,
189            })
190        }
191    }
192
193    /// Parse an identifier
194    fn parse_identifier(&mut self) -> ParseResult<String> {
195        let start = self.position;
196        while let Some(ch) = self.peek_char() {
197            if ch.is_alphanumeric() || ch == '_' {
198                self.advance();
199            } else {
200                break;
201            }
202        }
203
204        if self.position == start {
205            return Err(ParseError {
206                message: "Expected identifier".to_string(),
207                position: self.position,
208            });
209        }
210
211        Ok(self.input[start..self.position].to_string())
212    }
213
214    /// Parse a string literal
215    fn parse_string(&mut self) -> ParseResult<String> {
216        let start = self.position;
217        while let Some(ch) = self.peek_char() {
218            if ch == '"' {
219                break;
220            }
221            self.advance();
222        }
223
224        Ok(self.input[start..self.position].to_string())
225    }
226
227    /// Parse a number
228    fn parse_number(&mut self) -> ParseResult<i64> {
229        let start = self.position;
230
231        if self.peek_char() == Some('-') {
232            self.advance();
233        }
234
235        while let Some(ch) = self.peek_char() {
236            if ch.is_ascii_digit() {
237                self.advance();
238            } else {
239                break;
240            }
241        }
242
243        self.input[start..self.position]
244            .parse()
245            .map_err(|_| ParseError {
246                message: "Invalid number".to_string(),
247                position: start,
248            })
249    }
250
251    /// Skip whitespace and comments
252    fn skip_whitespace(&mut self) {
253        while let Some(ch) = self.peek_char() {
254            if ch.is_whitespace() {
255                self.advance();
256            } else if ch == '%' {
257                // Comment - skip until end of line
258                while let Some(ch) = self.peek_char() {
259                    self.advance();
260                    if ch == '\n' {
261                        break;
262                    }
263                }
264            } else {
265                break;
266            }
267        }
268    }
269
270    /// Peek at the next character
271    fn peek_char(&self) -> Option<char> {
272        self.input[self.position..].chars().next()
273    }
274
275    /// Peek at the next n characters
276    fn peek_str(&self, n: usize) -> Option<&str> {
277        if self.position + n <= self.input.len() {
278            Some(&self.input[self.position..self.position + n])
279        } else {
280            None
281        }
282    }
283
284    /// Advance the position by one character
285    fn advance(&mut self) {
286        if let Some(ch) = self.peek_char() {
287            self.position += ch.len_utf8();
288        }
289    }
290
291    /// Expect a specific character
292    fn expect_char(&mut self, expected: char) -> ParseResult<()> {
293        self.skip_whitespace();
294        let ch = self.peek_char().ok_or_else(|| ParseError {
295            message: format!("Expected '{}' but found end of input", expected),
296            position: self.position,
297        })?;
298
299        if ch == expected {
300            self.advance();
301            Ok(())
302        } else {
303            Err(ParseError {
304                message: format!("Expected '{}' but found '{}'", expected, ch),
305                position: self.position,
306            })
307        }
308    }
309}
310
311/// Parsed Datalog statement
312#[derive(Debug, Clone)]
313pub enum Statement {
314    /// A fact
315    Fact(Predicate),
316    /// A rule
317    Rule(Rule),
318    /// A query
319    Query(Predicate),
320}
321
322/// Parse a Datalog fact
323pub fn parse_fact(input: &str) -> ParseResult<Predicate> {
324    let mut parser = DatalogParser::new(input.to_string());
325    match parser.parse_statement()? {
326        Statement::Fact(fact) => Ok(fact),
327        _ => Err(ParseError {
328            message: "Expected a fact".to_string(),
329            position: 0,
330        }),
331    }
332}
333
334/// Parse a Datalog rule
335pub fn parse_rule(input: &str) -> ParseResult<Rule> {
336    let mut parser = DatalogParser::new(input.to_string());
337    match parser.parse_statement()? {
338        Statement::Rule(rule) => Ok(rule),
339        _ => Err(ParseError {
340            message: "Expected a rule".to_string(),
341            position: 0,
342        }),
343    }
344}
345
346/// Parse a Datalog query
347pub fn parse_query(input: &str) -> ParseResult<Predicate> {
348    let mut parser = DatalogParser::new(input.to_string());
349    match parser.parse_statement()? {
350        Statement::Query(query) => Ok(query),
351        _ => Err(ParseError {
352            message: "Expected a query".to_string(),
353            position: 0,
354        }),
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_parse_fact() {
364        let fact = parse_fact("parent(alice, bob).").unwrap();
365        assert_eq!(fact.name, "parent");
366        assert_eq!(fact.arity(), 2);
367    }
368
369    #[test]
370    fn test_parse_rule() {
371        let rule = parse_rule("grandparent(X, Z) :- parent(X, Y), parent(Y, Z).").unwrap();
372        assert_eq!(rule.head.name, "grandparent");
373        assert_eq!(rule.body.len(), 2);
374    }
375
376    #[test]
377    fn test_parse_query() {
378        let query = parse_query("?- parent(alice, X).").unwrap();
379        assert_eq!(query.name, "parent");
380        assert_eq!(query.arity(), 2);
381    }
382
383    #[test]
384    fn test_parse_with_comments() {
385        let fact = parse_fact("parent(alice, bob). % Alice is parent of Bob").unwrap();
386        assert_eq!(fact.name, "parent");
387    }
388}