1use mentedb_core::error::{MenteError, MenteResult};
4
5#[derive(Debug, Clone, PartialEq)]
6pub struct Token {
7 pub kind: TokenKind,
8 pub lexeme: String,
9 pub position: usize,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum TokenKind {
14 Recall,
16 Relate,
17 Forget,
18 Consolidate,
19 Traverse,
20
21 Where,
23 And,
24 Or,
25 Not,
26 Near,
27 Within,
28 Limit,
29 OrderBy,
30 As,
31 From,
32 To,
33 With,
34
35 Agent,
37 Space,
38 Type,
39 Tag,
40 Salience,
41 Confidence,
42 Created,
43 Accessed,
44 Depth,
45 Hops,
46 Memories,
47 By,
48 EdgeType,
49
50 Eq, Neq, Gt, Lt, Gte, Lte, SimilarTo, Arrow, LParen,
62 RParen,
63 LBracket,
64 RBracket,
65 Comma,
66 Dot,
67 Colon,
68 Semicolon,
69
70 StringLit,
72 IntegerLit,
73 FloatLit,
74 Identifier,
75 UuidLit,
76
77 Eof,
78}
79
80pub fn tokenize(input: &str) -> MenteResult<Vec<Token>> {
81 let mut tokens = Vec::new();
82 let bytes = input.as_bytes();
83 let len = bytes.len();
84 let mut pos = 0;
85
86 while pos < len {
87 if bytes[pos].is_ascii_whitespace() {
89 pos += 1;
90 continue;
91 }
92
93 let start = pos;
94
95 if bytes[pos] == b'"' {
97 pos += 1;
98 while pos < len && bytes[pos] != b'"' {
99 if bytes[pos] == b'\\' {
100 pos += 1; }
102 pos += 1;
103 }
104 if pos >= len {
105 return Err(MenteError::Query("unterminated string literal".into()));
106 }
107 pos += 1; let lexeme = input[start..pos].to_string();
109 tokens.push(Token {
110 kind: TokenKind::StringLit,
111 lexeme,
112 position: start,
113 });
114 continue;
115 }
116
117 if pos + 1 < len {
119 let two = &input[start..start + 2];
120 let kind = match two {
121 "!=" => Some(TokenKind::Neq),
122 ">=" => Some(TokenKind::Gte),
123 "<=" => Some(TokenKind::Lte),
124 "~>" => Some(TokenKind::SimilarTo),
125 "->" => Some(TokenKind::Arrow),
126 _ => None,
127 };
128 if let Some(k) = kind {
129 tokens.push(Token {
130 kind: k,
131 lexeme: two.to_string(),
132 position: start,
133 });
134 pos += 2;
135 continue;
136 }
137 }
138
139 let single = match bytes[pos] {
141 b'=' => Some(TokenKind::Eq),
142 b'>' => Some(TokenKind::Gt),
143 b'<' => Some(TokenKind::Lt),
144 b'(' => Some(TokenKind::LParen),
145 b')' => Some(TokenKind::RParen),
146 b'[' => Some(TokenKind::LBracket),
147 b']' => Some(TokenKind::RBracket),
148 b',' => Some(TokenKind::Comma),
149 b'.' => Some(TokenKind::Dot),
150 b':' => Some(TokenKind::Colon),
151 b';' => Some(TokenKind::Semicolon),
152 _ => None,
153 };
154 if let Some(k) = single {
155 tokens.push(Token {
156 kind: k,
157 lexeme: input[start..start + 1].to_string(),
158 position: start,
159 });
160 pos += 1;
161 continue;
162 }
163
164 if bytes[pos].is_ascii_hexdigit() {
166 let saved = pos;
167 while pos < len
169 && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' || bytes[pos] == b'-')
170 {
171 pos += 1;
172 }
173 let candidate = &input[saved..pos];
174 if is_uuid_like(candidate) {
175 tokens.push(Token {
176 kind: TokenKind::UuidLit,
177 lexeme: candidate.to_string(),
178 position: start,
179 });
180 continue;
181 }
182 pos = saved;
184 }
185
186 if bytes[pos].is_ascii_digit()
188 || (bytes[pos] == b'-' && pos + 1 < len && bytes[pos + 1].is_ascii_digit())
189 {
190 if bytes[pos] == b'-' {
191 pos += 1;
192 }
193 while pos < len && bytes[pos].is_ascii_digit() {
194 pos += 1;
195 }
196 let mut is_float = false;
197 if pos < len && bytes[pos] == b'.' && pos + 1 < len && bytes[pos + 1].is_ascii_digit() {
198 is_float = true;
199 pos += 1;
200 while pos < len && bytes[pos].is_ascii_digit() {
201 pos += 1;
202 }
203 }
204 let lexeme = input[start..pos].to_string();
205 let kind = if is_float {
206 TokenKind::FloatLit
207 } else {
208 TokenKind::IntegerLit
209 };
210 tokens.push(Token {
211 kind,
212 lexeme,
213 position: start,
214 });
215 continue;
216 }
217
218 if bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' {
220 while pos < len
221 && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_' || bytes[pos] == b'-')
222 {
223 pos += 1;
224 }
225 let lexeme = input[start..pos].to_string();
226
227 let kind = match lexeme.to_lowercase().as_str() {
228 "recall" => TokenKind::Recall,
229 "relate" => TokenKind::Relate,
230 "forget" => TokenKind::Forget,
231 "consolidate" => TokenKind::Consolidate,
232 "traverse" => TokenKind::Traverse,
233 "where" => TokenKind::Where,
234 "and" => TokenKind::And,
235 "or" => TokenKind::Or,
236 "not" => TokenKind::Not,
237 "near" => TokenKind::Near,
238 "within" => TokenKind::Within,
239 "limit" => TokenKind::Limit,
240 "order" => TokenKind::OrderBy,
241 "as" => TokenKind::As,
242 "from" => TokenKind::From,
243 "to" => TokenKind::To,
244 "with" => TokenKind::With,
245 "agent" => TokenKind::Agent,
246 "space" => TokenKind::Space,
247 "type" => TokenKind::Type,
248 "tag" => TokenKind::Tag,
249 "salience" => TokenKind::Salience,
250 "confidence" => TokenKind::Confidence,
251 "created" => TokenKind::Created,
252 "accessed" => TokenKind::Accessed,
253 "depth" => TokenKind::Depth,
254 "hops" => TokenKind::Hops,
255 "memories" => TokenKind::Memories,
256 "by" => TokenKind::By,
257 "edge_type" => TokenKind::EdgeType,
258 _ => TokenKind::Identifier,
259 };
260 tokens.push(Token {
261 kind,
262 lexeme,
263 position: start,
264 });
265 continue;
266 }
267
268 return Err(MenteError::Query(format!(
269 "unexpected character '{}' at position {}",
270 bytes[pos] as char, pos
271 )));
272 }
273
274 tokens.push(Token {
275 kind: TokenKind::Eof,
276 lexeme: String::new(),
277 position: pos,
278 });
279 Ok(tokens)
280}
281
282fn is_uuid_like(s: &str) -> bool {
283 if s.len() != 36 {
285 return false;
286 }
287 let parts: Vec<&str> = s.split('-').collect();
288 if parts.len() != 5 {
289 return false;
290 }
291 let expected_lens = [8, 4, 4, 4, 12];
292 for (part, &expected) in parts.iter().zip(&expected_lens) {
293 if part.len() != expected || !part.chars().all(|c| c.is_ascii_hexdigit()) {
294 return false;
295 }
296 }
297 true
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_recall_statement_tokens() {
306 let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 10").unwrap();
307 assert_eq!(tokens[0].kind, TokenKind::Recall);
308 assert_eq!(tokens[1].kind, TokenKind::Memories);
309 assert_eq!(tokens[2].kind, TokenKind::Where);
310 assert_eq!(tokens[3].kind, TokenKind::Type);
311 assert_eq!(tokens[4].kind, TokenKind::Eq);
312 assert_eq!(tokens[5].kind, TokenKind::Identifier);
313 assert_eq!(tokens[5].lexeme, "episodic");
314 assert_eq!(tokens[6].kind, TokenKind::Limit);
315 assert_eq!(tokens[7].kind, TokenKind::IntegerLit);
316 assert_eq!(tokens[8].kind, TokenKind::Eof);
317 }
318
319 #[test]
320 fn test_string_literal() {
321 let tokens = tokenize(r#"content ~> "database migration""#).unwrap();
322 assert_eq!(tokens[0].kind, TokenKind::Identifier);
323 assert_eq!(tokens[1].kind, TokenKind::SimilarTo);
324 assert_eq!(tokens[2].kind, TokenKind::StringLit);
325 assert_eq!(tokens[2].lexeme, r#""database migration""#);
326 }
327
328 #[test]
329 fn test_operators() {
330 let tokens = tokenize("= != > < >= <= ~> ->").unwrap();
331 let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
332 assert_eq!(
333 kinds,
334 vec![
335 TokenKind::Eq,
336 TokenKind::Neq,
337 TokenKind::Gt,
338 TokenKind::Lt,
339 TokenKind::Gte,
340 TokenKind::Lte,
341 TokenKind::SimilarTo,
342 TokenKind::Arrow,
343 TokenKind::Eof,
344 ]
345 );
346 }
347
348 #[test]
349 fn test_uuid_token() {
350 let tokens = tokenize("550e8400-e29b-41d4-a716-446655440000").unwrap();
351 assert_eq!(tokens[0].kind, TokenKind::UuidLit);
352 }
353
354 #[test]
355 fn test_float_literal() {
356 let tokens = tokenize("0.1 42 3.14").unwrap();
357 assert_eq!(tokens[0].kind, TokenKind::FloatLit);
358 assert_eq!(tokens[1].kind, TokenKind::IntegerLit);
359 assert_eq!(tokens[2].kind, TokenKind::FloatLit);
360 }
361
362 #[test]
363 fn test_vector_literal() {
364 let tokens = tokenize("[0.1, 0.2, 0.3]").unwrap();
365 assert_eq!(tokens[0].kind, TokenKind::LBracket);
366 assert_eq!(tokens[1].kind, TokenKind::FloatLit);
367 assert_eq!(tokens[2].kind, TokenKind::Comma);
368 assert_eq!(tokens[5].kind, TokenKind::FloatLit);
369 assert_eq!(tokens[6].kind, TokenKind::RBracket);
370 }
371
372 #[test]
373 fn test_punctuation() {
374 let tokens = tokenize("( ) [ ] , . : ;").unwrap();
375 let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
376 assert_eq!(
377 kinds,
378 vec![
379 TokenKind::LParen,
380 TokenKind::RParen,
381 TokenKind::LBracket,
382 TokenKind::RBracket,
383 TokenKind::Comma,
384 TokenKind::Dot,
385 TokenKind::Colon,
386 TokenKind::Semicolon,
387 TokenKind::Eof,
388 ]
389 );
390 }
391}