Skip to main content

mago_syntax_core/
input.rs

1use bumpalo::Bump;
2use memchr::memchr;
3use memchr::memmem::find;
4
5use mago_database::file::File;
6use mago_database::file::FileId;
7use mago_database::file::HasFileId;
8use mago_span::Position;
9
10/// Lookup table for ASCII whitespace (space, tab, newline, carriage return, form feed, vertical tab)
11const WHITESPACE_TABLE: [bool; 256] = {
12    let mut table = [false; 256];
13    table[b' ' as usize] = true;
14    table[b'\t' as usize] = true;
15    table[b'\n' as usize] = true;
16    table[b'\r' as usize] = true;
17    table[0x0C] = true;
18    table[0x0B] = true;
19    table
20};
21
22/// Lookup table for identifier continuation characters (a-z, A-Z, 0-9, _, 0x80-0xFF)
23const IDENT_PART_TABLE: [bool; 256] = {
24    let mut table = [false; 256];
25    let mut i = 0usize;
26    while i < 256 {
27        table[i] = matches!(i as u8, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | 0x80..=0xFF);
28        i += 1;
29    }
30    table
31};
32
33/// Lookup table for identifier start characters (a-z, A-Z, _, 0x80-0xFF)
34const IDENT_START_TABLE: [bool; 256] = {
35    let mut table = [false; 256];
36    let mut i = 0usize;
37    while i < 256 {
38        table[i] = matches!(i as u8, b'a'..=b'z' | b'A'..=b'Z' | b'_' | 0x80..=0xFF);
39        i += 1;
40    }
41    table
42};
43
44/// A struct representing the input code being lexed.
45///
46/// The `Input` struct provides methods to read, peek, consume, and skip characters
47/// from the bytes input code while keeping track of the current position (line, column, offset).
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
49pub struct Input<'a> {
50    pub(crate) bytes: &'a [u8],
51    pub(crate) length: usize,
52    pub(crate) offset: usize,
53    pub(crate) starting_position: Position,
54    pub(crate) file_id: FileId,
55}
56
57impl<'a> Input<'a> {
58    /// Creates a new `Input` instance from the given input.
59    ///
60    /// # Arguments
61    ///
62    /// * `file_id` - The unique identifier for the source file this input belongs to.
63    /// * `bytes` - A byte slice representing the input code to be lexed.
64    ///
65    /// # Returns
66    ///
67    /// A new `Input` instance initialized at the beginning of the input.
68    #[must_use]
69    pub fn new(file_id: FileId, bytes: &'a [u8]) -> Self {
70        let length = bytes.len();
71
72        Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
73    }
74
75    /// Creates a new `Input` instance from the contents of a `File`.
76    ///
77    /// # Arguments
78    ///
79    /// * `file` - A reference to the `File` containing the source code.
80    ///
81    /// # Returns
82    ///
83    /// A new `Input` instance initialized with the file's ID and contents.
84    #[must_use]
85    pub fn from_file(file: &'a File) -> Self {
86        Self::new(file.id, file.contents.as_bytes())
87    }
88
89    /// Creates a new `Input` instance from the contents of a `File`.
90    ///
91    /// # Arguments
92    ///
93    /// * `file` - A reference to the `File` containing the source code.
94    ///
95    /// # Returns
96    ///
97    /// A new `Input` instance initialized with the file's ID and contents.
98    pub fn from_file_in(arena: &'a Bump, file: &File) -> Self {
99        Self::new(file.id, arena.alloc_slice_clone(file.contents.as_bytes()))
100    }
101
102    /// Creates a new `Input` instance representing a byte slice that is
103    /// "anchored" at a specific absolute position within a larger source file.
104    ///
105    /// This is useful when lexing a subset (slice) of a source file, as it allows
106    /// generated tokens to retain accurate absolute positions and spans relative
107    /// to the original file.
108    ///
109    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
110    /// but the absolute position is calculated relative to the `anchor_position`.
111    ///
112    /// # Arguments
113    ///
114    /// * `file_id` - The unique identifier for the source file this input belongs to.
115    /// * `bytes` - A byte slice representing the input code subset to be lexed.
116    /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
117    ///
118    /// # Returns
119    ///
120    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
121    /// relative to `anchor_position`.
122    #[must_use]
123    pub fn anchored_at(file_id: FileId, bytes: &'a [u8], anchor_position: Position) -> Self {
124        let length = bytes.len();
125
126        Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
127    }
128
129    /// Returns the source file identifier of the input code.
130    #[inline]
131    #[must_use]
132    pub const fn file_id(&self) -> FileId {
133        self.file_id
134    }
135
136    /// Returns the absolute current `Position` of the lexer within the original source file.
137    ///
138    /// It calculates this by adding the internal offset (progress within the current byte slice)
139    /// to the `starting_position` the `Input` was initialized with.
140    #[inline]
141    #[must_use]
142    pub const fn current_position(&self) -> Position {
143        // Calculate absolute position by adding internal offset to the starting base
144        self.starting_position.forward(self.offset as u32)
145    }
146
147    /// Returns the current internal byte offset relative to the start of the input slice.
148    ///
149    /// This indicates how many bytes have been consumed from the current `bytes` slice.
150    /// To get the absolute position in the original source file, use `current_position()`.
151    #[inline]
152    #[must_use]
153    pub const fn current_offset(&self) -> usize {
154        self.offset
155    }
156
157    /// Returns `true` if the input slice is empty (length is zero).
158    #[inline]
159    #[must_use]
160    pub const fn is_empty(&self) -> bool {
161        self.length == 0
162    }
163
164    /// Returns the total length in bytes of the input slice being processed.
165    #[inline]
166    #[must_use]
167    pub const fn len(&self) -> usize {
168        self.length
169    }
170
171    /// Checks if the current position is at the end of the input.
172    ///
173    /// # Returns
174    ///
175    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
176    #[inline(always)]
177    #[must_use]
178    pub const fn has_reached_eof(&self) -> bool {
179        self.offset >= self.length
180    }
181
182    /// Returns a byte slice within a specified absolute range.
183    ///
184    /// The `from` and `to` arguments are absolute byte offsets from the beginning
185    /// of the original source file. The method calculates the correct slice
186    /// relative to the `starting_position` of this `Input`.
187    ///
188    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
189    /// positions are absolute, even when the `Input` only contains a subsection
190    /// of the source file.
191    ///
192    /// The returned slice is defensively clamped to the bounds of the current
193    /// `Input`'s byte slice to prevent panics.
194    ///
195    /// # Arguments
196    ///
197    /// * `from` - The absolute starting byte offset.
198    /// * `to` - The absolute ending byte offset (exclusive).
199    ///
200    /// # Returns
201    ///
202    /// A byte slice `&[u8]` corresponding to the requested range.
203    #[inline]
204    #[must_use]
205    pub fn slice_in_range(&self, from: u32, to: u32) -> &'a [u8] {
206        let base_offset = self.starting_position.offset;
207
208        // Calculate the start and end positions relative to the local `bytes` slice.
209        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
210        let local_from = from.saturating_sub(base_offset) as usize;
211        let local_to = to.saturating_sub(base_offset) as usize;
212
213        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
214        let start = local_from.min(self.length);
215        let end = local_to.min(self.length);
216
217        // Ensure the start index is not greater than the end index.
218        if start >= end {
219            return &[];
220        }
221
222        // If the start index is beyond the length of the input, return an empty slice.
223        if start >= self.length {
224            return &[];
225        }
226
227        &self.bytes[start..end]
228    }
229
230    /// Advances the current position by one character, updating line and column numbers.
231    ///
232    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
233    ///
234    /// If the end of input is reached, no action is taken.
235    #[inline(always)]
236    pub fn next(&mut self) {
237        if self.offset < self.length {
238            self.offset += 1;
239        }
240    }
241
242    /// Skips the next `count` characters, advancing the position accordingly.
243    ///
244    /// Updates offset by `count`, clamping to the input length.
245    ///
246    /// # Arguments
247    ///
248    /// * `count` - The number of characters to skip.
249    #[inline]
250    pub fn skip(&mut self, count: usize) {
251        self.offset = (self.offset + count).min(self.length);
252    }
253
254    /// Consumes the next `count` characters and returns them as a slice.
255    ///
256    /// Advances the position by `count` characters.
257    ///
258    /// # Arguments
259    ///
260    /// * `count` - The number of characters to consume.
261    ///
262    /// # Returns
263    ///
264    /// A byte slice containing the consumed characters.
265    #[inline(always)]
266    pub fn consume(&mut self, count: usize) -> &'a [u8] {
267        let from = self.offset;
268        let until = (from + count).min(self.length);
269        self.offset = until;
270        // SAFETY: from <= until <= self.length is guaranteed
271        unsafe { self.bytes.get_unchecked(from..until) }
272    }
273
274    /// Consumes all remaining characters from the current position to the end of input.
275    ///
276    /// Advances the position to EOF.
277    ///
278    /// # Returns
279    ///
280    /// A byte slice containing the remaining characters.
281    #[inline]
282    pub fn consume_remaining(&mut self) -> &'a [u8] {
283        if self.has_reached_eof() {
284            return &[];
285        }
286
287        let from = self.offset;
288        self.offset = self.length;
289
290        &self.bytes[from..]
291    }
292
293    /// Consumes characters until the given byte slice is found.
294    ///
295    /// Advances the position to the start of the search slice if found,
296    /// or to EOF if not found.
297    ///
298    /// # Arguments
299    ///
300    /// * `search` - The byte slice to search for.
301    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
302    ///
303    /// # Returns
304    ///
305    /// A byte slice containing the consumed characters.
306    #[inline]
307    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
308        let start = self.offset;
309        if ignore_ascii_case {
310            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
311                self.offset += 1;
312            }
313
314            &self.bytes[start..self.offset]
315        } else {
316            // For a single-byte search, use memchr.
317            if search.len() == 1 {
318                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
319                    self.offset += pos;
320                    &self.bytes[start..self.offset]
321                } else {
322                    self.offset = self.length;
323                    &self.bytes[start..self.length]
324                }
325            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
326                self.offset += pos;
327                &self.bytes[start..self.offset]
328            } else {
329                self.offset = self.length;
330                &self.bytes[start..self.length]
331            }
332        }
333    }
334
335    #[inline]
336    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
337        let start = self.offset;
338        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
339            self.offset += pos + 1;
340
341            &self.bytes[start..self.offset]
342        } else {
343            self.offset = self.length;
344
345            &self.bytes[start..self.length]
346        }
347    }
348
349    /// Consumes whitespaces until a non-whitespace character is found.
350    ///
351    /// # Returns
352    ///
353    /// A byte slice containing the consumed whitespaces.
354    #[inline(always)]
355    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
356        let start = self.offset;
357        let bytes = self.bytes;
358        let len = self.length;
359
360        // SAFETY: offset < len is checked in condition, table lookup is always valid
361        while self.offset < len {
362            let b = unsafe { *bytes.get_unchecked(self.offset) };
363            if !WHITESPACE_TABLE[b as usize] {
364                break;
365            }
366            self.offset += 1;
367        }
368
369        // SAFETY: start <= self.offset <= len
370        unsafe { bytes.get_unchecked(start..self.offset) }
371    }
372
373    /// Scans identifier characters starting at `offset_from_current` without consuming them.
374    /// Returns the length of identifier characters found (not including any trailing `\`).
375    /// Also returns whether the identifier ends with `\` followed by an identifier start character.
376    ///
377    /// This is optimized for the common case of scanning simple identifiers.
378    #[inline(always)]
379    pub fn scan_identifier(&self, offset_from_current: usize) -> (usize, bool) {
380        let start = self.offset + offset_from_current;
381        if start >= self.length {
382            return (0, false);
383        }
384
385        let bytes = self.bytes;
386        let len = self.length;
387        let mut pos = start + 1; // Skip first byte (already validated by caller)
388
389        // SAFETY: pos < len checked in condition, table lookups are always valid
390        while pos < len {
391            let b = unsafe { *bytes.get_unchecked(pos) };
392            if IDENT_PART_TABLE[b as usize] {
393                pos += 1;
394            } else if b == b'\\' && pos + 1 < len {
395                let next = unsafe { *bytes.get_unchecked(pos + 1) };
396                if IDENT_START_TABLE[next as usize] {
397                    // Found \ followed by identifier start
398                    return (pos - start, true);
399                }
400                break;
401            } else {
402                break;
403            }
404        }
405
406        (pos - start, false)
407    }
408
409    /// Reads the next `n` characters without advancing the position.
410    ///
411    /// # Arguments
412    ///
413    /// * `n` - The number of characters to read.
414    ///
415    /// # Returns
416    ///
417    /// A byte slice containing the next `n` characters.
418    #[inline(always)]
419    #[must_use]
420    pub fn read(&self, n: usize) -> &'a [u8] {
421        let from = self.offset;
422        let until = (from + n).min(self.length);
423        // SAFETY: from <= until <= self.length is guaranteed by min()
424        unsafe { self.bytes.get_unchecked(from..until) }
425    }
426
427    /// Reads all remaining characters from the current position to the end of input,
428    /// without advancing the position.
429    #[inline(always)]
430    pub fn read_remaining(&self) -> &'a [u8] {
431        let from = self.offset;
432
433        unsafe { self.bytes.get_unchecked(from..) }
434    }
435
436    /// Reads a single byte at a specific byte offset within the input slice,
437    /// without advancing the internal cursor.
438    ///
439    /// This provides direct, low-level access to the underlying byte data.
440    ///
441    /// # Arguments
442    ///
443    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
444    ///   from which to read the byte.
445    ///
446    /// # Returns
447    ///
448    /// A reference to the byte located at the specified offset `at`.
449    ///
450    /// # Panics
451    ///
452    /// This method **panics** if the provided `at` offset is out of bounds
453    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
454    #[must_use]
455    pub fn read_at(&self, at: usize) -> &'a u8 {
456        &self.bytes[at]
457    }
458
459    /// Reads a single byte at a specific byte offset within the input slice,
460    /// without advancing the internal cursor.
461    ///
462    /// # Safety
463    ///
464    /// The caller must ensure that `at` is a valid index within the bounds of the input slice
465    /// (i.e., `at < self.bytes.len()`). Failing to do so results in undefined behavior.
466    #[inline(always)]
467    pub unsafe fn read_at_unchecked(&self, at: usize) -> &'a u8 {
468        // SAFETY: Caller must ensure at < self.length
469        unsafe { self.bytes.get_unchecked(at) }
470    }
471
472    /// Checks if the input at the current position matches the given byte slice.
473    ///
474    /// # Arguments
475    ///
476    /// * `search` - The byte slice to compare against the input.
477    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
478    ///
479    /// # Returns
480    ///
481    /// `true` if the next bytes match `search`; `false` otherwise.
482    #[inline(always)]
483    #[must_use]
484    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
485        let len = search.len();
486        let from = self.offset;
487        let until = (from + len).min(self.length);
488
489        if until - from != len {
490            return false;
491        }
492
493        let slice = &self.bytes[from..until];
494        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
495    }
496
497    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
498    ///
499    /// This method tries to match the provided byte slice `search` against the input starting
500    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
501    /// are skipped during matching, but their length is included in the returned length.
502    ///
503    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
504    /// in the returned length.
505    ///
506    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
507    /// and this method would return the total length of the input consumed to match `(string)`,
508    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
509    ///
510    /// # Arguments
511    ///
512    /// * `search` - The byte slice to match against the input.
513    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
514    ///
515    /// # Returns
516    ///
517    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
518    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
519    /// * `None` - If the input does not match `search`.
520    #[inline]
521    #[must_use]
522    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
523        let mut offset = self.offset;
524        let mut search_offset = 0;
525        let mut length = 0;
526        let bytes = self.bytes;
527        let total = self.length;
528        while search_offset < search.len() {
529            // Skip whitespace in the input.
530            while offset < total && bytes[offset].is_ascii_whitespace() {
531                offset += 1;
532                length += 1;
533            }
534
535            if offset >= total {
536                return None;
537            }
538
539            let input_byte = bytes[offset];
540            let search_byte = search[search_offset];
541            let matched = if ignore_ascii_case {
542                input_byte.eq_ignore_ascii_case(&search_byte)
543            } else {
544                input_byte == search_byte
545            };
546
547            if matched {
548                offset += 1;
549                length += 1;
550                search_offset += 1;
551            } else {
552                return None;
553            }
554        }
555
556        Some(length)
557    }
558
559    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
560    ///
561    /// # Arguments
562    ///
563    /// * `offset` - The number of characters to skip before reading.
564    /// * `n` - The number of characters to read after skipping.
565    ///
566    /// # Returns
567    ///
568    /// A byte slice containing the peeked characters.
569    #[inline(always)]
570    #[must_use]
571    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
572        let from = self.offset + offset;
573        let len = self.length;
574        if from >= len {
575            return &[];
576        }
577        let until = (from + n).min(len);
578        // SAFETY: We verified from < len and until <= len above
579        unsafe { self.bytes.get_unchecked(from..until) }
580    }
581}
582
583impl HasFileId for Input<'_> {
584    fn file_id(&self) -> FileId {
585        self.file_id
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use mago_span::Position;
592
593    use super::*;
594
595    #[test]
596    fn test_new() {
597        let bytes = b"Hello, world!";
598        let input = Input::new(FileId::zero(), bytes);
599
600        assert_eq!(input.current_position(), Position::new(0));
601        assert_eq!(input.length, bytes.len());
602        assert_eq!(input.bytes, bytes);
603    }
604
605    #[test]
606    fn test_is_eof() {
607        let bytes = b"";
608        let input = Input::new(FileId::zero(), bytes);
609
610        assert!(input.has_reached_eof());
611
612        let bytes = b"data";
613        let mut input = Input::new(FileId::zero(), bytes);
614
615        assert!(!input.has_reached_eof());
616
617        input.skip(4);
618
619        assert!(input.has_reached_eof());
620    }
621
622    #[test]
623    fn test_next() {
624        let bytes = b"a\nb\r\nc\rd";
625        let mut input = Input::new(FileId::zero(), bytes);
626
627        // 'a'
628        input.next();
629        assert_eq!(input.current_position(), Position::new(1));
630
631        // '\n'
632        input.next();
633        assert_eq!(input.current_position(), Position::new(2));
634
635        // 'b'
636        input.next();
637        assert_eq!(input.current_position(), Position::new(3));
638
639        // '\r\n' should be treated as one newline
640        input.next();
641        assert_eq!(input.current_position(), Position::new(4));
642
643        // 'c'
644        input.next();
645        assert_eq!(input.current_position(), Position::new(5));
646
647        // '\r'
648        input.next();
649        assert_eq!(input.current_position(), Position::new(6));
650
651        // 'd'
652        input.next();
653        assert_eq!(input.current_position(), Position::new(7));
654    }
655
656    #[test]
657    fn test_consume() {
658        let bytes = b"abcdef";
659        let mut input = Input::new(FileId::zero(), bytes);
660
661        let consumed = input.consume(3);
662        assert_eq!(consumed, b"abc");
663        assert_eq!(input.current_position(), Position::new(3));
664
665        let consumed = input.consume(3);
666        assert_eq!(consumed, b"def");
667        assert_eq!(input.current_position(), Position::new(6));
668
669        let consumed = input.consume(1); // Should return empty slice at EOF
670        assert_eq!(consumed, b"");
671        assert!(input.has_reached_eof());
672    }
673
674    #[test]
675    fn test_consume_remaining() {
676        let bytes = b"abcdef";
677        let mut input = Input::new(FileId::zero(), bytes);
678
679        input.skip(2);
680        let remaining = input.consume_remaining();
681        assert_eq!(remaining, b"cdef");
682        assert!(input.has_reached_eof());
683    }
684
685    #[test]
686    fn test_read() {
687        let bytes = b"abcdef";
688        let input = Input::new(FileId::zero(), bytes);
689
690        let read = input.read(3);
691        assert_eq!(read, b"abc");
692        assert_eq!(input.current_position(), Position::new(0));
693        // Position should not change
694    }
695
696    #[test]
697    fn test_is_at() {
698        let bytes = b"abcdef";
699        let mut input = Input::new(FileId::zero(), bytes);
700
701        assert!(input.is_at(b"abc", false));
702        input.skip(2);
703        assert!(input.is_at(b"cde", false));
704        assert!(!input.is_at(b"xyz", false));
705    }
706
707    #[test]
708    fn test_is_at_ignore_ascii_case() {
709        let bytes = b"AbCdEf";
710        let mut input = Input::new(FileId::zero(), bytes);
711
712        assert!(input.is_at(b"abc", true));
713        input.skip(2);
714        assert!(input.is_at(b"cde", true));
715        assert!(!input.is_at(b"xyz", true));
716    }
717
718    #[test]
719    fn test_peek() {
720        let bytes = b"abcdef";
721        let input = Input::new(FileId::zero(), bytes);
722
723        let peeked = input.peek(2, 3);
724        assert_eq!(peeked, b"cde");
725        assert_eq!(input.current_position(), Position::new(0));
726        // Position should not change
727    }
728}