Skip to main content

mago_syntax_core/
input.rs

1use bumpalo::Bump;
2use memchr::memchr;
3use memchr::memmem::find;
4
5use mago_database::file::File;
6use mago_database::file::FileId;
7use mago_database::file::HasFileId;
8use mago_span::Position;
9
10/// Lookup table for ASCII whitespace (space, tab, newline, carriage return, form feed, vertical tab)
11const WHITESPACE_TABLE: [bool; 256] = {
12    let mut table = [false; 256];
13    table[b' ' as usize] = true;
14    table[b'\t' as usize] = true;
15    table[b'\n' as usize] = true;
16    table[b'\r' as usize] = true;
17    table[0x0C] = true;
18    table[0x0B] = true;
19    table
20};
21
22/// Lookup table for identifier continuation characters (a-z, A-Z, 0-9, _, 0x80-0xFF)
23const IDENT_PART_TABLE: [bool; 256] = {
24    let mut table = [false; 256];
25    let mut i = 0usize;
26    while i < 256 {
27        table[i] = matches!(i as u8, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | 0x80..=0xFF);
28        i += 1;
29    }
30    table
31};
32
33/// Lookup table for identifier start characters (a-z, A-Z, _, 0x80-0xFF)
34const IDENT_START_TABLE: [bool; 256] = {
35    let mut table = [false; 256];
36    let mut i = 0usize;
37    while i < 256 {
38        table[i] = matches!(i as u8, b'a'..=b'z' | b'A'..=b'Z' | b'_' | 0x80..=0xFF);
39        i += 1;
40    }
41    table
42};
43
44/// A struct representing the input code being lexed.
45///
46/// The `Input` struct provides methods to read, peek, consume, and skip characters
47/// from the bytes input code while keeping track of the current position (line, column, offset).
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
49#[allow(clippy::field_scoped_visibility_modifiers)]
50pub struct Input<'src> {
51    pub(crate) bytes: &'src [u8],
52    pub(crate) length: usize,
53    pub(crate) offset: usize,
54    pub(crate) starting_position: Position,
55    pub(crate) file_id: FileId,
56}
57
58impl<'src> Input<'src> {
59    /// Creates a new `Input` instance from the given input.
60    ///
61    /// # Arguments
62    ///
63    /// * `file_id` - The unique identifier for the source file this input belongs to.
64    /// * `bytes` - A byte slice representing the input code to be lexed.
65    ///
66    /// # Returns
67    ///
68    /// A new `Input` instance initialized at the beginning of the input.
69    #[must_use]
70    pub fn new(file_id: FileId, bytes: &'src [u8]) -> Self {
71        let length = bytes.len();
72
73        Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
74    }
75
76    /// Creates a new `Input` instance from the contents of a `File`.
77    ///
78    /// # Arguments
79    ///
80    /// * `file` - A reference to the `File` containing the source code.
81    ///
82    /// # Returns
83    ///
84    /// A new `Input` instance initialized with the file's ID and contents.
85    #[must_use]
86    pub fn from_file(file: &'src File) -> Self {
87        Self::new(file.id, file.contents.as_ref())
88    }
89
90    /// Creates a new `Input` instance from the contents of a `File`.
91    ///
92    /// # Arguments
93    ///
94    /// * `file` - A reference to the `File` containing the source code.
95    ///
96    /// # Returns
97    ///
98    /// A new `Input` instance initialized with the file's ID and contents.
99    pub fn from_file_in(arena: &'src Bump, file: &File) -> Self {
100        Self::new(file.id, arena.alloc_slice_clone(file.contents.as_ref()))
101    }
102
103    /// Creates a new `Input` instance representing a byte slice that is
104    /// "anchored" at a specific absolute position within a larger source file.
105    ///
106    /// This is useful when lexing a subset (slice) of a source file, as it allows
107    /// generated tokens to retain accurate absolute positions and spans relative
108    /// to the original file.
109    ///
110    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
111    /// but the absolute position is calculated relative to the `anchor_position`.
112    ///
113    /// # Arguments
114    ///
115    /// * `file_id` - The unique identifier for the source file this input belongs to.
116    /// * `bytes` - A byte slice representing the input code subset to be lexed.
117    /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
118    ///
119    /// # Returns
120    ///
121    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
122    /// relative to `anchor_position`.
123    #[must_use]
124    pub fn anchored_at(file_id: FileId, bytes: &'src [u8], anchor_position: Position) -> Self {
125        let length = bytes.len();
126
127        Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
128    }
129
130    /// Returns the source file identifier of the input code.
131    #[inline]
132    #[must_use]
133    pub const fn file_id(&self) -> FileId {
134        self.file_id
135    }
136
137    /// Returns the absolute current `Position` of the lexer within the original source file.
138    ///
139    /// It calculates this by adding the internal offset (progress within the current byte slice)
140    /// to the `starting_position` the `Input` was initialized with.
141    #[inline]
142    #[must_use]
143    pub const fn current_position(&self) -> Position {
144        // Calculate absolute position by adding internal offset to the starting base
145        self.starting_position.forward(self.offset as u32)
146    }
147
148    /// Returns the current internal byte offset relative to the start of the input slice.
149    ///
150    /// This indicates how many bytes have been consumed from the current `bytes` slice.
151    /// To get the absolute position in the original source file, use `current_position()`.
152    #[inline]
153    #[must_use]
154    pub const fn current_offset(&self) -> usize {
155        self.offset
156    }
157
158    /// Returns `true` if the input slice is empty (length is zero).
159    #[inline]
160    #[must_use]
161    pub const fn is_empty(&self) -> bool {
162        self.length == 0
163    }
164
165    /// Returns the total length in bytes of the input slice being processed.
166    #[inline]
167    #[must_use]
168    pub const fn len(&self) -> usize {
169        self.length
170    }
171
172    /// Checks if the current position is at the end of the input.
173    ///
174    /// # Returns
175    ///
176    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
177    #[inline(always)]
178    #[must_use]
179    pub const fn has_reached_eof(&self) -> bool {
180        self.offset >= self.length
181    }
182
183    /// Returns a byte slice within a specified absolute range.
184    ///
185    /// The `from` and `to` arguments are absolute byte offsets from the beginning
186    /// of the original source file. The method calculates the correct slice
187    /// relative to the `starting_position` of this `Input`.
188    ///
189    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
190    /// positions are absolute, even when the `Input` only contains a subsection
191    /// of the source file.
192    ///
193    /// The returned slice is defensively clamped to the bounds of the current
194    /// `Input`'s byte slice to prevent panics.
195    ///
196    /// # Arguments
197    ///
198    /// * `from` - The absolute starting byte offset.
199    /// * `to` - The absolute ending byte offset (exclusive).
200    ///
201    /// # Returns
202    ///
203    /// A byte slice `&[u8]` corresponding to the requested range.
204    #[inline]
205    #[must_use]
206    pub fn slice_in_range(&self, from: u32, to: u32) -> &'src [u8] {
207        let base_offset = self.starting_position.offset;
208
209        // Calculate the start and end positions relative to the local `bytes` slice.
210        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
211        let local_from = from.saturating_sub(base_offset) as usize;
212        let local_to = to.saturating_sub(base_offset) as usize;
213
214        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
215        let start = local_from.min(self.length);
216        let end = local_to.min(self.length);
217
218        // Ensure the start index is not greater than the end index.
219        if start >= end {
220            return &[];
221        }
222
223        // If the start index is beyond the length of the input, return an empty slice.
224        if start >= self.length {
225            return &[];
226        }
227
228        &self.bytes[start..end]
229    }
230
231    /// Advances the current position by one character, updating line and column numbers.
232    ///
233    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
234    ///
235    /// If the end of input is reached, no action is taken.
236    #[inline(always)]
237    pub fn next(&mut self) {
238        if self.offset < self.length {
239            self.offset += 1;
240        }
241    }
242
243    /// Skips the next `count` characters, advancing the position accordingly.
244    ///
245    /// Updates offset by `count`, clamping to the input length.
246    ///
247    /// # Arguments
248    ///
249    /// * `count` - The number of characters to skip.
250    #[inline]
251    pub fn skip(&mut self, count: usize) {
252        self.offset = (self.offset + count).min(self.length);
253    }
254
255    /// Consumes the next `count` characters and returns them as a slice.
256    ///
257    /// Advances the position by `count` characters.
258    ///
259    /// # Arguments
260    ///
261    /// * `count` - The number of characters to consume.
262    ///
263    /// # Returns
264    ///
265    /// A byte slice containing the consumed characters.
266    #[inline(always)]
267    pub fn consume(&mut self, count: usize) -> &'src [u8] {
268        let from = self.offset;
269        let until = (from + count).min(self.length);
270        self.offset = until;
271        // SAFETY: from <= until <= self.length is guaranteed
272        unsafe { self.bytes.get_unchecked(from..until) }
273    }
274
275    /// Consumes all remaining characters from the current position to the end of input.
276    ///
277    /// Advances the position to EOF.
278    ///
279    /// # Returns
280    ///
281    /// A byte slice containing the remaining characters.
282    #[inline]
283    pub fn consume_remaining(&mut self) -> &'src [u8] {
284        if self.has_reached_eof() {
285            return &[];
286        }
287
288        let from = self.offset;
289        self.offset = self.length;
290
291        &self.bytes[from..]
292    }
293
294    /// Consumes characters until the given byte slice is found.
295    ///
296    /// Advances the position to the start of the search slice if found,
297    /// or to EOF if not found.
298    ///
299    /// # Arguments
300    ///
301    /// * `search` - The byte slice to search for.
302    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
303    ///
304    /// # Returns
305    ///
306    /// A byte slice containing the consumed characters.
307    #[inline]
308    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'src [u8] {
309        let start = self.offset;
310        if ignore_ascii_case {
311            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
312                self.offset += 1;
313            }
314
315            &self.bytes[start..self.offset]
316        } else {
317            // For a single-byte search, use memchr.
318            if search.len() == 1 {
319                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
320                    self.offset += pos;
321                    &self.bytes[start..self.offset]
322                } else {
323                    self.offset = self.length;
324                    &self.bytes[start..self.length]
325                }
326            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
327                self.offset += pos;
328                &self.bytes[start..self.offset]
329            } else {
330                self.offset = self.length;
331                &self.bytes[start..self.length]
332            }
333        }
334    }
335
336    #[inline]
337    pub fn consume_through(&mut self, search: u8) -> &'src [u8] {
338        let start = self.offset;
339        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
340            self.offset += pos + 1;
341
342            &self.bytes[start..self.offset]
343        } else {
344            self.offset = self.length;
345
346            &self.bytes[start..self.length]
347        }
348    }
349
350    /// Consumes whitespaces until a non-whitespace character is found.
351    ///
352    /// # Returns
353    ///
354    /// A byte slice containing the consumed whitespaces.
355    #[inline(always)]
356    pub fn consume_whitespaces(&mut self) -> &'src [u8] {
357        let start = self.offset;
358        let bytes = self.bytes;
359        let len = self.length;
360
361        while self.offset < len {
362            // SAFETY: `self.offset < len` was just checked, so the index is in bounds.
363            let b = unsafe { *bytes.get_unchecked(self.offset) };
364            if !WHITESPACE_TABLE[b as usize] {
365                break;
366            }
367            self.offset += 1;
368        }
369
370        // SAFETY: `start` and `self.offset` are both in `0..=len`, and `start <= self.offset` because
371        // the loop only ever increments `self.offset` from its initial `start` value.
372        unsafe { bytes.get_unchecked(start..self.offset) }
373    }
374
375    /// Scans identifier characters starting at `offset_from_current` without consuming them.
376    /// Returns the length of identifier characters found (not including any trailing `\`).
377    /// Also returns whether the identifier ends with `\` followed by an identifier start character.
378    ///
379    /// This is optimized for the common case of scanning simple identifiers.
380    #[inline(always)]
381    #[must_use]
382    pub fn scan_identifier(&self, offset_from_current: usize) -> (usize, bool) {
383        let start = self.offset + offset_from_current;
384        if start >= self.length {
385            return (0, false);
386        }
387
388        let bytes = self.bytes;
389        let len = self.length;
390        let mut pos = start + 1; // Skip first byte (already validated by caller)
391
392        while pos < len {
393            // SAFETY: `pos < len` was just checked.
394            let b = unsafe { *bytes.get_unchecked(pos) };
395            if IDENT_PART_TABLE[b as usize] {
396                pos += 1;
397            } else if b == b'\\' && pos + 1 < len {
398                // SAFETY: `pos + 1 < len` was just checked in the `else if` guard.
399                let next = unsafe { *bytes.get_unchecked(pos + 1) };
400                if IDENT_START_TABLE[next as usize] {
401                    // Found \ followed by identifier start
402                    return (pos - start, true);
403                }
404                break;
405            } else {
406                break;
407            }
408        }
409
410        (pos - start, false)
411    }
412
413    /// Reads the next `n` characters without advancing the position.
414    ///
415    /// # Arguments
416    ///
417    /// * `n` - The number of characters to read.
418    ///
419    /// # Returns
420    ///
421    /// A byte slice containing the next `n` characters.
422    #[inline(always)]
423    #[must_use]
424    pub fn read(&self, n: usize) -> &'src [u8] {
425        let from = self.offset;
426        let until = (from + n).min(self.length);
427        // SAFETY: from <= until <= self.length is guaranteed by min()
428        unsafe { self.bytes.get_unchecked(from..until) }
429    }
430
431    /// Reads all remaining characters from the current position to the end of input,
432    /// without advancing the position.
433    #[inline(always)]
434    #[must_use]
435    pub fn read_remaining(&self) -> &'src [u8] {
436        let from = self.offset;
437
438        // SAFETY: `from = self.offset` is always in `0..=self.length`, so the open-ended range is in bounds.
439        unsafe { self.bytes.get_unchecked(from..) }
440    }
441
442    /// Reads a single byte at a specific byte offset within the input slice,
443    /// without advancing the internal cursor.
444    ///
445    /// This provides direct, low-level access to the underlying byte data.
446    ///
447    /// # Arguments
448    ///
449    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
450    ///   from which to read the byte.
451    ///
452    /// # Returns
453    ///
454    /// A reference to the byte located at the specified offset `at`.
455    ///
456    /// # Panics
457    ///
458    /// This method **panics** if the provided `at` offset is out of bounds
459    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
460    #[must_use]
461    pub fn read_at(&self, at: usize) -> &'src u8 {
462        &self.bytes[at]
463    }
464
465    /// Reads a single byte at a specific byte offset within the input slice,
466    /// without advancing the internal cursor.
467    ///
468    /// # Safety
469    ///
470    /// The caller must ensure that `at` is a valid index within the bounds of the input slice
471    /// (i.e., `at < self.bytes.len()`). Failing to do so results in undefined behavior.
472    #[inline(always)]
473    #[must_use]
474    pub unsafe fn read_at_unchecked(&self, at: usize) -> &'src u8 {
475        // SAFETY: Caller must ensure at < self.length
476        unsafe { self.bytes.get_unchecked(at) }
477    }
478
479    /// Checks if the input at the current position matches the given byte slice.
480    ///
481    /// # Arguments
482    ///
483    /// * `search` - The byte slice to compare against the input.
484    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
485    ///
486    /// # Returns
487    ///
488    /// `true` if the next bytes match `search`; `false` otherwise.
489    #[inline(always)]
490    #[must_use]
491    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
492        let len = search.len();
493        let from = self.offset;
494        let until = (from + len).min(self.length);
495
496        if until - from != len {
497            return false;
498        }
499
500        let slice = &self.bytes[from..until];
501        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
502    }
503
504    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
505    ///
506    /// This method tries to match the provided byte slice `search` against the input starting
507    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
508    /// are skipped during matching, but their length is included in the returned length.
509    ///
510    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
511    /// in the returned length.
512    ///
513    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
514    /// and this method would return the total length of the input consumed to match `(string)`,
515    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
516    ///
517    /// # Arguments
518    ///
519    /// * `search` - The byte slice to match against the input.
520    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
521    ///
522    /// # Returns
523    ///
524    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
525    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
526    /// * `None` - If the input does not match `search`.
527    #[inline]
528    #[must_use]
529    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
530        let mut offset = self.offset;
531        let mut search_offset = 0;
532        let mut length = 0;
533        let bytes = self.bytes;
534        let total = self.length;
535        while search_offset < search.len() {
536            // Skip whitespace in the input.
537            while offset < total && bytes[offset].is_ascii_whitespace() {
538                offset += 1;
539                length += 1;
540            }
541
542            if offset >= total {
543                return None;
544            }
545
546            let input_byte = bytes[offset];
547            let search_byte = search[search_offset];
548            let matched = if ignore_ascii_case {
549                input_byte.eq_ignore_ascii_case(&search_byte)
550            } else {
551                input_byte == search_byte
552            };
553
554            if matched {
555                offset += 1;
556                length += 1;
557                search_offset += 1;
558            } else {
559                return None;
560            }
561        }
562
563        Some(length)
564    }
565
566    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
567    ///
568    /// # Arguments
569    ///
570    /// * `offset` - The number of characters to skip before reading.
571    /// * `n` - The number of characters to read after skipping.
572    ///
573    /// # Returns
574    ///
575    /// A byte slice containing the peeked characters.
576    #[inline(always)]
577    #[must_use]
578    pub fn peek(&self, offset: usize, n: usize) -> &'src [u8] {
579        let from = self.offset + offset;
580        let len = self.length;
581        if from >= len {
582            return &[];
583        }
584
585        let until = (from + n).min(len);
586        // SAFETY: We verified from < len and until <= len above
587        unsafe { self.bytes.get_unchecked(from..until) }
588    }
589}
590
591impl HasFileId for Input<'_> {
592    fn file_id(&self) -> FileId {
593        self.file_id
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use mago_span::Position;
600
601    use super::*;
602
603    #[test]
604    fn test_new() {
605        let bytes = b"Hello, world!";
606        let input = Input::new(FileId::zero(), bytes);
607
608        assert_eq!(input.current_position(), Position::new(0));
609        assert_eq!(input.length, bytes.len());
610        assert_eq!(input.bytes, bytes);
611    }
612
613    #[test]
614    fn test_is_eof() {
615        let bytes = b"";
616        let input = Input::new(FileId::zero(), bytes);
617
618        assert!(input.has_reached_eof());
619
620        let bytes = b"data";
621        let mut input = Input::new(FileId::zero(), bytes);
622
623        assert!(!input.has_reached_eof());
624
625        input.skip(4);
626
627        assert!(input.has_reached_eof());
628    }
629
630    #[test]
631    fn test_next() {
632        let bytes = b"a\nb\r\nc\rd";
633        let mut input = Input::new(FileId::zero(), bytes);
634
635        // 'a'
636        input.next();
637        assert_eq!(input.current_position(), Position::new(1));
638
639        // '\n'
640        input.next();
641        assert_eq!(input.current_position(), Position::new(2));
642
643        // 'b'
644        input.next();
645        assert_eq!(input.current_position(), Position::new(3));
646
647        // '\r\n' should be treated as one newline
648        input.next();
649        assert_eq!(input.current_position(), Position::new(4));
650
651        // 'c'
652        input.next();
653        assert_eq!(input.current_position(), Position::new(5));
654
655        // '\r'
656        input.next();
657        assert_eq!(input.current_position(), Position::new(6));
658
659        // 'd'
660        input.next();
661        assert_eq!(input.current_position(), Position::new(7));
662    }
663
664    #[test]
665    fn test_consume() {
666        let bytes = b"abcdef";
667        let mut input = Input::new(FileId::zero(), bytes);
668
669        let consumed = input.consume(3);
670        assert_eq!(consumed, b"abc");
671        assert_eq!(input.current_position(), Position::new(3));
672
673        let consumed = input.consume(3);
674        assert_eq!(consumed, b"def");
675        assert_eq!(input.current_position(), Position::new(6));
676
677        let consumed = input.consume(1); // Should return empty slice at EOF
678        assert_eq!(consumed, b"");
679        assert!(input.has_reached_eof());
680    }
681
682    #[test]
683    fn test_consume_remaining() {
684        let bytes = b"abcdef";
685        let mut input = Input::new(FileId::zero(), bytes);
686
687        input.skip(2);
688        let remaining = input.consume_remaining();
689        assert_eq!(remaining, b"cdef");
690        assert!(input.has_reached_eof());
691    }
692
693    #[test]
694    fn test_read() {
695        let bytes = b"abcdef";
696        let input = Input::new(FileId::zero(), bytes);
697
698        let read = input.read(3);
699        assert_eq!(read, b"abc");
700        assert_eq!(input.current_position(), Position::new(0));
701        // Position should not change
702    }
703
704    #[test]
705    fn test_is_at() {
706        let bytes = b"abcdef";
707        let mut input = Input::new(FileId::zero(), bytes);
708
709        assert!(input.is_at(b"abc", false));
710        input.skip(2);
711        assert!(input.is_at(b"cde", false));
712        assert!(!input.is_at(b"xyz", false));
713    }
714
715    #[test]
716    fn test_is_at_ignore_ascii_case() {
717        let bytes = b"AbCdEf";
718        let mut input = Input::new(FileId::zero(), bytes);
719
720        assert!(input.is_at(b"abc", true));
721        input.skip(2);
722        assert!(input.is_at(b"cde", true));
723        assert!(!input.is_at(b"xyz", true));
724    }
725
726    #[test]
727    fn test_peek() {
728        let bytes = b"abcdef";
729        let input = Input::new(FileId::zero(), bytes);
730
731        let peeked = input.peek(2, 3);
732        assert_eq!(peeked, b"cde");
733        assert_eq!(input.current_position(), Position::new(0));
734        // Position should not change
735    }
736}