Skip to main content

patch_prolog_core/
tokenizer.rs

1use serde::{Deserialize, Serialize};
2
3/// Token types for Edinburgh Prolog.
4#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
5pub enum TokenKind {
6    // Identifiers
7    Atom(String),     // lowercase-starting or single-quoted
8    Variable(String), // uppercase-starting or _
9    Integer(i64),
10    Float(f64),
11
12    // Operators
13    Neck,      // :-
14    QueryOp,   // ?-
15    Equals,    // =
16    NotEquals, // \=
17    Is,        // is
18    Lt,        // <
19    Gt,        // >
20    Lte,       // =<
21    Gte,       // >=
22    ArithEq,   // =:=
23    ArithNeq,  // =\=
24    Plus,      // +
25    Minus,     // -
26    Star,      // *
27    Slash,     // /
28    Mod,       // mod
29    Not,       // \+
30    Cut,       // !
31    Arrow,     // ->
32    Semicolon, // ;
33
34    // Punctuation
35    Dot,      // .
36    Comma,    // ,
37    LParen,   // (
38    RParen,   // )
39    LBracket, // [
40    RBracket, // ]
41    Pipe,     // |
42
43    // End of input
44    Eof,
45}
46
47#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
48pub struct Token {
49    pub kind: TokenKind,
50    pub line: usize,
51    pub col: usize,
52}
53
54pub struct Tokenizer<'a> {
55    input: &'a [u8],
56    pos: usize,
57    line: usize,
58    col: usize,
59}
60
61impl<'a> Tokenizer<'a> {
62    pub fn new(input: &'a str) -> Self {
63        Tokenizer {
64            input: input.as_bytes(),
65            pos: 0,
66            line: 1,
67            col: 1,
68        }
69    }
70
71    pub fn tokenize(input: &str) -> Result<Vec<Token>, String> {
72        let mut tok = Tokenizer::new(input);
73        let mut tokens = Vec::new();
74        loop {
75            let t = tok.next_token()?;
76            if t.kind == TokenKind::Eof {
77                tokens.push(t);
78                break;
79            }
80            tokens.push(t);
81        }
82        Ok(tokens)
83    }
84
85    fn peek(&self) -> Option<u8> {
86        if self.pos < self.input.len() {
87            Some(self.input[self.pos])
88        } else {
89            None
90        }
91    }
92
93    fn peek_at(&self, offset: usize) -> Option<u8> {
94        let idx = self.pos + offset;
95        if idx < self.input.len() {
96            Some(self.input[idx])
97        } else {
98            None
99        }
100    }
101
102    fn advance(&mut self) -> u8 {
103        let ch = self.input[self.pos];
104        self.pos += 1;
105        if ch == b'\n' {
106            self.line += 1;
107            self.col = 1;
108        } else {
109            self.col += 1;
110        }
111        ch
112    }
113
114    fn skip_whitespace(&mut self) {
115        while let Some(ch) = self.peek() {
116            match ch {
117                b' ' | b'\t' | b'\r' | b'\n' => {
118                    self.advance();
119                }
120                b'%' => {
121                    // Line comment
122                    while let Some(ch) = self.peek() {
123                        if ch == b'\n' {
124                            break;
125                        }
126                        self.advance();
127                    }
128                }
129                b'/' if self.peek_at(1) == Some(b'*') => {
130                    // Block comment
131                    self.advance(); // /
132                    self.advance(); // *
133                    loop {
134                        match self.peek() {
135                            None => break,
136                            Some(b'*') if self.peek_at(1) == Some(b'/') => {
137                                self.advance();
138                                self.advance();
139                                break;
140                            }
141                            _ => {
142                                self.advance();
143                            }
144                        }
145                    }
146                }
147                _ => break,
148            }
149        }
150    }
151
152    fn next_token(&mut self) -> Result<Token, String> {
153        self.skip_whitespace();
154
155        let line = self.line;
156        let col = self.col;
157
158        let ch = match self.peek() {
159            None => {
160                return Ok(Token {
161                    kind: TokenKind::Eof,
162                    line,
163                    col,
164                })
165            }
166            Some(ch) => ch,
167        };
168
169        match ch {
170            b'(' => {
171                self.advance();
172                Ok(Token {
173                    kind: TokenKind::LParen,
174                    line,
175                    col,
176                })
177            }
178            b')' => {
179                self.advance();
180                Ok(Token {
181                    kind: TokenKind::RParen,
182                    line,
183                    col,
184                })
185            }
186            b'[' => {
187                self.advance();
188                // Check for []
189                if self.peek() == Some(b']') {
190                    self.advance();
191                    Ok(Token {
192                        kind: TokenKind::Atom("[]".into()),
193                        line,
194                        col,
195                    })
196                } else {
197                    Ok(Token {
198                        kind: TokenKind::LBracket,
199                        line,
200                        col,
201                    })
202                }
203            }
204            b']' => {
205                self.advance();
206                Ok(Token {
207                    kind: TokenKind::RBracket,
208                    line,
209                    col,
210                })
211            }
212            b'|' => {
213                self.advance();
214                Ok(Token {
215                    kind: TokenKind::Pipe,
216                    line,
217                    col,
218                })
219            }
220            b',' => {
221                self.advance();
222                Ok(Token {
223                    kind: TokenKind::Comma,
224                    line,
225                    col,
226                })
227            }
228            b'!' => {
229                self.advance();
230                Ok(Token {
231                    kind: TokenKind::Cut,
232                    line,
233                    col,
234                })
235            }
236            b';' => {
237                self.advance();
238                Ok(Token {
239                    kind: TokenKind::Semicolon,
240                    line,
241                    col,
242                })
243            }
244
245            b'.' => {
246                self.advance();
247                // Check if followed by whitespace/EOF/comment (end of clause)
248                // vs followed by digit (float - but we handle that in number parsing)
249                Ok(Token {
250                    kind: TokenKind::Dot,
251                    line,
252                    col,
253                })
254            }
255
256            b':' => {
257                self.advance();
258                if self.peek() == Some(b'-') {
259                    self.advance();
260                    Ok(Token {
261                        kind: TokenKind::Neck,
262                        line,
263                        col,
264                    })
265                } else {
266                    Err(format!("Unexpected ':' at line {} col {}", line, col))
267                }
268            }
269
270            b'?' => {
271                self.advance();
272                if self.peek() == Some(b'-') {
273                    self.advance();
274                    Ok(Token {
275                        kind: TokenKind::QueryOp,
276                        line,
277                        col,
278                    })
279                } else {
280                    Err(format!("Unexpected '?' at line {} col {}", line, col))
281                }
282            }
283
284            b'=' => {
285                self.advance();
286                match self.peek() {
287                    Some(b':') if self.peek_at(1) == Some(b'=') => {
288                        self.advance();
289                        self.advance();
290                        Ok(Token {
291                            kind: TokenKind::ArithEq,
292                            line,
293                            col,
294                        })
295                    }
296                    Some(b'\\') if self.peek_at(1) == Some(b'=') => {
297                        self.advance();
298                        self.advance();
299                        Ok(Token {
300                            kind: TokenKind::ArithNeq,
301                            line,
302                            col,
303                        })
304                    }
305                    Some(b'<') => {
306                        self.advance();
307                        Ok(Token {
308                            kind: TokenKind::Lte,
309                            line,
310                            col,
311                        })
312                    }
313                    Some(b'.') if self.peek_at(1) == Some(b'.') => {
314                        self.advance();
315                        self.advance();
316                        Ok(Token {
317                            kind: TokenKind::Atom("=..".into()),
318                            line,
319                            col,
320                        })
321                    }
322                    _ => Ok(Token {
323                        kind: TokenKind::Equals,
324                        line,
325                        col,
326                    }),
327                }
328            }
329
330            b'\\' => {
331                self.advance();
332                match self.peek() {
333                    Some(b'=') => {
334                        self.advance();
335                        Ok(Token {
336                            kind: TokenKind::NotEquals,
337                            line,
338                            col,
339                        })
340                    }
341                    Some(b'+') => {
342                        self.advance();
343                        Ok(Token {
344                            kind: TokenKind::Not,
345                            line,
346                            col,
347                        })
348                    }
349                    _ => Err(format!("Unexpected '\\' at line {} col {}", line, col)),
350                }
351            }
352
353            b'<' => {
354                self.advance();
355                Ok(Token {
356                    kind: TokenKind::Lt,
357                    line,
358                    col,
359                })
360            }
361            b'>' => {
362                self.advance();
363                if self.peek() == Some(b'=') {
364                    self.advance();
365                    Ok(Token {
366                        kind: TokenKind::Gte,
367                        line,
368                        col,
369                    })
370                } else {
371                    Ok(Token {
372                        kind: TokenKind::Gt,
373                        line,
374                        col,
375                    })
376                }
377            }
378
379            b'@' => {
380                self.advance();
381                match self.peek() {
382                    Some(b'<') => {
383                        self.advance();
384                        Ok(Token {
385                            kind: TokenKind::Atom("@<".into()),
386                            line,
387                            col,
388                        })
389                    }
390                    Some(b'>') => {
391                        self.advance();
392                        if self.peek() == Some(b'=') {
393                            self.advance();
394                            Ok(Token {
395                                kind: TokenKind::Atom("@>=".into()),
396                                line,
397                                col,
398                            })
399                        } else {
400                            Ok(Token {
401                                kind: TokenKind::Atom("@>".into()),
402                                line,
403                                col,
404                            })
405                        }
406                    }
407                    Some(b'=') if self.peek_at(1) == Some(b'<') => {
408                        self.advance();
409                        self.advance();
410                        Ok(Token {
411                            kind: TokenKind::Atom("@=<".into()),
412                            line,
413                            col,
414                        })
415                    }
416                    _ => Err(format!("Unexpected '@' at line {} col {}", line, col)),
417                }
418            }
419
420            b'+' => {
421                self.advance();
422                Ok(Token {
423                    kind: TokenKind::Plus,
424                    line,
425                    col,
426                })
427            }
428            b'*' => {
429                self.advance();
430                Ok(Token {
431                    kind: TokenKind::Star,
432                    line,
433                    col,
434                })
435            }
436            b'/' => {
437                self.advance();
438                Ok(Token {
439                    kind: TokenKind::Slash,
440                    line,
441                    col,
442                })
443            }
444
445            b'-' => {
446                self.advance();
447                // Check for -> (arrow)
448                if self.peek() == Some(b'>') {
449                    self.advance();
450                    return Ok(Token {
451                        kind: TokenKind::Arrow,
452                        line,
453                        col,
454                    });
455                }
456                // Check if this is a negative number: dash followed by digit
457                if let Some(d) = self.peek() {
458                    if d.is_ascii_digit() {
459                        return Ok(Token {
460                            kind: TokenKind::Minus,
461                            line,
462                            col,
463                        });
464                    }
465                }
466                Ok(Token {
467                    kind: TokenKind::Minus,
468                    line,
469                    col,
470                })
471            }
472
473            b'\'' => self.read_quoted_atom(line, col),
474
475            b'0'..=b'9' => self.read_number(line, col),
476
477            b'a'..=b'z' => self.read_atom(line, col),
478
479            b'A'..=b'Z' | b'_' => self.read_variable(line, col),
480
481            _ => {
482                self.advance();
483                Err(format!(
484                    "Unexpected character '{}' at line {} col {}",
485                    ch as char, line, col
486                ))
487            }
488        }
489    }
490
491    fn read_atom(&mut self, line: usize, col: usize) -> Result<Token, String> {
492        let mut s = String::new();
493        while let Some(ch) = self.peek() {
494            if ch.is_ascii_alphanumeric() || ch == b'_' {
495                s.push(self.advance() as char);
496            } else {
497                break;
498            }
499        }
500        // Check for keyword operators
501        let kind = match s.as_str() {
502            "is" => TokenKind::Is,
503            "mod" => TokenKind::Mod,
504            _ => TokenKind::Atom(s),
505        };
506        Ok(Token { kind, line, col })
507    }
508
509    fn read_variable(&mut self, line: usize, col: usize) -> Result<Token, String> {
510        let mut s = String::new();
511        while let Some(ch) = self.peek() {
512            if ch.is_ascii_alphanumeric() || ch == b'_' {
513                s.push(self.advance() as char);
514            } else {
515                break;
516            }
517        }
518        Ok(Token {
519            kind: TokenKind::Variable(s),
520            line,
521            col,
522        })
523    }
524
525    fn read_number(&mut self, line: usize, col: usize) -> Result<Token, String> {
526        let mut s = String::new();
527        let mut is_float = false;
528
529        while let Some(ch) = self.peek() {
530            if ch.is_ascii_digit() {
531                s.push(self.advance() as char);
532            } else if ch == b'.' {
533                // Check if next char after dot is a digit (float), otherwise it's a clause terminator
534                if let Some(next) = self.peek_at(1) {
535                    if next.is_ascii_digit() {
536                        is_float = true;
537                        s.push(self.advance() as char); // consume .
538                        while let Some(d) = self.peek() {
539                            if d.is_ascii_digit() {
540                                s.push(self.advance() as char);
541                            } else {
542                                break;
543                            }
544                        }
545                    } else {
546                        break; // dot is clause terminator
547                    }
548                } else {
549                    break; // dot at EOF
550                }
551            } else {
552                break;
553            }
554        }
555
556        if is_float {
557            let val: f64 = s
558                .parse()
559                .map_err(|e| format!("Invalid float '{}': {}", s, e))?;
560            if val.is_infinite() {
561                return Err(format!(
562                    "Float literal '{}' overflows f64 at line {} col {}",
563                    s, line, col
564                ));
565            }
566            Ok(Token {
567                kind: TokenKind::Float(val),
568                line,
569                col,
570            })
571        } else {
572            let val: i64 = s
573                .parse()
574                .map_err(|e| format!("Invalid integer '{}': {}", s, e))?;
575            Ok(Token {
576                kind: TokenKind::Integer(val),
577                line,
578                col,
579            })
580        }
581    }
582
583    fn read_quoted_atom(&mut self, line: usize, col: usize) -> Result<Token, String> {
584        self.advance(); // skip opening quote
585        let mut s = String::new();
586        loop {
587            match self.peek() {
588                None => {
589                    return Err(format!(
590                        "Unterminated quoted atom at line {} col {}",
591                        line, col
592                    ))
593                }
594                Some(b'\'') => {
595                    self.advance();
596                    // Check for escaped quote ''
597                    if self.peek() == Some(b'\'') {
598                        s.push('\'');
599                        self.advance();
600                    } else {
601                        break;
602                    }
603                }
604                Some(b'\\') => {
605                    self.advance();
606                    match self.peek() {
607                        Some(b'\'') => {
608                            s.push('\'');
609                            self.advance();
610                        }
611                        Some(b'\\') => {
612                            s.push('\\');
613                            self.advance();
614                        }
615                        Some(b'n') => {
616                            s.push('\n');
617                            self.advance();
618                        }
619                        Some(b't') => {
620                            s.push('\t');
621                            self.advance();
622                        }
623                        Some(ch) => {
624                            s.push(ch as char);
625                            self.advance();
626                        }
627                        None => {
628                            return Err(format!(
629                                "Unterminated escape at line {} col {}",
630                                self.line, self.col
631                            ))
632                        }
633                    }
634                }
635                Some(ch) => {
636                    s.push(ch as char);
637                    self.advance();
638                }
639            }
640        }
641        Ok(Token {
642            kind: TokenKind::Atom(s),
643            line,
644            col,
645        })
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    fn tok(input: &str) -> Vec<TokenKind> {
654        Tokenizer::tokenize(input)
655            .unwrap()
656            .into_iter()
657            .map(|t| t.kind)
658            .filter(|k| *k != TokenKind::Eof)
659            .collect()
660    }
661
662    #[test]
663    fn test_atoms() {
664        assert_eq!(tok("hello"), vec![TokenKind::Atom("hello".into())]);
665        assert_eq!(tok("foo_bar"), vec![TokenKind::Atom("foo_bar".into())]);
666        assert_eq!(tok("a123"), vec![TokenKind::Atom("a123".into())]);
667    }
668
669    #[test]
670    fn test_quoted_atoms() {
671        assert_eq!(
672            tok("'hello world'"),
673            vec![TokenKind::Atom("hello world".into())]
674        );
675        assert_eq!(tok("'it''s'"), vec![TokenKind::Atom("it's".into())]);
676    }
677
678    #[test]
679    fn test_variables() {
680        assert_eq!(tok("X"), vec![TokenKind::Variable("X".into())]);
681        assert_eq!(tok("_foo"), vec![TokenKind::Variable("_foo".into())]);
682        assert_eq!(tok("_"), vec![TokenKind::Variable("_".into())]);
683        assert_eq!(tok("MyVar"), vec![TokenKind::Variable("MyVar".into())]);
684    }
685
686    #[test]
687    fn test_numbers() {
688        assert_eq!(tok("42"), vec![TokenKind::Integer(42)]);
689        assert_eq!(tok("3.14"), vec![TokenKind::Float(3.14)]);
690        assert_eq!(tok("0"), vec![TokenKind::Integer(0)]);
691    }
692
693    #[test]
694    fn test_operators() {
695        assert_eq!(tok(":-"), vec![TokenKind::Neck]);
696        assert_eq!(tok("?-"), vec![TokenKind::QueryOp]);
697        assert_eq!(tok("="), vec![TokenKind::Equals]);
698        assert_eq!(tok("\\="), vec![TokenKind::NotEquals]);
699        assert_eq!(tok("is"), vec![TokenKind::Is]);
700        assert_eq!(tok("<"), vec![TokenKind::Lt]);
701        assert_eq!(tok(">"), vec![TokenKind::Gt]);
702        assert_eq!(tok("=<"), vec![TokenKind::Lte]);
703        assert_eq!(tok(">="), vec![TokenKind::Gte]);
704        assert_eq!(tok("=:="), vec![TokenKind::ArithEq]);
705        assert_eq!(tok("=\\="), vec![TokenKind::ArithNeq]);
706        assert_eq!(tok("\\+"), vec![TokenKind::Not]);
707    }
708
709    #[test]
710    fn test_punctuation() {
711        assert_eq!(
712            tok("( ) | , ."),
713            vec![
714                TokenKind::LParen,
715                TokenKind::RParen,
716                TokenKind::Pipe,
717                TokenKind::Comma,
718                TokenKind::Dot,
719            ]
720        );
721        // [ ] with space is separate tokens, not []
722        assert_eq!(tok("[ ]"), vec![TokenKind::LBracket, TokenKind::RBracket,]);
723    }
724
725    #[test]
726    fn test_cut() {
727        assert_eq!(tok("!"), vec![TokenKind::Cut]);
728    }
729
730    #[test]
731    fn test_clause() {
732        let tokens = tok("parent(tom, mary).");
733        assert_eq!(
734            tokens,
735            vec![
736                TokenKind::Atom("parent".into()),
737                TokenKind::LParen,
738                TokenKind::Atom("tom".into()),
739                TokenKind::Comma,
740                TokenKind::Atom("mary".into()),
741                TokenKind::RParen,
742                TokenKind::Dot,
743            ]
744        );
745    }
746
747    #[test]
748    fn test_rule() {
749        let tokens = tok("happy(X) :- likes(X, food).");
750        assert_eq!(
751            tokens,
752            vec![
753                TokenKind::Atom("happy".into()),
754                TokenKind::LParen,
755                TokenKind::Variable("X".into()),
756                TokenKind::RParen,
757                TokenKind::Neck,
758                TokenKind::Atom("likes".into()),
759                TokenKind::LParen,
760                TokenKind::Variable("X".into()),
761                TokenKind::Comma,
762                TokenKind::Atom("food".into()),
763                TokenKind::RParen,
764                TokenKind::Dot,
765            ]
766        );
767    }
768
769    #[test]
770    fn test_arithmetic() {
771        let tokens = tok("X is 2 + 3 * 4.");
772        assert_eq!(
773            tokens,
774            vec![
775                TokenKind::Variable("X".into()),
776                TokenKind::Is,
777                TokenKind::Integer(2),
778                TokenKind::Plus,
779                TokenKind::Integer(3),
780                TokenKind::Star,
781                TokenKind::Integer(4),
782                TokenKind::Dot,
783            ]
784        );
785    }
786
787    #[test]
788    fn test_line_comment() {
789        assert_eq!(
790            tok("foo % this is a comment\nbar"),
791            vec![TokenKind::Atom("foo".into()), TokenKind::Atom("bar".into()),]
792        );
793    }
794
795    #[test]
796    fn test_block_comment() {
797        assert_eq!(
798            tok("foo /* block */ bar"),
799            vec![TokenKind::Atom("foo".into()), TokenKind::Atom("bar".into()),]
800        );
801    }
802
803    #[test]
804    fn test_empty_list() {
805        assert_eq!(tok("[]"), vec![TokenKind::Atom("[]".into())]);
806    }
807
808    #[test]
809    fn test_list_syntax() {
810        let tokens = tok("[1, 2, 3]");
811        assert_eq!(
812            tokens,
813            vec![
814                TokenKind::LBracket,
815                TokenKind::Integer(1),
816                TokenKind::Comma,
817                TokenKind::Integer(2),
818                TokenKind::Comma,
819                TokenKind::Integer(3),
820                TokenKind::RBracket,
821            ]
822        );
823    }
824
825    #[test]
826    fn test_minus_operator() {
827        assert_eq!(
828            tok("5 - 3"),
829            vec![
830                TokenKind::Integer(5),
831                TokenKind::Minus,
832                TokenKind::Integer(3),
833            ]
834        );
835    }
836}