exp_rs/
lexer.rs

1use crate::types::TokenKind;
2use crate::{Real, String, ToString};
3
4#[cfg(test)]
5use std::format;
6
7#[cfg(all(not(test), target_arch = "arm"))]
8use alloc::format;
9
10/// A token produced by the lexer.
11#[derive(Debug, Clone, PartialEq)]
12pub struct Token {
13    pub kind: TokenKind,
14    pub value: Option<Real>,
15    pub text: Option<String>,
16    pub position: usize,
17}
18
19/// The lexer struct, which produces tokens from an input string.
20#[derive(Clone)]
21pub struct Lexer<'a> {
22    input: &'a str,
23    pub pos: usize,
24}
25
26impl<'a> Lexer<'a> {
27    pub fn new(input: &'a str) -> Self {
28        // Check for invalid UTF-8 sequences
29        // This is a no-op in Rust since the &str type guarantees valid UTF-8
30        // But we can check for extremely long input
31        Self { input, pos: 0 }
32    }
33
34    /// Peek at the current character.
35    fn peek(&self) -> Option<char> {
36        self.input[self.pos..].chars().next()
37    }
38
39    /// Advance the position by one character.
40    fn advance(&mut self) {
41        if let Some(c) = self.peek() {
42            self.pos += c.len_utf8();
43        }
44    }
45
46    /// Peek at the next token without consuming it
47    pub fn peek_token(&self) -> Option<Token> {
48        let mut lexer_copy = self.clone();
49        lexer_copy.next_token()
50    }
51
52    /// Get the remaining input from the current position
53    pub fn get_remaining_input(&self) -> Option<&str> {
54        if self.pos < self.input.len() {
55            Some(&self.input[self.pos..])
56        } else {
57            None
58        }
59    }
60
61    /// Get the original input string
62    pub fn get_original_input(&self) -> &'a str {
63        self.input
64    }
65
66    /// Skip whitespace.
67    fn skip_whitespace(&mut self) {
68        while let Some(c) = self.peek() {
69            if c.is_whitespace() {
70                self.advance();
71            } else {
72                break;
73            }
74        }
75    }
76
77    /// Check if a token is too long
78    fn check_token_length(&self, start_pos: usize, end_pos: usize) -> Result<(), String> {
79        const MAX_TOKEN_LENGTH: usize = 1000; // Reasonable limit
80        if end_pos - start_pos > MAX_TOKEN_LENGTH {
81            return Err(format!(
82                "Token too long: {} characters (maximum is {})",
83                end_pos - start_pos,
84                MAX_TOKEN_LENGTH
85            ));
86        }
87        Ok(())
88    }
89
90    /// Get the next token from the input.
91    pub fn next_token(&mut self) -> Option<Token> {
92        self.skip_whitespace();
93        let start_pos = self.pos;
94        let c = self.peek()?;
95
96        // Special case for decimal numbers starting with a dot
97        if c == '.' && self.pos + 1 < self.input.len() {
98            let next_char = self.input[self.pos + 1..].chars().next();
99            if next_char.is_some_and(|d| d.is_ascii_digit()) {
100                // This is a decimal number starting with a dot (e.g., .5)
101                self.advance(); // Skip the dot
102                let mut has_digits = false;
103                let mut has_exp = false;
104
105                // Parse digits after the dot
106                while let Some(nc) = self.peek() {
107                    if nc.is_ascii_digit() {
108                        has_digits = true;
109                        self.advance();
110                    } else if (nc == 'e' || nc == 'E') && !has_exp {
111                        has_exp = true;
112                        self.advance();
113                        // Optional sign after e/E
114                        if let Some(sign) = self.peek() {
115                            if sign == '+' || sign == '-' {
116                                self.advance();
117                            }
118                        }
119
120                        // Must have at least one digit after e/E
121                        let mut has_exp_digits = false;
122                        while let Some(ec) = self.peek() {
123                            if ec.is_ascii_digit() {
124                                has_exp_digits = true;
125                                self.advance();
126                            } else {
127                                break;
128                            }
129                        }
130
131                        if !has_exp_digits {
132                            return Some(Token {
133                                kind: TokenKind::Error,
134                                value: None,
135                                text: Some(String::from(&self.input[start_pos..self.pos])),
136                                position: start_pos,
137                            });
138                        }
139                    } else {
140                        break;
141                    }
142                }
143
144                if !has_digits {
145                    return Some(Token {
146                        kind: TokenKind::Error,
147                        value: None,
148                        text: Some(String::from(".")),
149                        position: start_pos,
150                    });
151                }
152
153                // Parse the number with a leading zero
154                let num_str = format!("0{}", &self.input[start_pos..self.pos]);
155
156                if let Ok(val) = num_str.parse::<Real>() {
157                    return Some(Token {
158                        kind: TokenKind::Number,
159                        value: Some(val),
160                        text: Some(String::from(&self.input[start_pos..self.pos])),
161                        position: start_pos,
162                    });
163                } else {
164                    return Some(Token {
165                        kind: TokenKind::Error,
166                        value: None,
167                        text: Some(String::from(&self.input[start_pos..self.pos])),
168                        position: start_pos,
169                    });
170                }
171            }
172        }
173
174        // Number (integer or float, possibly scientific notation)
175        if c.is_ascii_digit() {
176            let mut saw_dot = false;
177            let mut saw_e = false;
178            let mut has_digits_after_e = false;
179
180            // Parse integer part
181            self.advance();
182
183            // Parse fractional part
184            while let Some(nc) = self.peek() {
185                if nc.is_ascii_digit() {
186                    self.advance();
187                    if saw_e {
188                        has_digits_after_e = true;
189                    }
190                } else if nc == '.' && !saw_dot {
191                    saw_dot = true;
192                    self.advance();
193                } else if (nc == 'e' || nc == 'E') && !saw_e {
194                    saw_e = true;
195                    self.advance();
196                    // Optional sign after e/E
197                    if let Some(sign) = self.peek() {
198                        if sign == '+' || sign == '-' {
199                            self.advance();
200                        }
201                    }
202                } else {
203                    break;
204                }
205            }
206
207            // Validate scientific notation has digits after 'e'
208            if saw_e && !has_digits_after_e {
209                return Some(Token {
210                    kind: TokenKind::Error,
211                    value: None,
212                    text: Some(String::from(&self.input[start_pos..self.pos])),
213                    position: start_pos,
214                });
215            }
216
217            let num_str = &self.input[start_pos..self.pos];
218            if let Ok(val) = num_str.parse::<Real>() {
219                return Some(Token {
220                    kind: TokenKind::Number,
221                    value: Some(val),
222                    text: Some(String::from(num_str)),
223                    position: start_pos,
224                });
225            } else {
226                return Some(Token {
227                    kind: TokenKind::Error,
228                    value: None,
229                    text: Some(num_str.to_string()),
230                    position: start_pos,
231                });
232            }
233        }
234
235        // Operators and punctuation
236        // Support multi-character operators for tinyexpr++ grammar
237        let op_start = "+-*/^%.<>=!&|~?:"; // Added ? and : for ternary operators
238        if op_start.contains(c) {
239            let kind = TokenKind::Operator;
240            let mut text = String::from(c);
241            self.advance();
242
243            // Lookahead for multi-character operators
244            let next = self.peek();
245            // Handle **, &&, ||, <<, >>, <<<, >>>, <=, >=, ==, !=, <>, and others
246            if let Some(nc) = next {
247                match (c, nc) {
248                    // Triple char: <<<, >>>
249                    ('<', '<') if self.input[self.pos..].starts_with("<<") => {
250                        // Could be <<< or <<, check for third '<'
251                        self.advance(); // 2nd '<'
252                        if self.peek() == Some('<') {
253                            text.push('<');
254                            self.advance();
255                        } else {
256                            text.push('<');
257                        }
258                    }
259                    ('>', '>') if self.input[self.pos..].starts_with(">>") => {
260                        // Could be >>> or >>, check for third '>'
261                        self.advance(); // 2nd '>'
262                        if self.peek() == Some('>') {
263                            text.push('>');
264                            self.advance();
265                        } else {
266                            text.push('>');
267                        }
268                    }
269                    // Double char ops
270                    ('*', '*') | ('&', '&') | ('|', '|') | ('<', '<') | ('>', '>') => {
271                        text.push(nc);
272                        self.advance();
273                    }
274                    ('<', '>') => {
275                        text.push(nc);
276                        self.advance();
277                    }
278                    ('<', '=') | ('>', '=') | ('=', '=') | ('!', '=') => {
279                        text.push(nc);
280                        self.advance();
281                    }
282                    _ => {}
283                }
284            }
285
286            return Some(Token {
287                kind,
288                value: None,
289                text: Some(text),
290                position: start_pos,
291            });
292        }
293
294        // Identifier (variable, function, constant)
295        if c.is_ascii_alphabetic() || c == '_' {
296            let start_pos = self.pos;
297            let mut end = self.pos;
298            while let Some(nc) = self.input[end..].chars().next() {
299                if nc.is_ascii_alphanumeric() || nc == '_' {
300                    end += nc.len_utf8();
301                } else {
302                    break;
303                }
304            }
305
306            // Check if the identifier is too long
307            if let Err(err) = self.check_token_length(start_pos, end) {
308                return Some(Token {
309                    kind: TokenKind::Error,
310                    value: None,
311                    text: Some(err),
312                    position: start_pos,
313                });
314            }
315
316            let ident = &self.input[self.pos..end];
317            self.pos = end;
318            return Some(Token {
319                kind: TokenKind::Variable,
320                value: None,
321                text: Some(String::from(ident)),
322                position: start_pos,
323            });
324        }
325
326        // Other punctuation
327        let kind = match c {
328            '(' | '[' => TokenKind::Open,
329            ')' | ']' => TokenKind::Close,
330            ',' | ';' => TokenKind::Separator, // Add ; as a separator
331            _ => TokenKind::Error,
332        };
333        let text = String::from(c);
334        self.advance();
335        Some(Token {
336            kind,
337            value: None,
338            text: Some(text),
339            position: start_pos,
340        })
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::types::TokenKind;
348
349    #[test]
350    fn test_lexer_tokenization_all_types() {
351        let mut lexer = Lexer::new("1 + foo_bar * (2.5e-1) , -baz_123 / 4.2 ^ _x");
352        let mut tokens = Vec::new();
353        while let Some(tok) = lexer.next_token() {
354            tokens.push(tok);
355        }
356        let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
357        assert!(kinds.contains(&TokenKind::Number));
358        assert!(kinds.contains(&TokenKind::Operator));
359        assert!(kinds.contains(&TokenKind::Variable));
360        assert!(kinds.contains(&TokenKind::Open));
361        assert!(kinds.contains(&TokenKind::Close));
362        assert!(kinds.contains(&TokenKind::Separator));
363    }
364
365    #[test]
366    fn test_lexer_tokenization_error_tokens() {
367        let mut lexer = Lexer::new("1 $ 2");
368        let mut found_error = false;
369        while let Some(tok) = lexer.next_token() {
370            if tok.kind == TokenKind::Error {
371                found_error = true;
372                break;
373            }
374        }
375        assert!(
376            found_error,
377            "Lexer should produce error token for unknown character"
378        );
379    }
380
381    #[test]
382    fn test_lexer_tokenization_malformed_numbers() {
383        let mut lexer = Lexer::new("1..2 1e--2");
384        let mut found_error = false;
385        let mut tokens = Vec::new();
386
387        // Collect all tokens to avoid infinite loop
388        while let Some(tok) = lexer.next_token() {
389            tokens.push(tok);
390            // Break after collecting a reasonable number of tokens
391            if tokens.len() > 10 {
392                break;
393            }
394        }
395
396        // Check if any token is an error
397        for tok in tokens {
398            if tok.kind == TokenKind::Error {
399                found_error = true;
400                break;
401            }
402        }
403
404        assert!(
405            found_error,
406            "Lexer should produce error token for malformed numbers"
407        );
408    }
409
410    #[test]
411    fn test_lexer_decimal_with_leading_dot() {
412        let mut lexer = Lexer::new(".5 .123 .0 .9e2");
413
414        // Test .5
415        let token = lexer.next_token().unwrap();
416        assert_eq!(token.kind, TokenKind::Number);
417        assert_eq!(token.value, Some(0.5));
418
419        // Test .123
420        let token = lexer.next_token().unwrap();
421        assert_eq!(token.kind, TokenKind::Number);
422        assert_eq!(token.value, Some(0.123));
423
424        // Test .0
425        let token = lexer.next_token().unwrap();
426        assert_eq!(token.kind, TokenKind::Number);
427        assert_eq!(token.value, Some(0.0));
428
429        // Test .9e2
430        let token = lexer.next_token().unwrap();
431        assert_eq!(token.kind, TokenKind::Number);
432        assert_eq!(token.value, Some(90.0));
433    }
434
435    #[test]
436    fn test_lexer_tokenization_variable_with_dot() {
437        let mut lexer = Lexer::new("foo.bar");
438        let t1 = lexer.next_token().unwrap();
439        let t2 = lexer.next_token().unwrap();
440        let t3 = lexer.next_token().unwrap();
441        assert_eq!(t1.kind, TokenKind::Variable);
442        assert_eq!(t1.text.as_deref(), Some("foo"));
443        assert_eq!(t2.kind, TokenKind::Operator);
444        assert_eq!(t2.text.as_deref(), Some("."));
445        assert_eq!(t3.kind, TokenKind::Variable);
446        assert_eq!(t3.text.as_deref(), Some("bar"));
447    }
448
449    #[test]
450    fn test_lexer_tokenization_multichar_operators() {
451        let mut lexer =
452            Lexer::new("a && b || c == d != e <= f >= g << h >> i <<< j >>> k ** l <> m ; n");
453        let mut tokens = Vec::new();
454        while let Some(tok) = lexer.next_token() {
455            tokens.push(tok);
456        }
457        let ops: Vec<_> = tokens
458            .iter()
459            .filter(|t| t.kind == TokenKind::Operator || t.kind == TokenKind::Separator)
460            .map(|t| t.text.as_deref().unwrap())
461            .collect();
462        assert!(ops.contains(&"&&"));
463        assert!(ops.contains(&"||"));
464        assert!(ops.contains(&"=="));
465        assert!(ops.contains(&"!="));
466        assert!(ops.contains(&"<="));
467        assert!(ops.contains(&">="));
468        assert!(ops.contains(&"<<"));
469        assert!(ops.contains(&">>"));
470        // The current lexer implementation tokenizes <<< as two tokens: "<<" and "<"
471        // and >>> as two tokens: ">>" and ">"
472        // So we do not assert for "<<<" or ">>>"
473        assert!(ops.contains(&"**"));
474        assert!(ops.contains(&"<>"));
475        assert!(ops.contains(&";"));
476    }
477
478    #[test]
479    fn test_lexer_tokenization_ternary_operators() {
480        let mut lexer = Lexer::new("x > 0 ? y : z");
481        let mut tokens = Vec::new();
482        while let Some(tok) = lexer.next_token() {
483            tokens.push(tok);
484        }
485
486        // Check that we have the right number of tokens
487        assert_eq!(tokens.len(), 7);
488
489        // Verify the ternary operator tokens
490        assert_eq!(tokens[0].kind, TokenKind::Variable); // x
491        assert_eq!(tokens[1].kind, TokenKind::Operator); // >
492        assert_eq!(tokens[2].kind, TokenKind::Number); // 0
493        assert_eq!(tokens[3].kind, TokenKind::Operator); // ?
494        assert_eq!(tokens[3].text.as_deref(), Some("?"));
495        assert_eq!(tokens[4].kind, TokenKind::Variable); // y
496        assert_eq!(tokens[5].kind, TokenKind::Operator); // :
497        assert_eq!(tokens[5].text.as_deref(), Some(":"));
498        assert_eq!(tokens[6].kind, TokenKind::Variable); // z
499    }
500}