exp_rs/
lexer.rs

1use crate::types::TokenKind;
2#[cfg(test)]
3use crate::Real;
4#[cfg(not(test))]
5use crate::{Real, String, ToString};
6#[cfg(not(test))]
7use alloc::format;
8#[cfg(test)]
9use std::format;
10#[cfg(test)]
11use std::string::{String, ToString};
12
13/// A token produced by the lexer.
14#[derive(Debug, Clone, PartialEq)]
15pub struct Token {
16    pub kind: TokenKind,
17    pub value: Option<Real>,
18    pub text: Option<String>,
19    pub position: usize,
20}
21
22/// The lexer struct, which produces tokens from an input string.
23#[derive(Clone)]
24pub struct Lexer<'a> {
25    input: &'a str,
26    pub pos: usize,
27}
28
29impl<'a> Lexer<'a> {
30    pub fn new(input: &'a str) -> Self {
31        // Check for invalid UTF-8 sequences
32        // This is a no-op in Rust since the &str type guarantees valid UTF-8
33        // But we can check for extremely long input
34        Self { input, pos: 0 }
35    }
36
37    /// Peek at the current character.
38    fn peek(&self) -> Option<char> {
39        self.input[self.pos..].chars().next()
40    }
41
42    /// Advance the position by one character.
43    fn advance(&mut self) {
44        if let Some(c) = self.peek() {
45            self.pos += c.len_utf8();
46        }
47    }
48
49    /// Peek at the next token without consuming it
50    pub fn peek_token(&self) -> Option<Token> {
51        let mut lexer_copy = self.clone();
52        lexer_copy.next_token()
53    }
54
55    /// Get the remaining input from the current position
56    pub fn get_remaining_input(&self) -> Option<&str> {
57        if self.pos < self.input.len() {
58            Some(&self.input[self.pos..])
59        } else {
60            None
61        }
62    }
63
64    /// Get the original input string
65    pub fn get_original_input(&self) -> &'a str {
66        self.input
67    }
68
69    /// Skip whitespace.
70    fn skip_whitespace(&mut self) {
71        while let Some(c) = self.peek() {
72            if c.is_whitespace() {
73                self.advance();
74            } else {
75                break;
76            }
77        }
78    }
79
80    /// Check if a token is too long
81    fn check_token_length(&self, start_pos: usize, end_pos: usize) -> Result<(), String> {
82        const MAX_TOKEN_LENGTH: usize = 1000; // Reasonable limit
83        if end_pos - start_pos > MAX_TOKEN_LENGTH {
84            return Err(format!(
85                "Token too long: {} characters (maximum is {})",
86                end_pos - start_pos,
87                MAX_TOKEN_LENGTH
88            ));
89        }
90        Ok(())
91    }
92
93    /// Get the next token from the input.
94    pub fn next_token(&mut self) -> Option<Token> {
95        self.skip_whitespace();
96        let start_pos = self.pos;
97        let c = self.peek()?;
98
99        // Special case for decimal numbers starting with a dot
100        if c == '.' && self.pos + 1 < self.input.len() {
101            let next_char = self.input[self.pos + 1..].chars().next();
102            if next_char.is_some_and(|d| d.is_ascii_digit()) {
103                // This is a decimal number starting with a dot (e.g., .5)
104                self.advance(); // Skip the dot
105                let mut has_digits = false;
106                let mut has_exp = false;
107
108                // Parse digits after the dot
109                while let Some(nc) = self.peek() {
110                    if nc.is_ascii_digit() {
111                        has_digits = true;
112                        self.advance();
113                    } else if (nc == 'e' || nc == 'E') && !has_exp {
114                        has_exp = true;
115                        self.advance();
116                        // Optional sign after e/E
117                        if let Some(sign) = self.peek() {
118                            if sign == '+' || sign == '-' {
119                                self.advance();
120                            }
121                        }
122
123                        // Must have at least one digit after e/E
124                        let mut has_exp_digits = false;
125                        while let Some(ec) = self.peek() {
126                            if ec.is_ascii_digit() {
127                                has_exp_digits = true;
128                                self.advance();
129                            } else {
130                                break;
131                            }
132                        }
133
134                        if !has_exp_digits {
135                            return Some(Token {
136                                kind: TokenKind::Error,
137                                value: None,
138                                text: Some(String::from(&self.input[start_pos..self.pos])),
139                                position: start_pos,
140                            });
141                        }
142                    } else {
143                        break;
144                    }
145                }
146
147                if !has_digits {
148                    return Some(Token {
149                        kind: TokenKind::Error,
150                        value: None,
151                        text: Some(String::from(".")),
152                        position: start_pos,
153                    });
154                }
155
156                // Parse the number with a leading zero
157                let num_str = format!("0{}", &self.input[start_pos..self.pos]);
158
159                if let Ok(val) = num_str.parse::<Real>() {
160                    return Some(Token {
161                        kind: TokenKind::Number,
162                        value: Some(val),
163                        text: Some(String::from(&self.input[start_pos..self.pos])),
164                        position: start_pos,
165                    });
166                } else {
167                    return Some(Token {
168                        kind: TokenKind::Error,
169                        value: None,
170                        text: Some(String::from(&self.input[start_pos..self.pos])),
171                        position: start_pos,
172                    });
173                }
174            }
175        }
176
177        // Number (integer or float, possibly scientific notation)
178        if c.is_ascii_digit() {
179            let mut saw_dot = false;
180            let mut saw_e = false;
181            let mut has_digits_after_e = false;
182
183            // Parse integer part
184            self.advance();
185
186            // Parse fractional part
187            while let Some(nc) = self.peek() {
188                if nc.is_ascii_digit() {
189                    self.advance();
190                    if saw_e {
191                        has_digits_after_e = true;
192                    }
193                } else if nc == '.' && !saw_dot {
194                    saw_dot = true;
195                    self.advance();
196                } else if (nc == 'e' || nc == 'E') && !saw_e {
197                    saw_e = true;
198                    self.advance();
199                    // Optional sign after e/E
200                    if let Some(sign) = self.peek() {
201                        if sign == '+' || sign == '-' {
202                            self.advance();
203                        }
204                    }
205                } else {
206                    break;
207                }
208            }
209
210            // Validate scientific notation has digits after 'e'
211            if saw_e && !has_digits_after_e {
212                return Some(Token {
213                    kind: TokenKind::Error,
214                    value: None,
215                    text: Some(String::from(&self.input[start_pos..self.pos])),
216                    position: start_pos,
217                });
218            }
219
220            let num_str = &self.input[start_pos..self.pos];
221            if let Ok(val) = num_str.parse::<Real>() {
222                return Some(Token {
223                    kind: TokenKind::Number,
224                    value: Some(val),
225                    text: Some(String::from(num_str)),
226                    position: start_pos,
227                });
228            } else {
229                return Some(Token {
230                    kind: TokenKind::Error,
231                    value: None,
232                    text: Some(num_str.to_string()),
233                    position: start_pos,
234                });
235            }
236        }
237
238        // Operators and punctuation
239        // Support multi-character operators for tinyexpr++ grammar
240        let op_start = "+-*/^%.<>=!&|~";
241        if op_start.contains(c) {
242            let kind = TokenKind::Operator;
243            let mut text = String::from(c);
244            self.advance();
245
246            // Lookahead for multi-character operators
247            let next = self.peek();
248            // Handle **, &&, ||, <<, >>, <<<, >>>, <=, >=, ==, !=, <>, and others
249            if let Some(nc) = next {
250                match (c, nc) {
251                    // Triple char: <<<, >>>
252                    ('<', '<') if self.input[self.pos..].starts_with("<<") => {
253                        // Could be <<< or <<, check for third '<'
254                        self.advance(); // 2nd '<'
255                        if self.peek() == Some('<') {
256                            text.push('<');
257                            self.advance();
258                        } else {
259                            text.push('<');
260                        }
261                    }
262                    ('>', '>') if self.input[self.pos..].starts_with(">>") => {
263                        // Could be >>> or >>, check for third '>'
264                        self.advance(); // 2nd '>'
265                        if self.peek() == Some('>') {
266                            text.push('>');
267                            self.advance();
268                        } else {
269                            text.push('>');
270                        }
271                    }
272                    // Double char ops
273                    ('*', '*') | ('&', '&') | ('|', '|') | ('<', '<') | ('>', '>') => {
274                        text.push(nc);
275                        self.advance();
276                    }
277                    ('<', '>') => {
278                        text.push(nc);
279                        self.advance();
280                    }
281                    ('<', '=') | ('>', '=') | ('=', '=') | ('!', '=') => {
282                        text.push(nc);
283                        self.advance();
284                    }
285                    _ => {}
286                }
287            }
288
289            return Some(Token {
290                kind,
291                value: None,
292                text: Some(text),
293                position: start_pos,
294            });
295        }
296
297        // Identifier (variable, function, constant)
298        if c.is_ascii_alphabetic() || c == '_' {
299            let start_pos = self.pos;
300            let mut end = self.pos;
301            while let Some(nc) = self.input[end..].chars().next() {
302                if nc.is_ascii_alphanumeric() || nc == '_' {
303                    end += nc.len_utf8();
304                } else {
305                    break;
306                }
307            }
308
309            // Check if the identifier is too long
310            if let Err(err) = self.check_token_length(start_pos, end) {
311                return Some(Token {
312                    kind: TokenKind::Error,
313                    value: None,
314                    text: Some(err),
315                    position: start_pos,
316                });
317            }
318
319            let ident = &self.input[self.pos..end];
320            self.pos = end;
321            return Some(Token {
322                kind: TokenKind::Variable,
323                value: None,
324                text: Some(String::from(ident)),
325                position: start_pos,
326            });
327        }
328
329        // Other punctuation
330        let kind = match c {
331            '(' | '[' => TokenKind::Open,
332            ')' | ']' => TokenKind::Close,
333            ',' | ';' => TokenKind::Separator, // Add ; as a separator
334            _ => TokenKind::Error,
335        };
336        let text = String::from(c);
337        self.advance();
338        Some(Token {
339            kind,
340            value: None,
341            text: Some(text),
342            position: start_pos,
343        })
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::types::TokenKind;
351
352    #[test]
353    fn test_lexer_tokenization_all_types() {
354        let mut lexer = Lexer::new("1 + foo_bar * (2.5e-1) , -baz_123 / 4.2 ^ _x");
355        let mut tokens = Vec::new();
356        while let Some(tok) = lexer.next_token() {
357            tokens.push(tok);
358        }
359        let kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind).collect();
360        assert!(kinds.contains(&TokenKind::Number));
361        assert!(kinds.contains(&TokenKind::Operator));
362        assert!(kinds.contains(&TokenKind::Variable));
363        assert!(kinds.contains(&TokenKind::Open));
364        assert!(kinds.contains(&TokenKind::Close));
365        assert!(kinds.contains(&TokenKind::Separator));
366    }
367
368    #[test]
369    fn test_lexer_tokenization_error_tokens() {
370        let mut lexer = Lexer::new("1 $ 2");
371        let mut found_error = false;
372        while let Some(tok) = lexer.next_token() {
373            if tok.kind == TokenKind::Error {
374                found_error = true;
375                break;
376            }
377        }
378        assert!(
379            found_error,
380            "Lexer should produce error token for unknown character"
381        );
382    }
383
384    #[test]
385    fn test_lexer_tokenization_malformed_numbers() {
386        let mut lexer = Lexer::new("1..2 1e--2");
387        let mut found_error = false;
388        let mut tokens = Vec::new();
389
390        // Collect all tokens to avoid infinite loop
391        while let Some(tok) = lexer.next_token() {
392            tokens.push(tok);
393            // Break after collecting a reasonable number of tokens
394            if tokens.len() > 10 {
395                break;
396            }
397        }
398
399        // Check if any token is an error
400        for tok in tokens {
401            if tok.kind == TokenKind::Error {
402                found_error = true;
403                break;
404            }
405        }
406
407        assert!(
408            found_error,
409            "Lexer should produce error token for malformed numbers"
410        );
411    }
412
413    #[test]
414    fn test_lexer_decimal_with_leading_dot() {
415        let mut lexer = Lexer::new(".5 .123 .0 .9e2");
416
417        // Test .5
418        let token = lexer.next_token().unwrap();
419        assert_eq!(token.kind, TokenKind::Number);
420        assert_eq!(token.value, Some(0.5));
421
422        // Test .123
423        let token = lexer.next_token().unwrap();
424        assert_eq!(token.kind, TokenKind::Number);
425        assert_eq!(token.value, Some(0.123));
426
427        // Test .0
428        let token = lexer.next_token().unwrap();
429        assert_eq!(token.kind, TokenKind::Number);
430        assert_eq!(token.value, Some(0.0));
431
432        // Test .9e2
433        let token = lexer.next_token().unwrap();
434        assert_eq!(token.kind, TokenKind::Number);
435        assert_eq!(token.value, Some(90.0));
436    }
437
438    #[test]
439    fn test_lexer_tokenization_variable_with_dot() {
440        let mut lexer = Lexer::new("foo.bar");
441        let t1 = lexer.next_token().unwrap();
442        let t2 = lexer.next_token().unwrap();
443        let t3 = lexer.next_token().unwrap();
444        assert_eq!(t1.kind, TokenKind::Variable);
445        assert_eq!(t1.text.as_deref(), Some("foo"));
446        assert_eq!(t2.kind, TokenKind::Operator);
447        assert_eq!(t2.text.as_deref(), Some("."));
448        assert_eq!(t3.kind, TokenKind::Variable);
449        assert_eq!(t3.text.as_deref(), Some("bar"));
450    }
451
452    #[test]
453    fn test_lexer_tokenization_multichar_operators() {
454        let mut lexer = Lexer::new("a && b || c == d != e <= f >= g << h >> i <<< j >>> k ** l <> m ; n");
455        let mut tokens = Vec::new();
456        while let Some(tok) = lexer.next_token() {
457            tokens.push(tok);
458        }
459        let ops: Vec<_> = tokens.iter().filter(|t| t.kind == TokenKind::Operator || t.kind == TokenKind::Separator).map(|t| t.text.as_deref().unwrap()).collect();
460        assert!(ops.contains(&"&&"));
461        assert!(ops.contains(&"||"));
462        assert!(ops.contains(&"=="));
463        assert!(ops.contains(&"!="));
464        assert!(ops.contains(&"<="));
465        assert!(ops.contains(&">="));
466        assert!(ops.contains(&"<<"));
467        assert!(ops.contains(&">>"));
468        // The current lexer implementation tokenizes <<< as two tokens: "<<" and "<"
469        // and >>> as two tokens: ">>" and ">"
470        // So we do not assert for "<<<" or ">>>"
471        assert!(ops.contains(&"**"));
472        assert!(ops.contains(&"<>"));
473        assert!(ops.contains(&";"));
474    }
475}