ipfrs_tensorlogic/
datalog.rs1use crate::ir::{Constant, Predicate, Rule, Term};
9use std::fmt;
10
11#[derive(Debug, Clone)]
13pub struct ParseError {
14 pub message: String,
15 pub position: usize,
16}
17
18impl fmt::Display for ParseError {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(
21 f,
22 "Parse error at position {}: {}",
23 self.position, self.message
24 )
25 }
26}
27
28impl std::error::Error for ParseError {}
29
30type ParseResult<T> = Result<T, ParseError>;
31
32pub struct DatalogParser {
34 input: String,
35 position: usize,
36}
37
38impl DatalogParser {
39 pub fn new(input: String) -> Self {
41 Self { input, position: 0 }
42 }
43
44 pub fn parse_statement(&mut self) -> ParseResult<Statement> {
46 self.skip_whitespace();
47
48 if self.peek_char() == Some('?') {
49 self.advance(); self.expect_char('-')?;
52 self.skip_whitespace();
53 let predicate = self.parse_predicate()?;
54 self.skip_whitespace();
55 self.expect_char('.')?;
56 Ok(Statement::Query(predicate))
57 } else {
58 let head = self.parse_predicate()?;
60 self.skip_whitespace();
61
62 if self.peek_char() == Some('.') {
63 self.advance();
65 Ok(Statement::Fact(head))
66 } else if self.peek_str(2) == Some(":-") {
67 self.advance();
69 self.advance();
70 self.skip_whitespace();
71
72 let body = self.parse_predicate_list()?;
73 self.skip_whitespace();
74 self.expect_char('.')?;
75
76 Ok(Statement::Rule(Rule::new(head, body)))
77 } else {
78 Err(ParseError {
79 message: "Expected '.' or ':-'".to_string(),
80 position: self.position,
81 })
82 }
83 }
84 }
85
86 fn parse_predicate(&mut self) -> ParseResult<Predicate> {
88 let name = self.parse_identifier()?;
89 self.skip_whitespace();
90 self.expect_char('(')?;
91 self.skip_whitespace();
92
93 let args = self.parse_term_list()?;
94 self.skip_whitespace();
95 self.expect_char(')')?;
96
97 Ok(Predicate::new(name, args))
98 }
99
100 fn parse_predicate_list(&mut self) -> ParseResult<Vec<Predicate>> {
102 let mut predicates = Vec::new();
103
104 loop {
105 predicates.push(self.parse_predicate()?);
106 self.skip_whitespace();
107
108 if self.peek_char() == Some(',') {
109 self.advance();
110 self.skip_whitespace();
111 } else {
112 break;
113 }
114 }
115
116 Ok(predicates)
117 }
118
119 fn parse_term_list(&mut self) -> ParseResult<Vec<Term>> {
121 let mut terms = Vec::new();
122
123 if self.peek_char() == Some(')') {
124 return Ok(terms); }
126
127 loop {
128 terms.push(self.parse_term()?);
129 self.skip_whitespace();
130
131 if self.peek_char() == Some(',') {
132 self.advance();
133 self.skip_whitespace();
134 } else {
135 break;
136 }
137 }
138
139 Ok(terms)
140 }
141
142 fn parse_term(&mut self) -> ParseResult<Term> {
144 self.skip_whitespace();
145
146 let ch = self.peek_char().ok_or_else(|| ParseError {
147 message: "Unexpected end of input".to_string(),
148 position: self.position,
149 })?;
150
151 if ch == '?' || ch.is_uppercase() {
152 if ch == '?' {
154 self.advance();
155 }
156 let name = self.parse_identifier()?;
157 Ok(Term::Var(name))
158 } else if ch == '"' {
159 self.advance(); let value = self.parse_string()?;
162 self.expect_char('"')?;
163 Ok(Term::Const(Constant::String(value)))
164 } else if ch.is_ascii_digit() || ch == '-' {
165 let value = self.parse_number()?;
167 Ok(Term::Const(Constant::Int(value)))
168 } else if ch.is_lowercase() {
169 let name = self.parse_identifier()?;
171 self.skip_whitespace();
172
173 if self.peek_char() == Some('(') {
174 self.advance();
176 self.skip_whitespace();
177 let args = self.parse_term_list()?;
178 self.skip_whitespace();
179 self.expect_char(')')?;
180 Ok(Term::Fun(name, args))
181 } else {
182 Ok(Term::Const(Constant::String(name)))
184 }
185 } else {
186 Err(ParseError {
187 message: format!("Unexpected character: '{}'", ch),
188 position: self.position,
189 })
190 }
191 }
192
193 fn parse_identifier(&mut self) -> ParseResult<String> {
195 let start = self.position;
196 while let Some(ch) = self.peek_char() {
197 if ch.is_alphanumeric() || ch == '_' {
198 self.advance();
199 } else {
200 break;
201 }
202 }
203
204 if self.position == start {
205 return Err(ParseError {
206 message: "Expected identifier".to_string(),
207 position: self.position,
208 });
209 }
210
211 Ok(self.input[start..self.position].to_string())
212 }
213
214 fn parse_string(&mut self) -> ParseResult<String> {
216 let start = self.position;
217 while let Some(ch) = self.peek_char() {
218 if ch == '"' {
219 break;
220 }
221 self.advance();
222 }
223
224 Ok(self.input[start..self.position].to_string())
225 }
226
227 fn parse_number(&mut self) -> ParseResult<i64> {
229 let start = self.position;
230
231 if self.peek_char() == Some('-') {
232 self.advance();
233 }
234
235 while let Some(ch) = self.peek_char() {
236 if ch.is_ascii_digit() {
237 self.advance();
238 } else {
239 break;
240 }
241 }
242
243 self.input[start..self.position]
244 .parse()
245 .map_err(|_| ParseError {
246 message: "Invalid number".to_string(),
247 position: start,
248 })
249 }
250
251 fn skip_whitespace(&mut self) {
253 while let Some(ch) = self.peek_char() {
254 if ch.is_whitespace() {
255 self.advance();
256 } else if ch == '%' {
257 while let Some(ch) = self.peek_char() {
259 self.advance();
260 if ch == '\n' {
261 break;
262 }
263 }
264 } else {
265 break;
266 }
267 }
268 }
269
270 fn peek_char(&self) -> Option<char> {
272 self.input[self.position..].chars().next()
273 }
274
275 fn peek_str(&self, n: usize) -> Option<&str> {
277 if self.position + n <= self.input.len() {
278 Some(&self.input[self.position..self.position + n])
279 } else {
280 None
281 }
282 }
283
284 fn advance(&mut self) {
286 if let Some(ch) = self.peek_char() {
287 self.position += ch.len_utf8();
288 }
289 }
290
291 fn expect_char(&mut self, expected: char) -> ParseResult<()> {
293 self.skip_whitespace();
294 let ch = self.peek_char().ok_or_else(|| ParseError {
295 message: format!("Expected '{}' but found end of input", expected),
296 position: self.position,
297 })?;
298
299 if ch == expected {
300 self.advance();
301 Ok(())
302 } else {
303 Err(ParseError {
304 message: format!("Expected '{}' but found '{}'", expected, ch),
305 position: self.position,
306 })
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
313pub enum Statement {
314 Fact(Predicate),
316 Rule(Rule),
318 Query(Predicate),
320}
321
322pub fn parse_fact(input: &str) -> ParseResult<Predicate> {
324 let mut parser = DatalogParser::new(input.to_string());
325 match parser.parse_statement()? {
326 Statement::Fact(fact) => Ok(fact),
327 _ => Err(ParseError {
328 message: "Expected a fact".to_string(),
329 position: 0,
330 }),
331 }
332}
333
334pub fn parse_rule(input: &str) -> ParseResult<Rule> {
336 let mut parser = DatalogParser::new(input.to_string());
337 match parser.parse_statement()? {
338 Statement::Rule(rule) => Ok(rule),
339 _ => Err(ParseError {
340 message: "Expected a rule".to_string(),
341 position: 0,
342 }),
343 }
344}
345
346pub fn parse_query(input: &str) -> ParseResult<Predicate> {
348 let mut parser = DatalogParser::new(input.to_string());
349 match parser.parse_statement()? {
350 Statement::Query(query) => Ok(query),
351 _ => Err(ParseError {
352 message: "Expected a query".to_string(),
353 position: 0,
354 }),
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_parse_fact() {
364 let fact = parse_fact("parent(alice, bob).").unwrap();
365 assert_eq!(fact.name, "parent");
366 assert_eq!(fact.arity(), 2);
367 }
368
369 #[test]
370 fn test_parse_rule() {
371 let rule = parse_rule("grandparent(X, Z) :- parent(X, Y), parent(Y, Z).").unwrap();
372 assert_eq!(rule.head.name, "grandparent");
373 assert_eq!(rule.body.len(), 2);
374 }
375
376 #[test]
377 fn test_parse_query() {
378 let query = parse_query("?- parent(alice, X).").unwrap();
379 assert_eq!(query.name, "parent");
380 assert_eq!(query.arity(), 2);
381 }
382
383 #[test]
384 fn test_parse_with_comments() {
385 let fact = parse_fact("parent(alice, bob). % Alice is parent of Bob").unwrap();
386 assert_eq!(fact.name, "parent");
387 }
388}