mago_syntax_core/
input.rs

1use bumpalo::Bump;
2use mago_database::file::HasFileId;
3use memchr::memchr;
4use memchr::memmem::find;
5
6use mago_database::file::File;
7use mago_database::file::FileId;
8use mago_span::Position;
9
10/// A struct representing the input code being lexed.
11///
12/// The `Input` struct provides methods to read, peek, consume, and skip characters
13/// from the bytes input code while keeping track of the current position (line, column, offset).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct Input<'a> {
16    pub(crate) bytes: &'a [u8],
17    pub(crate) length: usize,
18    pub(crate) offset: usize,
19    pub(crate) starting_position: Position,
20    pub(crate) file_id: FileId,
21}
22
23impl<'a> Input<'a> {
24    /// Creates a new `Input` instance from the given input.
25    ///
26    /// # Arguments
27    ///
28    /// * `file_id` - The unique identifier for the source file this input belongs to.
29    /// * `bytes` - A byte slice representing the input code to be lexed.
30    ///
31    /// # Returns
32    ///
33    /// A new `Input` instance initialized at the beginning of the input.
34    #[must_use]
35    pub fn new(file_id: FileId, bytes: &'a [u8]) -> Self {
36        let length = bytes.len();
37
38        Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
39    }
40
41    /// Creates a new `Input` instance from the contents of a `File`.
42    ///
43    /// # Arguments
44    ///
45    /// * `file` - A reference to the `File` containing the source code.
46    ///
47    /// # Returns
48    ///
49    /// A new `Input` instance initialized with the file's ID and contents.
50    #[must_use]
51    pub fn from_file(file: &'a File) -> Self {
52        Self::new(file.id, file.contents.as_bytes())
53    }
54
55    /// Creates a new `Input` instance from the contents of a `File`.
56    ///
57    /// # Arguments
58    ///
59    /// * `file` - A reference to the `File` containing the source code.
60    ///
61    /// # Returns
62    ///
63    /// A new `Input` instance initialized with the file's ID and contents.
64    pub fn from_file_in(arena: &'a Bump, file: &File) -> Self {
65        Self::new(file.id, arena.alloc_slice_clone(file.contents.as_bytes()))
66    }
67
68    /// Creates a new `Input` instance representing a byte slice that is
69    /// "anchored" at a specific absolute position within a larger source file.
70    ///
71    /// This is useful when lexing a subset (slice) of a source file, as it allows
72    /// generated tokens to retain accurate absolute positions and spans relative
73    /// to the original file.
74    ///
75    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
76    /// but the absolute position is calculated relative to the `anchor_position`.
77    ///
78    /// # Arguments
79    ///
80    /// * `file_id` - The unique identifier for the source file this input belongs to.
81    /// * `bytes` - A byte slice representing the input code subset to be lexed.
82    /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
83    ///
84    /// # Returns
85    ///
86    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
87    /// relative to `anchor_position`.
88    #[must_use]
89    pub fn anchored_at(file_id: FileId, bytes: &'a [u8], anchor_position: Position) -> Self {
90        let length = bytes.len();
91
92        Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
93    }
94
95    /// Returns the source file identifier of the input code.
96    #[inline]
97    #[must_use]
98    pub const fn file_id(&self) -> FileId {
99        self.file_id
100    }
101
102    /// Returns the absolute current `Position` of the lexer within the original source file.
103    ///
104    /// It calculates this by adding the internal offset (progress within the current byte slice)
105    /// to the `starting_position` the `Input` was initialized with.
106    #[inline]
107    #[must_use]
108    pub const fn current_position(&self) -> Position {
109        // Calculate absolute position by adding internal offset to the starting base
110        self.starting_position.forward(self.offset as u32)
111    }
112
113    /// Returns the current internal byte offset relative to the start of the input slice.
114    ///
115    /// This indicates how many bytes have been consumed from the current `bytes` slice.
116    /// To get the absolute position in the original source file, use `current_position()`.
117    #[inline]
118    #[must_use]
119    pub const fn current_offset(&self) -> usize {
120        self.offset
121    }
122
123    /// Returns `true` if the input slice is empty (length is zero).
124    #[inline]
125    #[must_use]
126    pub const fn is_empty(&self) -> bool {
127        self.length == 0
128    }
129
130    /// Returns the total length in bytes of the input slice being processed.
131    #[inline]
132    #[must_use]
133    pub const fn len(&self) -> usize {
134        self.length
135    }
136
137    /// Checks if the current position is at the end of the input.
138    ///
139    /// # Returns
140    ///
141    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
142    #[inline]
143    #[must_use]
144    pub const fn has_reached_eof(&self) -> bool {
145        self.offset >= self.length
146    }
147
148    /// Returns a byte slice within a specified absolute range.
149    ///
150    /// The `from` and `to` arguments are absolute byte offsets from the beginning
151    /// of the original source file. The method calculates the correct slice
152    /// relative to the `starting_position` of this `Input`.
153    ///
154    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
155    /// positions are absolute, even when the `Input` only contains a subsection
156    /// of the source file.
157    ///
158    /// The returned slice is defensively clamped to the bounds of the current
159    /// `Input`'s byte slice to prevent panics.
160    ///
161    /// # Arguments
162    ///
163    /// * `from` - The absolute starting byte offset.
164    /// * `to` - The absolute ending byte offset (exclusive).
165    ///
166    /// # Returns
167    ///
168    /// A byte slice `&[u8]` corresponding to the requested range.
169    #[inline]
170    #[must_use]
171    pub fn slice_in_range(&self, from: u32, to: u32) -> &'a [u8] {
172        let base_offset = self.starting_position.offset;
173
174        // Calculate the start and end positions relative to the local `bytes` slice.
175        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
176        let local_from = from.saturating_sub(base_offset) as usize;
177        let local_to = to.saturating_sub(base_offset) as usize;
178
179        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
180        let start = local_from.min(self.length);
181        let end = local_to.min(self.length);
182
183        // Ensure the start index is not greater than the end index.
184        if start >= end {
185            return &[];
186        }
187
188        // If the start index is beyond the length of the input, return an empty slice.
189        if start >= self.length {
190            return &[];
191        }
192
193        &self.bytes[start..end]
194    }
195
196    /// Advances the current position by one character, updating line and column numbers.
197    ///
198    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
199    ///
200    /// If the end of input is reached, no action is taken.
201    #[inline]
202    pub fn next(&mut self) {
203        if !self.has_reached_eof() {
204            self.offset += 1;
205        }
206    }
207
208    /// Skips the next `count` characters, advancing the position accordingly.
209    ///
210    /// Updates line and column numbers as it advances.
211    ///
212    /// # Arguments
213    ///
214    /// * `count` - The number of characters to skip.
215    #[inline]
216    pub fn skip(&mut self, count: usize) {
217        self.offset = (self.offset + count).min(self.length);
218    }
219
220    /// Consumes the next `count` characters and returns them as a slice.
221    ///
222    /// Advances the position by `count` characters.
223    ///
224    /// # Arguments
225    ///
226    /// * `count` - The number of characters to consume.
227    ///
228    /// # Returns
229    ///
230    /// A byte slice containing the consumed characters.
231    #[inline]
232    pub fn consume(&mut self, count: usize) -> &'a [u8] {
233        let (from, until) = self.calculate_bound(count);
234
235        self.skip(count);
236
237        &self.bytes[from..until]
238    }
239
240    /// Consumes all remaining characters from the current position to the end of input.
241    ///
242    /// Advances the position to EOF.
243    ///
244    /// # Returns
245    ///
246    /// A byte slice containing the remaining characters.
247    #[inline]
248    pub fn consume_remaining(&mut self) -> &'a [u8] {
249        if self.has_reached_eof() {
250            return &[];
251        }
252
253        let from = self.offset;
254        self.offset = self.length;
255
256        &self.bytes[from..]
257    }
258
259    /// Consumes characters until the given byte slice is found.
260    ///
261    /// Advances the position to the start of the search slice if found,
262    /// or to EOF if not found.
263    ///
264    /// # Arguments
265    ///
266    /// * `search` - The byte slice to search for.
267    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
268    ///
269    /// # Returns
270    ///
271    /// A byte slice containing the consumed characters.
272    #[inline]
273    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
274        let start = self.offset;
275        if ignore_ascii_case {
276            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
277                self.offset += 1;
278            }
279
280            &self.bytes[start..self.offset]
281        } else {
282            // For a single-byte search, use memchr.
283            if search.len() == 1 {
284                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
285                    self.offset += pos;
286                    &self.bytes[start..self.offset]
287                } else {
288                    self.offset = self.length;
289                    &self.bytes[start..self.length]
290                }
291            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
292                self.offset += pos;
293                &self.bytes[start..self.offset]
294            } else {
295                self.offset = self.length;
296                &self.bytes[start..self.length]
297            }
298        }
299    }
300
301    #[inline]
302    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
303        let start = self.offset;
304        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
305            self.offset += pos + 1;
306
307            &self.bytes[start..self.offset]
308        } else {
309            self.offset = self.length;
310
311            &self.bytes[start..self.length]
312        }
313    }
314
315    /// Consumes whitespaces until a non-whitespace character is found.
316    ///
317    /// # Returns
318    ///
319    /// A byte slice containing the consumed whitespaces.
320    #[inline]
321    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
322        let start = self.offset;
323        let bytes = self.bytes;
324        let len = self.length;
325        while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
326            self.offset += 1;
327        }
328
329        &bytes[start..self.offset]
330    }
331
332    /// Reads the next `n` characters without advancing the position.
333    ///
334    /// # Arguments
335    ///
336    /// * `n` - The number of characters to read.
337    ///
338    /// # Returns
339    ///
340    /// A byte slice containing the next `n` characters.
341    #[inline]
342    #[must_use]
343    pub fn read(&self, n: usize) -> &'a [u8] {
344        let (from, until) = self.calculate_bound(n);
345
346        &self.bytes[from..until]
347    }
348
349    /// Reads a single byte at a specific byte offset within the input slice,
350    /// without advancing the internal cursor.
351    ///
352    /// This provides direct, low-level access to the underlying byte data.
353    ///
354    /// # Arguments
355    ///
356    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
357    ///   from which to read the byte.
358    ///
359    /// # Returns
360    ///
361    /// A reference to the byte located at the specified offset `at`.
362    ///
363    /// # Panics
364    ///
365    /// This method **panics** if the provided `at` offset is out of bounds
366    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
367    #[must_use]
368    pub fn read_at(&self, at: usize) -> &'a u8 {
369        &self.bytes[at]
370    }
371
372    /// Checks if the input at the current position matches the given byte slice.
373    ///
374    /// # Arguments
375    ///
376    /// * `search` - The byte slice to compare against the input.
377    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
378    ///
379    /// # Returns
380    ///
381    /// `true` if the next bytes match `search`; `false` otherwise.
382    #[inline]
383    #[must_use]
384    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
385        let (from, until) = self.calculate_bound(search.len());
386        let slice = &self.bytes[from..until];
387
388        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
389    }
390
391    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
392    ///
393    /// This method tries to match the provided byte slice `search` against the input starting
394    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
395    /// are skipped during matching, but their length is included in the returned length.
396    ///
397    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
398    /// in the returned length.
399    ///
400    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
401    /// and this method would return the total length of the input consumed to match `(string)`,
402    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
403    ///
404    /// # Arguments
405    ///
406    /// * `search` - The byte slice to match against the input.
407    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
408    ///
409    /// # Returns
410    ///
411    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
412    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
413    /// * `None` - If the input does not match `search`.
414    #[inline]
415    #[must_use]
416    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
417        let mut offset = self.offset;
418        let mut search_offset = 0;
419        let mut length = 0;
420        let bytes = self.bytes;
421        let total = self.length;
422        while search_offset < search.len() {
423            // Skip whitespace in the input.
424            while offset < total && bytes[offset].is_ascii_whitespace() {
425                offset += 1;
426                length += 1;
427            }
428
429            if offset >= total {
430                return None;
431            }
432
433            let input_byte = bytes[offset];
434            let search_byte = search[search_offset];
435            let matched = if ignore_ascii_case {
436                input_byte.eq_ignore_ascii_case(&search_byte)
437            } else {
438                input_byte == search_byte
439            };
440
441            if matched {
442                offset += 1;
443                length += 1;
444                search_offset += 1;
445            } else {
446                return None;
447            }
448        }
449
450        Some(length)
451    }
452
453    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
454    ///
455    /// # Arguments
456    ///
457    /// * `offset` - The number of characters to skip before reading.
458    /// * `n` - The number of characters to read after skipping.
459    ///
460    /// # Returns
461    ///
462    /// A byte slice containing the peeked characters.
463    #[inline]
464    #[must_use]
465    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
466        let from = self.offset + offset;
467        if from >= self.length {
468            return &self.bytes[self.length..self.length];
469        }
470
471        let mut until = from + n;
472        if until >= self.length {
473            until = self.length;
474        }
475
476        &self.bytes[from..until]
477    }
478
479    /// Calculates the bounds for slicing the input safely.
480    ///
481    /// Ensures that slicing does not go beyond the input length.
482    ///
483    /// # Arguments
484    ///
485    /// * `n` - The number of characters to include in the slice.
486    ///
487    /// # Returns
488    ///
489    /// A tuple `(from, until)` representing the start and end indices for slicing.
490    #[inline]
491    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
492        if self.has_reached_eof() {
493            return (self.length, self.length);
494        }
495
496        let mut until = self.offset + n;
497
498        if until >= self.length {
499            until = self.length;
500        }
501
502        (self.offset, until)
503    }
504}
505
506impl HasFileId for Input<'_> {
507    fn file_id(&self) -> FileId {
508        self.file_id
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use mago_span::Position;
515
516    use super::*;
517
518    #[test]
519    fn test_new() {
520        let bytes = b"Hello, world!";
521        let input = Input::new(FileId::zero(), bytes);
522
523        assert_eq!(input.current_position(), Position::new(0));
524        assert_eq!(input.length, bytes.len());
525        assert_eq!(input.bytes, bytes);
526    }
527
528    #[test]
529    fn test_is_eof() {
530        let bytes = b"";
531        let input = Input::new(FileId::zero(), bytes);
532
533        assert!(input.has_reached_eof());
534
535        let bytes = b"data";
536        let mut input = Input::new(FileId::zero(), bytes);
537
538        assert!(!input.has_reached_eof());
539
540        input.skip(4);
541
542        assert!(input.has_reached_eof());
543    }
544
545    #[test]
546    fn test_next() {
547        let bytes = b"a\nb\r\nc\rd";
548        let mut input = Input::new(FileId::zero(), bytes);
549
550        // 'a'
551        input.next();
552        assert_eq!(input.current_position(), Position::new(1));
553
554        // '\n'
555        input.next();
556        assert_eq!(input.current_position(), Position::new(2));
557
558        // 'b'
559        input.next();
560        assert_eq!(input.current_position(), Position::new(3));
561
562        // '\r\n' should be treated as one newline
563        input.next();
564        assert_eq!(input.current_position(), Position::new(4));
565
566        // 'c'
567        input.next();
568        assert_eq!(input.current_position(), Position::new(5));
569
570        // '\r'
571        input.next();
572        assert_eq!(input.current_position(), Position::new(6));
573
574        // 'd'
575        input.next();
576        assert_eq!(input.current_position(), Position::new(7));
577    }
578
579    #[test]
580    fn test_consume() {
581        let bytes = b"abcdef";
582        let mut input = Input::new(FileId::zero(), bytes);
583
584        let consumed = input.consume(3);
585        assert_eq!(consumed, b"abc");
586        assert_eq!(input.current_position(), Position::new(3));
587
588        let consumed = input.consume(3);
589        assert_eq!(consumed, b"def");
590        assert_eq!(input.current_position(), Position::new(6));
591
592        let consumed = input.consume(1); // Should return empty slice at EOF
593        assert_eq!(consumed, b"");
594        assert!(input.has_reached_eof());
595    }
596
597    #[test]
598    fn test_consume_remaining() {
599        let bytes = b"abcdef";
600        let mut input = Input::new(FileId::zero(), bytes);
601
602        input.skip(2);
603        let remaining = input.consume_remaining();
604        assert_eq!(remaining, b"cdef");
605        assert!(input.has_reached_eof());
606    }
607
608    #[test]
609    fn test_read() {
610        let bytes = b"abcdef";
611        let input = Input::new(FileId::zero(), bytes);
612
613        let read = input.read(3);
614        assert_eq!(read, b"abc");
615        assert_eq!(input.current_position(), Position::new(0));
616        // Position should not change
617    }
618
619    #[test]
620    fn test_is_at() {
621        let bytes = b"abcdef";
622        let mut input = Input::new(FileId::zero(), bytes);
623
624        assert!(input.is_at(b"abc", false));
625        input.skip(2);
626        assert!(input.is_at(b"cde", false));
627        assert!(!input.is_at(b"xyz", false));
628    }
629
630    #[test]
631    fn test_is_at_ignore_ascii_case() {
632        let bytes = b"AbCdEf";
633        let mut input = Input::new(FileId::zero(), bytes);
634
635        assert!(input.is_at(b"abc", true));
636        input.skip(2);
637        assert!(input.is_at(b"cde", true));
638        assert!(!input.is_at(b"xyz", true));
639    }
640
641    #[test]
642    fn test_peek() {
643        let bytes = b"abcdef";
644        let input = Input::new(FileId::zero(), bytes);
645
646        let peeked = input.peek(2, 3);
647        assert_eq!(peeked, b"cde");
648        assert_eq!(input.current_position(), Position::new(0));
649        // Position should not change
650    }
651
652    #[test]
653    fn test_to_bound() {
654        let bytes = b"abcdef";
655        let input = Input::new(FileId::zero(), bytes);
656
657        let (from, until) = input.calculate_bound(3);
658        assert_eq!((from, until), (0, 3));
659
660        let (from, until) = input.calculate_bound(10); // Exceeds length
661        assert_eq!((from, until), (0, 6));
662    }
663}