mago_lexer/
input.rs

1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13    pub(crate) bytes: &'a [u8],
14    pub(crate) length: usize,
15    pub(crate) position: Position,
16}
17
18impl<'a> Input<'a> {
19    /// Creates a new `Input` instance from the given input.
20    ///
21    /// # Arguments
22    ///
23    /// * `input` - A byte slice representing the input code to be processed.
24    ///
25    /// # Returns
26    ///
27    /// A new `Input` instance initialized at the beginning of the input.
28    pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
29        let length = bytes.len();
30
31        Self { bytes, length, position: Position::start_of(source) }
32    }
33
34    /// Returns the source identifier of the input code.
35    #[inline]
36    pub const fn source_identifier(&self) -> SourceIdentifier {
37        self.position.source
38    }
39
40    /// Returns the current position in the input code.
41    ///
42    /// # Returns
43    ///
44    /// A `Position` struct containing the current line, column, and offset.
45    #[inline]
46    pub const fn position(&self) -> Position {
47        self.position
48    }
49
50    /// Checks if the current position is at the end of the input.
51    ///
52    /// # Returns
53    ///
54    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
55    #[inline]
56    pub const fn has_reached_eof(&self) -> bool {
57        self.position.offset >= self.length
58    }
59
60    /// Advances the current position by one character, updating line and column numbers.
61    ///
62    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
63    ///
64    /// If the end of input is reached, no action is taken.
65    #[inline]
66    pub fn next(&mut self) {
67        if !self.has_reached_eof() {
68            self.position.offset += 1;
69        }
70    }
71
72    /// Skips the next `count` characters, advancing the position accordingly.
73    ///
74    /// Updates line and column numbers as it advances.
75    ///
76    /// # Arguments
77    ///
78    /// * `count` - The number of characters to skip.
79    #[inline]
80    pub fn skip(&mut self, count: usize) {
81        self.position.offset = (self.position.offset + count).min(self.length);
82    }
83
84    /// Consumes the next `count` characters and returns them as a slice.
85    ///
86    /// Advances the position by `count` characters.
87    ///
88    /// # Arguments
89    ///
90    /// * `count` - The number of characters to consume.
91    ///
92    /// # Returns
93    ///
94    /// A byte slice containing the consumed characters.
95    #[inline]
96    pub fn consume(&mut self, count: usize) -> &'a [u8] {
97        let (from, until) = self.calculate_bound(count);
98
99        self.skip(count);
100
101        &self.bytes[from..until]
102    }
103
104    /// Consumes all remaining characters from the current position to the end of input.
105    ///
106    /// Advances the position to EOF.
107    ///
108    /// # Returns
109    ///
110    /// A byte slice containing the remaining characters.
111    #[inline]
112    pub fn consume_remaining(&mut self) -> &'a [u8] {
113        if self.has_reached_eof() {
114            return &[];
115        }
116
117        let from = self.position.offset;
118        self.position.offset = self.length;
119
120        &self.bytes[from..]
121    }
122
123    /// Consumes characters until the given byte slice is found.
124    ///
125    /// Advances the position to the start of the search slice if found,
126    /// or to EOF if not found.
127    ///
128    /// # Arguments
129    ///
130    /// * `search` - The byte slice to search for.
131    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
132    ///
133    /// # Returns
134    ///
135    /// A byte slice containing the consumed characters.
136    #[inline]
137    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
138        let start = self.position.offset;
139        if !ignore_ascii_case {
140            // For a single-byte search, use memchr.
141            if search.len() == 1 {
142                if let Some(pos) = memchr(search[0], &self.bytes[self.position.offset..]) {
143                    self.position.offset += pos;
144                    &self.bytes[start..self.position.offset]
145                } else {
146                    self.position.offset = self.length;
147                    &self.bytes[start..self.length]
148                }
149            } else if let Some(pos) = find(&self.bytes[self.position.offset..], search) {
150                self.position.offset += pos;
151                &self.bytes[start..self.position.offset]
152            } else {
153                self.position.offset = self.length;
154                &self.bytes[start..self.length]
155            }
156        } else {
157            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
158                self.position.offset += 1;
159            }
160
161            &self.bytes[start..self.position.offset]
162        }
163    }
164
165    #[inline]
166    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
167        let start = self.position.offset;
168        if let Some(pos) = memchr::memchr(search, &self.bytes[self.position.offset..]) {
169            self.position.offset += pos + 1;
170
171            &self.bytes[start..self.position.offset]
172        } else {
173            self.position.offset = self.length;
174
175            &self.bytes[start..self.length]
176        }
177    }
178
179    /// Consumes whitespaces until a non-whitespace character is found.
180    ///
181    /// # Returns
182    ///
183    /// A byte slice containing the consumed whitespaces.
184    #[inline]
185    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
186        let start = self.position.offset;
187        let bytes = self.bytes;
188        let len = self.length;
189        while self.position.offset < len && bytes[self.position.offset].is_ascii_whitespace() {
190            self.position.offset += 1;
191        }
192
193        &bytes[start..self.position.offset]
194    }
195
196    /// Reads the next `n` characters without advancing the position.
197    ///
198    /// # Arguments
199    ///
200    /// * `n` - The number of characters to read.
201    ///
202    /// # Returns
203    ///
204    /// A byte slice containing the next `n` characters.
205    #[inline]
206    pub fn read(&self, n: usize) -> &'a [u8] {
207        let (from, until) = self.calculate_bound(n);
208
209        &self.bytes[from..until]
210    }
211
212    /// Checks if the input at the current position matches the given byte slice.
213    ///
214    /// # Arguments
215    ///
216    /// * `search` - The byte slice to compare against the input.
217    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
218    ///
219    /// # Returns
220    ///
221    /// `true` if the next bytes match `search`; `false` otherwise.
222    #[inline]
223    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
224        let (from, until) = self.calculate_bound(search.len());
225        let slice = &self.bytes[from..until];
226
227        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
228    }
229
230    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
231    ///
232    /// This method tries to match the provided byte slice `search` against the input starting
233    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
234    /// are skipped during matching, but their length is included in the returned length.
235    ///
236    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
237    /// in the returned length.
238    ///
239    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
240    /// and this method would return the total length of the input consumed to match `(string)`,
241    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
242    ///
243    /// # Arguments
244    ///
245    /// * `search` - The byte slice to match against the input.
246    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
247    ///
248    /// # Returns
249    ///
250    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
251    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
252    /// * `None` - If the input does not match `search`.
253    ///
254    /// # Examples
255    ///
256    /// ```rust
257    /// use mago_lexer::input::Input;
258    /// use mago_source::SourceIdentifier;
259    ///
260    /// let source = SourceIdentifier::dummy();
261    ///
262    /// // Given input "( string ) x", starting at offset 0:
263    /// let input = Input::new(source.clone(), b"( string ) x");
264    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
265    ///
266    /// // Given input "(int)", with no whitespace:
267    /// let input = Input::new(source.clone(), b"(int)");
268    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
269    ///
270    /// // Given input "(  InT   )abc", ignoring ASCII case:
271    /// let input = Input::new(source.clone(), b"(  InT   )abc");
272    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
273    ///
274    /// // Given input "(integer)", attempting to match "(int)":
275    /// let input = Input::new(source.clone(), b"(integer)");
276    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
277    ///
278    /// // Trailing whitespace after ')':
279    /// let input = Input::new(source.clone(), b"(int)   x");
280    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
281    /// ```
282    #[inline]
283    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
284        let mut offset = self.position.offset;
285        let mut search_offset = 0;
286        let mut length = 0;
287        let bytes = self.bytes;
288        let total = self.length;
289        while search_offset < search.len() {
290            // Skip whitespace in the input.
291            while offset < total && bytes[offset].is_ascii_whitespace() {
292                offset += 1;
293                length += 1;
294            }
295
296            if offset >= total {
297                return None;
298            }
299
300            let input_byte = bytes[offset];
301            let search_byte = search[search_offset];
302            let matched = if ignore_ascii_case {
303                input_byte.eq_ignore_ascii_case(&search_byte)
304            } else {
305                input_byte == search_byte
306            };
307
308            if matched {
309                offset += 1;
310                length += 1;
311                search_offset += 1;
312            } else {
313                return None;
314            }
315        }
316
317        Some(length)
318    }
319
320    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
321    ///
322    /// # Arguments
323    ///
324    /// * `offset` - The number of characters to skip before reading.
325    /// * `n` - The number of characters to read after skipping.
326    ///
327    /// # Returns
328    ///
329    /// A byte slice containing the peeked characters.
330    #[inline]
331    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
332        let from = self.position.offset + offset;
333        if from >= self.length {
334            return &self.bytes[self.length..self.length];
335        }
336
337        let mut until = from + n;
338        if until >= self.length {
339            until = self.length;
340        }
341
342        &self.bytes[from..until]
343    }
344
345    /// Calculates the bounds for slicing the input safely.
346    ///
347    /// Ensures that slicing does not go beyond the input length.
348    ///
349    /// # Arguments
350    ///
351    /// * `n` - The number of characters to include in the slice.
352    ///
353    /// # Returns
354    ///
355    /// A tuple `(from, until)` representing the start and end indices for slicing.
356    #[inline]
357    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
358        if self.has_reached_eof() {
359            return (self.length, self.length);
360        }
361
362        let mut until = self.position.offset + n;
363
364        if until >= self.length {
365            until = self.length;
366        }
367
368        (self.position.offset, until)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use mago_span::Position;
375
376    use super::*;
377
378    #[test]
379    fn test_new() {
380        let bytes = b"Hello, world!";
381        let input = Input::new(SourceIdentifier::dummy(), bytes);
382
383        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
384        assert_eq!(input.length, bytes.len());
385        assert_eq!(input.bytes, bytes);
386    }
387
388    #[test]
389    fn test_is_eof() {
390        let bytes = b"";
391        let input = Input::new(SourceIdentifier::dummy(), bytes);
392
393        assert!(input.has_reached_eof());
394
395        let bytes = b"data";
396        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
397
398        assert!(!input.has_reached_eof());
399
400        input.skip(4);
401
402        assert!(input.has_reached_eof());
403    }
404
405    #[test]
406    fn test_next() {
407        let bytes = b"a\nb\r\nc\rd";
408        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
409
410        // 'a'
411        input.next();
412        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 1));
413
414        // '\n'
415        input.next();
416        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 2));
417
418        // 'b'
419        input.next();
420        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 3));
421
422        // '\r\n' should be treated as one newline
423        input.next();
424        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 4));
425
426        // 'c'
427        input.next();
428        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 5));
429
430        // '\r'
431        input.next();
432        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 6));
433
434        // 'd'
435        input.next();
436        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 7));
437    }
438
439    #[test]
440    fn test_consume() {
441        let bytes = b"abcdef";
442        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
443
444        let consumed = input.consume(3);
445        assert_eq!(consumed, b"abc");
446        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 3));
447
448        let consumed = input.consume(3);
449        assert_eq!(consumed, b"def");
450        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 6));
451
452        let consumed = input.consume(1); // Should return empty slice at EOF
453        assert_eq!(consumed, b"");
454        assert!(input.has_reached_eof());
455    }
456
457    #[test]
458    fn test_consume_remaining() {
459        let bytes = b"abcdef";
460        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
461
462        input.skip(2);
463        let remaining = input.consume_remaining();
464        assert_eq!(remaining, b"cdef");
465        assert!(input.has_reached_eof());
466    }
467
468    #[test]
469    fn test_read() {
470        let bytes = b"abcdef";
471        let input = Input::new(SourceIdentifier::dummy(), bytes);
472
473        let read = input.read(3);
474        assert_eq!(read, b"abc");
475        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
476        // Position should not change
477    }
478
479    #[test]
480    fn test_is_at() {
481        let bytes = b"abcdef";
482        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
483
484        assert!(input.is_at(b"abc", false));
485        input.skip(2);
486        assert!(input.is_at(b"cde", false));
487        assert!(!input.is_at(b"xyz", false));
488    }
489
490    #[test]
491    fn test_is_at_ignore_ascii_case() {
492        let bytes = b"AbCdEf";
493        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
494
495        assert!(input.is_at(b"abc", true));
496        input.skip(2);
497        assert!(input.is_at(b"cde", true));
498        assert!(!input.is_at(b"xyz", true));
499    }
500
501    #[test]
502    fn test_peek() {
503        let bytes = b"abcdef";
504        let input = Input::new(SourceIdentifier::dummy(), bytes);
505
506        let peeked = input.peek(2, 3);
507        assert_eq!(peeked, b"cde");
508        assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
509        // Position should not change
510    }
511
512    #[test]
513    fn test_to_bound() {
514        let bytes = b"abcdef";
515        let input = Input::new(SourceIdentifier::dummy(), bytes);
516
517        let (from, until) = input.calculate_bound(3);
518        assert_eq!((from, until), (0, 3));
519
520        let (from, until) = input.calculate_bound(10); // Exceeds length
521        assert_eq!((from, until), (0, 6));
522    }
523}