Skip to main content

mentedb_query/
lexer.rs

1//! Hand-written lexer for MQL.
2
3use mentedb_core::error::{MenteError, MenteResult};
4
5#[derive(Debug, Clone, PartialEq)]
6pub struct Token {
7    pub kind: TokenKind,
8    pub lexeme: String,
9    pub position: usize,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum TokenKind {
14    // Statements
15    Recall,
16    Relate,
17    Forget,
18    Consolidate,
19    Traverse,
20
21    // Clauses
22    Where,
23    And,
24    Or,
25    Not,
26    Near,
27    Within,
28    Limit,
29    OrderBy,
30    As,
31    From,
32    To,
33    With,
34
35    // Keywords
36    Agent,
37    Space,
38    Type,
39    Tag,
40    Salience,
41    Confidence,
42    Created,
43    Accessed,
44    Depth,
45    Hops,
46    Memories,
47    By,
48    EdgeType,
49
50    // Operators
51    Eq,        // =
52    Neq,       // !=
53    Gt,        // >
54    Lt,        // <
55    Gte,       // >=
56    Lte,       // <=
57    SimilarTo, // ~>
58    Arrow,     // ->
59
60    // Punctuation
61    LParen,
62    RParen,
63    LBracket,
64    RBracket,
65    Comma,
66    Dot,
67    Colon,
68    Semicolon,
69
70    // Literals
71    StringLit,
72    IntegerLit,
73    FloatLit,
74    Identifier,
75    UuidLit,
76
77    Eof,
78}
79
80pub fn tokenize(input: &str) -> MenteResult<Vec<Token>> {
81    let mut tokens = Vec::new();
82    let bytes = input.as_bytes();
83    let len = bytes.len();
84    let mut pos = 0;
85
86    while pos < len {
87        // Skip whitespace
88        if bytes[pos].is_ascii_whitespace() {
89            pos += 1;
90            continue;
91        }
92
93        let start = pos;
94
95        // String literal
96        if bytes[pos] == b'"' {
97            pos += 1;
98            while pos < len && bytes[pos] != b'"' {
99                if bytes[pos] == b'\\' {
100                    pos += 1; // skip escaped char
101                }
102                pos += 1;
103            }
104            if pos >= len {
105                return Err(MenteError::Query("unterminated string literal".into()));
106            }
107            pos += 1; // closing quote
108            let lexeme = input[start..pos].to_string();
109            tokens.push(Token {
110                kind: TokenKind::StringLit,
111                lexeme,
112                position: start,
113            });
114            continue;
115        }
116
117        // Two-char operators
118        if pos + 1 < len {
119            let two = &input[start..start + 2];
120            let kind = match two {
121                "!=" => Some(TokenKind::Neq),
122                ">=" => Some(TokenKind::Gte),
123                "<=" => Some(TokenKind::Lte),
124                "~>" => Some(TokenKind::SimilarTo),
125                "->" => Some(TokenKind::Arrow),
126                _ => None,
127            };
128            if let Some(k) = kind {
129                tokens.push(Token {
130                    kind: k,
131                    lexeme: two.to_string(),
132                    position: start,
133                });
134                pos += 2;
135                continue;
136            }
137        }
138
139        // Single-char operators/punctuation
140        let single = match bytes[pos] {
141            b'=' => Some(TokenKind::Eq),
142            b'>' => Some(TokenKind::Gt),
143            b'<' => Some(TokenKind::Lt),
144            b'(' => Some(TokenKind::LParen),
145            b')' => Some(TokenKind::RParen),
146            b'[' => Some(TokenKind::LBracket),
147            b']' => Some(TokenKind::RBracket),
148            b',' => Some(TokenKind::Comma),
149            b'.' => Some(TokenKind::Dot),
150            b':' => Some(TokenKind::Colon),
151            b';' => Some(TokenKind::Semicolon),
152            _ => None,
153        };
154        if let Some(k) = single {
155            tokens.push(Token {
156                kind: k,
157                lexeme: input[start..start + 1].to_string(),
158                position: start,
159            });
160            pos += 1;
161            continue;
162        }
163
164        // Try UUID first: if we see a hex digit, speculatively scan for UUID pattern
165        if bytes[pos].is_ascii_hexdigit() {
166            let saved = pos;
167            // Consume alphanumeric + hyphens to check for UUID
168            while pos < len
169                && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' || bytes[pos] == b'-')
170            {
171                pos += 1;
172            }
173            let candidate = &input[saved..pos];
174            if is_uuid_like(candidate) {
175                tokens.push(Token {
176                    kind: TokenKind::UuidLit,
177                    lexeme: candidate.to_string(),
178                    position: start,
179                });
180                continue;
181            }
182            // Not a UUID — reset and fall through to number/identifier parsing
183            pos = saved;
184        }
185
186        // Numbers (may start with - for negative)
187        if bytes[pos].is_ascii_digit()
188            || (bytes[pos] == b'-' && pos + 1 < len && bytes[pos + 1].is_ascii_digit())
189        {
190            if bytes[pos] == b'-' {
191                pos += 1;
192            }
193            while pos < len && bytes[pos].is_ascii_digit() {
194                pos += 1;
195            }
196            let mut is_float = false;
197            if pos < len && bytes[pos] == b'.' && pos + 1 < len && bytes[pos + 1].is_ascii_digit() {
198                is_float = true;
199                pos += 1;
200                while pos < len && bytes[pos].is_ascii_digit() {
201                    pos += 1;
202                }
203            }
204            let lexeme = input[start..pos].to_string();
205            let kind = if is_float {
206                TokenKind::FloatLit
207            } else {
208                TokenKind::IntegerLit
209            };
210            tokens.push(Token {
211                kind,
212                lexeme,
213                position: start,
214            });
215            continue;
216        }
217
218        // Identifiers, keywords
219        if bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' {
220            while pos < len
221                && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' || bytes[pos] == b'-')
222            {
223                pos += 1;
224            }
225            let lexeme = input[start..pos].to_string();
226
227            let kind = match lexeme.to_lowercase().as_str() {
228                "recall" => TokenKind::Recall,
229                "relate" => TokenKind::Relate,
230                "forget" => TokenKind::Forget,
231                "consolidate" => TokenKind::Consolidate,
232                "traverse" => TokenKind::Traverse,
233                "where" => TokenKind::Where,
234                "and" => TokenKind::And,
235                "or" => TokenKind::Or,
236                "not" => TokenKind::Not,
237                "near" => TokenKind::Near,
238                "within" => TokenKind::Within,
239                "limit" => TokenKind::Limit,
240                "order" => TokenKind::OrderBy,
241                "as" => TokenKind::As,
242                "from" => TokenKind::From,
243                "to" => TokenKind::To,
244                "with" => TokenKind::With,
245                "agent" => TokenKind::Agent,
246                "space" => TokenKind::Space,
247                "type" => TokenKind::Type,
248                "tag" => TokenKind::Tag,
249                "salience" => TokenKind::Salience,
250                "confidence" => TokenKind::Confidence,
251                "created" => TokenKind::Created,
252                "accessed" => TokenKind::Accessed,
253                "depth" => TokenKind::Depth,
254                "hops" => TokenKind::Hops,
255                "memories" => TokenKind::Memories,
256                "by" => TokenKind::By,
257                "edge_type" => TokenKind::EdgeType,
258                _ => TokenKind::Identifier,
259            };
260            tokens.push(Token {
261                kind,
262                lexeme,
263                position: start,
264            });
265            continue;
266        }
267
268        return Err(MenteError::Query(format!(
269            "unexpected character '{}' at position {}",
270            bytes[pos] as char, pos
271        )));
272    }
273
274    tokens.push(Token {
275        kind: TokenKind::Eof,
276        lexeme: String::new(),
277        position: pos,
278    });
279    Ok(tokens)
280}
281
282fn is_uuid_like(s: &str) -> bool {
283    // UUID format: 8-4-4-4-12 hex chars (with dashes)
284    if s.len() != 36 {
285        return false;
286    }
287    let parts: Vec<&str> = s.split('-').collect();
288    if parts.len() != 5 {
289        return false;
290    }
291    let expected_lens = [8, 4, 4, 4, 12];
292    for (part, &expected) in parts.iter().zip(&expected_lens) {
293        if part.len() != expected || !part.chars().all(|c| c.is_ascii_hexdigit()) {
294            return false;
295        }
296    }
297    true
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_recall_statement_tokens() {
306        let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 10").unwrap();
307        assert_eq!(tokens[0].kind, TokenKind::Recall);
308        assert_eq!(tokens[1].kind, TokenKind::Memories);
309        assert_eq!(tokens[2].kind, TokenKind::Where);
310        assert_eq!(tokens[3].kind, TokenKind::Type);
311        assert_eq!(tokens[4].kind, TokenKind::Eq);
312        assert_eq!(tokens[5].kind, TokenKind::Identifier);
313        assert_eq!(tokens[5].lexeme, "episodic");
314        assert_eq!(tokens[6].kind, TokenKind::Limit);
315        assert_eq!(tokens[7].kind, TokenKind::IntegerLit);
316        assert_eq!(tokens[8].kind, TokenKind::Eof);
317    }
318
319    #[test]
320    fn test_string_literal() {
321        let tokens = tokenize(r#"content ~> "database migration""#).unwrap();
322        assert_eq!(tokens[0].kind, TokenKind::Identifier);
323        assert_eq!(tokens[1].kind, TokenKind::SimilarTo);
324        assert_eq!(tokens[2].kind, TokenKind::StringLit);
325        assert_eq!(tokens[2].lexeme, r#""database migration""#);
326    }
327
328    #[test]
329    fn test_operators() {
330        let tokens = tokenize("= != > < >= <= ~> ->").unwrap();
331        let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
332        assert_eq!(
333            kinds,
334            vec![
335                TokenKind::Eq,
336                TokenKind::Neq,
337                TokenKind::Gt,
338                TokenKind::Lt,
339                TokenKind::Gte,
340                TokenKind::Lte,
341                TokenKind::SimilarTo,
342                TokenKind::Arrow,
343                TokenKind::Eof,
344            ]
345        );
346    }
347
348    #[test]
349    fn test_uuid_token() {
350        let tokens = tokenize("550e8400-e29b-41d4-a716-446655440000").unwrap();
351        assert_eq!(tokens[0].kind, TokenKind::UuidLit);
352    }
353
354    #[test]
355    fn test_float_literal() {
356        let tokens = tokenize("0.1 42 3.14").unwrap();
357        assert_eq!(tokens[0].kind, TokenKind::FloatLit);
358        assert_eq!(tokens[1].kind, TokenKind::IntegerLit);
359        assert_eq!(tokens[2].kind, TokenKind::FloatLit);
360    }
361
362    #[test]
363    fn test_vector_literal() {
364        let tokens = tokenize("[0.1, 0.2, 0.3]").unwrap();
365        assert_eq!(tokens[0].kind, TokenKind::LBracket);
366        assert_eq!(tokens[1].kind, TokenKind::FloatLit);
367        assert_eq!(tokens[2].kind, TokenKind::Comma);
368        assert_eq!(tokens[5].kind, TokenKind::FloatLit);
369        assert_eq!(tokens[6].kind, TokenKind::RBracket);
370    }
371
372    #[test]
373    fn test_punctuation() {
374        let tokens = tokenize("( ) [ ] , . : ;").unwrap();
375        let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
376        assert_eq!(
377            kinds,
378            vec![
379                TokenKind::LParen,
380                TokenKind::RParen,
381                TokenKind::LBracket,
382                TokenKind::RBracket,
383                TokenKind::Comma,
384                TokenKind::Dot,
385                TokenKind::Colon,
386                TokenKind::Semicolon,
387                TokenKind::Eof,
388            ]
389        );
390    }
391}