mago_syntax_core/
input.rs

1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13    pub(crate) bytes: &'a [u8],
14    pub(crate) length: usize,
15    pub(crate) offset: usize,
16    pub(crate) starting_position: Position,
17}
18
19impl<'a> Input<'a> {
20    /// Creates a new `Input` instance from the given input.
21    ///
22    /// # Arguments
23    ///
24    /// * `input` - A byte slice representing the input code to be processed.
25    ///
26    /// # Returns
27    ///
28    /// A new `Input` instance initialized at the beginning of the input.
29    pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
30        let length = bytes.len();
31
32        Self { bytes, length, offset: 0, starting_position: Position::start_of(source) }
33    }
34
35    /// Creates a new `Input` instance representing a byte slice that is
36    /// "anchored" at a specific absolute position within a larger source file.
37    ///
38    /// This is useful when lexing a subset (slice) of a source file, as it allows
39    /// generated tokens to retain accurate absolute positions and spans relative
40    /// to the original file.
41    ///
42    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
43    /// but the absolute position is calculated relative to the `anchor_position`.
44    ///
45    /// # Arguments
46    ///
47    /// * `bytes` - A byte slice representing the input code subset to be lexed.
48    /// * `anchor_position` - The absolute `Position` in the original source file where
49    ///   the provided `bytes` slice begins.
50    ///
51    /// # Returns
52    ///
53    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
54    /// relative to `anchor_position`.
55    pub fn anchored_at(bytes: &'a [u8], anchor_position: Position) -> Self {
56        let length = bytes.len();
57
58        Self { bytes, length, offset: 0, starting_position: anchor_position }
59    }
60
61    /// Returns the source identifier of the input code.
62    #[inline]
63    pub const fn source_identifier(&self) -> SourceIdentifier {
64        self.starting_position.source
65    }
66
67    /// Returns the absolute current `Position` of the lexer within the original source file.
68    ///
69    /// It calculates this by adding the internal offset (progress within the current byte slice)
70    /// to the `starting_position` the `Input` was initialized with.
71    #[inline]
72    pub const fn current_position(&self) -> Position {
73        // Calculate absolute position by adding internal offset to the starting base
74        self.starting_position.forward(self.offset)
75    }
76
77    /// Returns the current internal byte offset relative to the start of the input slice.
78    ///
79    /// This indicates how many bytes have been consumed from the current `bytes` slice.
80    /// To get the absolute position in the original source file, use `current_position()`.
81    #[inline]
82    pub const fn current_offset(&self) -> usize {
83        self.offset
84    }
85
86    /// Returns `true` if the input slice is empty (length is zero).
87    #[inline]
88    pub const fn is_empty(&self) -> bool {
89        self.length == 0
90    }
91
92    /// Returns the total length in bytes of the input slice being processed.
93    #[inline]
94    pub const fn len(&self) -> usize {
95        self.length
96    }
97
98    /// Checks if the current position is at the end of the input.
99    ///
100    /// # Returns
101    ///
102    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
103    #[inline]
104    pub const fn has_reached_eof(&self) -> bool {
105        self.offset >= self.length
106    }
107
108    /// Advances the current position by one character, updating line and column numbers.
109    ///
110    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
111    ///
112    /// If the end of input is reached, no action is taken.
113    #[inline]
114    pub fn next(&mut self) {
115        if !self.has_reached_eof() {
116            self.offset += 1;
117        }
118    }
119
120    /// Skips the next `count` characters, advancing the position accordingly.
121    ///
122    /// Updates line and column numbers as it advances.
123    ///
124    /// # Arguments
125    ///
126    /// * `count` - The number of characters to skip.
127    #[inline]
128    pub fn skip(&mut self, count: usize) {
129        self.offset = (self.offset + count).min(self.length);
130    }
131
132    /// Consumes the next `count` characters and returns them as a slice.
133    ///
134    /// Advances the position by `count` characters.
135    ///
136    /// # Arguments
137    ///
138    /// * `count` - The number of characters to consume.
139    ///
140    /// # Returns
141    ///
142    /// A byte slice containing the consumed characters.
143    #[inline]
144    pub fn consume(&mut self, count: usize) -> &'a [u8] {
145        let (from, until) = self.calculate_bound(count);
146
147        self.skip(count);
148
149        &self.bytes[from..until]
150    }
151
152    /// Consumes all remaining characters from the current position to the end of input.
153    ///
154    /// Advances the position to EOF.
155    ///
156    /// # Returns
157    ///
158    /// A byte slice containing the remaining characters.
159    #[inline]
160    pub fn consume_remaining(&mut self) -> &'a [u8] {
161        if self.has_reached_eof() {
162            return &[];
163        }
164
165        let from = self.offset;
166        self.offset = self.length;
167
168        &self.bytes[from..]
169    }
170
171    /// Consumes characters until the given byte slice is found.
172    ///
173    /// Advances the position to the start of the search slice if found,
174    /// or to EOF if not found.
175    ///
176    /// # Arguments
177    ///
178    /// * `search` - The byte slice to search for.
179    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
180    ///
181    /// # Returns
182    ///
183    /// A byte slice containing the consumed characters.
184    #[inline]
185    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
186        let start = self.offset;
187        if !ignore_ascii_case {
188            // For a single-byte search, use memchr.
189            if search.len() == 1 {
190                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
191                    self.offset += pos;
192                    &self.bytes[start..self.offset]
193                } else {
194                    self.offset = self.length;
195                    &self.bytes[start..self.length]
196                }
197            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
198                self.offset += pos;
199                &self.bytes[start..self.offset]
200            } else {
201                self.offset = self.length;
202                &self.bytes[start..self.length]
203            }
204        } else {
205            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
206                self.offset += 1;
207            }
208
209            &self.bytes[start..self.offset]
210        }
211    }
212
213    #[inline]
214    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
215        let start = self.offset;
216        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
217            self.offset += pos + 1;
218
219            &self.bytes[start..self.offset]
220        } else {
221            self.offset = self.length;
222
223            &self.bytes[start..self.length]
224        }
225    }
226
227    /// Consumes whitespaces until a non-whitespace character is found.
228    ///
229    /// # Returns
230    ///
231    /// A byte slice containing the consumed whitespaces.
232    #[inline]
233    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
234        let start = self.offset;
235        let bytes = self.bytes;
236        let len = self.length;
237        while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
238            self.offset += 1;
239        }
240
241        &bytes[start..self.offset]
242    }
243
244    /// Reads the next `n` characters without advancing the position.
245    ///
246    /// # Arguments
247    ///
248    /// * `n` - The number of characters to read.
249    ///
250    /// # Returns
251    ///
252    /// A byte slice containing the next `n` characters.
253    #[inline]
254    pub fn read(&self, n: usize) -> &'a [u8] {
255        let (from, until) = self.calculate_bound(n);
256
257        &self.bytes[from..until]
258    }
259
260    /// Reads a single byte at a specific byte offset within the input slice,
261    /// without advancing the internal cursor.
262    ///
263    /// This provides direct, low-level access to the underlying byte data.
264    ///
265    /// # Arguments
266    ///
267    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
268    ///   from which to read the byte.
269    ///
270    /// # Returns
271    ///
272    /// A reference to the byte located at the specified offset `at`.
273    ///
274    /// # Panics
275    ///
276    /// This method **panics** if the provided `at` offset is out of bounds
277    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
278    pub fn read_at(&self, at: usize) -> &'a u8 {
279        &self.bytes[at]
280    }
281
282    /// Checks if the input at the current position matches the given byte slice.
283    ///
284    /// # Arguments
285    ///
286    /// * `search` - The byte slice to compare against the input.
287    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
288    ///
289    /// # Returns
290    ///
291    /// `true` if the next bytes match `search`; `false` otherwise.
292    #[inline]
293    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
294        let (from, until) = self.calculate_bound(search.len());
295        let slice = &self.bytes[from..until];
296
297        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
298    }
299
300    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
301    ///
302    /// This method tries to match the provided byte slice `search` against the input starting
303    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
304    /// are skipped during matching, but their length is included in the returned length.
305    ///
306    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
307    /// in the returned length.
308    ///
309    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
310    /// and this method would return the total length of the input consumed to match `(string)`,
311    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
312    ///
313    /// # Arguments
314    ///
315    /// * `search` - The byte slice to match against the input.
316    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
317    ///
318    /// # Returns
319    ///
320    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
321    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
322    /// * `None` - If the input does not match `search`.
323    ///
324    /// # Examples
325    ///
326    /// ```rust
327    /// use mago_syntax_core::input::Input;
328    /// use mago_source::SourceIdentifier;
329    ///
330    /// let source = SourceIdentifier::dummy();
331    ///
332    /// // Given input "( string ) x", starting at offset 0:
333    /// let input = Input::new(source.clone(), b"( string ) x");
334    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
335    ///
336    /// // Given input "(int)", with no whitespace:
337    /// let input = Input::new(source.clone(), b"(int)");
338    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
339    ///
340    /// // Given input "(  InT   )abc", ignoring ASCII case:
341    /// let input = Input::new(source.clone(), b"(  InT   )abc");
342    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
343    ///
344    /// // Given input "(integer)", attempting to match "(int)":
345    /// let input = Input::new(source.clone(), b"(integer)");
346    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
347    ///
348    /// // Trailing whitespace after ')':
349    /// let input = Input::new(source.clone(), b"(int)   x");
350    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
351    /// ```
352    #[inline]
353    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
354        let mut offset = self.offset;
355        let mut search_offset = 0;
356        let mut length = 0;
357        let bytes = self.bytes;
358        let total = self.length;
359        while search_offset < search.len() {
360            // Skip whitespace in the input.
361            while offset < total && bytes[offset].is_ascii_whitespace() {
362                offset += 1;
363                length += 1;
364            }
365
366            if offset >= total {
367                return None;
368            }
369
370            let input_byte = bytes[offset];
371            let search_byte = search[search_offset];
372            let matched = if ignore_ascii_case {
373                input_byte.eq_ignore_ascii_case(&search_byte)
374            } else {
375                input_byte == search_byte
376            };
377
378            if matched {
379                offset += 1;
380                length += 1;
381                search_offset += 1;
382            } else {
383                return None;
384            }
385        }
386
387        Some(length)
388    }
389
390    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
391    ///
392    /// # Arguments
393    ///
394    /// * `offset` - The number of characters to skip before reading.
395    /// * `n` - The number of characters to read after skipping.
396    ///
397    /// # Returns
398    ///
399    /// A byte slice containing the peeked characters.
400    #[inline]
401    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
402        let from = self.offset + offset;
403        if from >= self.length {
404            return &self.bytes[self.length..self.length];
405        }
406
407        let mut until = from + n;
408        if until >= self.length {
409            until = self.length;
410        }
411
412        &self.bytes[from..until]
413    }
414
415    /// Calculates the bounds for slicing the input safely.
416    ///
417    /// Ensures that slicing does not go beyond the input length.
418    ///
419    /// # Arguments
420    ///
421    /// * `n` - The number of characters to include in the slice.
422    ///
423    /// # Returns
424    ///
425    /// A tuple `(from, until)` representing the start and end indices for slicing.
426    #[inline]
427    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
428        if self.has_reached_eof() {
429            return (self.length, self.length);
430        }
431
432        let mut until = self.offset + n;
433
434        if until >= self.length {
435            until = self.length;
436        }
437
438        (self.offset, until)
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use mago_span::Position;
445
446    use super::*;
447
448    #[test]
449    fn test_new() {
450        let bytes = b"Hello, world!";
451        let input = Input::new(SourceIdentifier::dummy(), bytes);
452
453        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
454        assert_eq!(input.length, bytes.len());
455        assert_eq!(input.bytes, bytes);
456    }
457
458    #[test]
459    fn test_is_eof() {
460        let bytes = b"";
461        let input = Input::new(SourceIdentifier::dummy(), bytes);
462
463        assert!(input.has_reached_eof());
464
465        let bytes = b"data";
466        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
467
468        assert!(!input.has_reached_eof());
469
470        input.skip(4);
471
472        assert!(input.has_reached_eof());
473    }
474
475    #[test]
476    fn test_next() {
477        let bytes = b"a\nb\r\nc\rd";
478        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
479
480        // 'a'
481        input.next();
482        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 1));
483
484        // '\n'
485        input.next();
486        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 2));
487
488        // 'b'
489        input.next();
490        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
491
492        // '\r\n' should be treated as one newline
493        input.next();
494        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 4));
495
496        // 'c'
497        input.next();
498        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 5));
499
500        // '\r'
501        input.next();
502        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
503
504        // 'd'
505        input.next();
506        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 7));
507    }
508
509    #[test]
510    fn test_consume() {
511        let bytes = b"abcdef";
512        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
513
514        let consumed = input.consume(3);
515        assert_eq!(consumed, b"abc");
516        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
517
518        let consumed = input.consume(3);
519        assert_eq!(consumed, b"def");
520        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
521
522        let consumed = input.consume(1); // Should return empty slice at EOF
523        assert_eq!(consumed, b"");
524        assert!(input.has_reached_eof());
525    }
526
527    #[test]
528    fn test_consume_remaining() {
529        let bytes = b"abcdef";
530        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
531
532        input.skip(2);
533        let remaining = input.consume_remaining();
534        assert_eq!(remaining, b"cdef");
535        assert!(input.has_reached_eof());
536    }
537
538    #[test]
539    fn test_read() {
540        let bytes = b"abcdef";
541        let input = Input::new(SourceIdentifier::dummy(), bytes);
542
543        let read = input.read(3);
544        assert_eq!(read, b"abc");
545        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
546        // Position should not change
547    }
548
549    #[test]
550    fn test_is_at() {
551        let bytes = b"abcdef";
552        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
553
554        assert!(input.is_at(b"abc", false));
555        input.skip(2);
556        assert!(input.is_at(b"cde", false));
557        assert!(!input.is_at(b"xyz", false));
558    }
559
560    #[test]
561    fn test_is_at_ignore_ascii_case() {
562        let bytes = b"AbCdEf";
563        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
564
565        assert!(input.is_at(b"abc", true));
566        input.skip(2);
567        assert!(input.is_at(b"cde", true));
568        assert!(!input.is_at(b"xyz", true));
569    }
570
571    #[test]
572    fn test_peek() {
573        let bytes = b"abcdef";
574        let input = Input::new(SourceIdentifier::dummy(), bytes);
575
576        let peeked = input.peek(2, 3);
577        assert_eq!(peeked, b"cde");
578        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
579        // Position should not change
580    }
581
582    #[test]
583    fn test_to_bound() {
584        let bytes = b"abcdef";
585        let input = Input::new(SourceIdentifier::dummy(), bytes);
586
587        let (from, until) = input.calculate_bound(3);
588        assert_eq!((from, until), (0, 3));
589
590        let (from, until) = input.calculate_bound(10); // Exceeds length
591        assert_eq!((from, until), (0, 6));
592    }
593}