mago_syntax_core/
input.rs

1use bumpalo::Bump;
2use mago_database::file::HasFileId;
3use memchr::memchr;
4use memchr::memmem::find;
5
6use mago_database::file::File;
7use mago_database::file::FileId;
8use mago_span::Position;
9
10/// A struct representing the input code being lexed.
11///
12/// The `Input` struct provides methods to read, peek, consume, and skip characters
13/// from the bytes input code while keeping track of the current position (line, column, offset).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct Input<'a> {
16    pub(crate) bytes: &'a [u8],
17    pub(crate) length: usize,
18    pub(crate) offset: usize,
19    pub(crate) starting_position: Position,
20    pub(crate) file_id: FileId,
21}
22
23impl<'a> Input<'a> {
24    /// Creates a new `Input` instance from the given input.
25    ///
26    /// # Arguments
27    ///
28    /// * `file_id` - The unique identifier for the source file this input belongs to.
29    /// * `bytes` - A byte slice representing the input code to be lexed.
30    ///
31    /// # Returns
32    ///
33    /// A new `Input` instance initialized at the beginning of the input.
34    pub fn new(file_id: FileId, bytes: &'a [u8]) -> Self {
35        let length = bytes.len();
36
37        Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
38    }
39
40    /// Creates a new `Input` instance from the contents of a `File`.
41    ///
42    /// # Arguments
43    ///
44    /// * `file` - A reference to the `File` containing the source code.
45    ///
46    /// # Returns
47    ///
48    /// A new `Input` instance initialized with the file's ID and contents.
49    pub fn from_file(file: &'a File) -> Self {
50        Self::new(file.id, file.contents.as_bytes())
51    }
52
53    /// Creates a new `Input` instance from the contents of a `File`.
54    ///
55    /// # Arguments
56    ///
57    /// * `file` - A reference to the `File` containing the source code.
58    ///
59    /// # Returns
60    ///
61    /// A new `Input` instance initialized with the file's ID and contents.
62    pub fn from_file_in(arena: &'a Bump, file: &File) -> Self {
63        Self::new(file.id, arena.alloc_slice_clone(file.contents.as_bytes()))
64    }
65
66    /// Creates a new `Input` instance representing a byte slice that is
67    /// "anchored" at a specific absolute position within a larger source file.
68    ///
69    /// This is useful when lexing a subset (slice) of a source file, as it allows
70    /// generated tokens to retain accurate absolute positions and spans relative
71    /// to the original file.
72    ///
73    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
74    /// but the absolute position is calculated relative to the `anchor_position`.
75    ///
76    /// # Arguments
77    ///
78    /// * `file_id` - The unique identifier for the source file this input belongs to.
79    /// * `bytes` - A byte slice representing the input code subset to be lexed.
80    /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
81    ///
82    /// # Returns
83    ///
84    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
85    /// relative to `anchor_position`.
86    pub fn anchored_at(file_id: FileId, bytes: &'a [u8], anchor_position: Position) -> Self {
87        let length = bytes.len();
88
89        Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
90    }
91
92    /// Returns the source file identifier of the input code.
93    #[inline]
94    pub const fn file_id(&self) -> FileId {
95        self.file_id
96    }
97
98    /// Returns the absolute current `Position` of the lexer within the original source file.
99    ///
100    /// It calculates this by adding the internal offset (progress within the current byte slice)
101    /// to the `starting_position` the `Input` was initialized with.
102    #[inline]
103    pub const fn current_position(&self) -> Position {
104        // Calculate absolute position by adding internal offset to the starting base
105        self.starting_position.forward(self.offset as u32)
106    }
107
108    /// Returns the current internal byte offset relative to the start of the input slice.
109    ///
110    /// This indicates how many bytes have been consumed from the current `bytes` slice.
111    /// To get the absolute position in the original source file, use `current_position()`.
112    #[inline]
113    pub const fn current_offset(&self) -> usize {
114        self.offset
115    }
116
117    /// Returns `true` if the input slice is empty (length is zero).
118    #[inline]
119    pub const fn is_empty(&self) -> bool {
120        self.length == 0
121    }
122
123    /// Returns the total length in bytes of the input slice being processed.
124    #[inline]
125    pub const fn len(&self) -> usize {
126        self.length
127    }
128
129    /// Checks if the current position is at the end of the input.
130    ///
131    /// # Returns
132    ///
133    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
134    #[inline]
135    pub const fn has_reached_eof(&self) -> bool {
136        self.offset >= self.length
137    }
138
139    /// Returns a byte slice within a specified absolute range.
140    ///
141    /// The `from` and `to` arguments are absolute byte offsets from the beginning
142    /// of the original source file. The method calculates the correct slice
143    /// relative to the `starting_position` of this `Input`.
144    ///
145    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
146    /// positions are absolute, even when the `Input` only contains a subsection
147    /// of the source file.
148    ///
149    /// The returned slice is defensively clamped to the bounds of the current
150    /// `Input`'s byte slice to prevent panics.
151    ///
152    /// # Arguments
153    ///
154    /// * `from` - The absolute starting byte offset.
155    /// * `to` - The absolute ending byte offset (exclusive).
156    ///
157    /// # Returns
158    ///
159    /// A byte slice `&[u8]` corresponding to the requested range.
160    #[inline]
161    pub fn slice_in_range(&self, from: u32, to: u32) -> &'a [u8] {
162        let base_offset = self.starting_position.offset;
163
164        // Calculate the start and end positions relative to the local `bytes` slice.
165        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
166        let local_from = from.saturating_sub(base_offset) as usize;
167        let local_to = to.saturating_sub(base_offset) as usize;
168
169        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
170        let start = local_from.min(self.length);
171        let end = local_to.min(self.length);
172
173        // Ensure the start index is not greater than the end index.
174        if start >= end {
175            return &[];
176        }
177
178        // If the start index is beyond the length of the input, return an empty slice.
179        if start >= self.length {
180            return &[];
181        }
182
183        &self.bytes[start..end]
184    }
185
186    /// Advances the current position by one character, updating line and column numbers.
187    ///
188    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
189    ///
190    /// If the end of input is reached, no action is taken.
191    #[inline]
192    pub fn next(&mut self) {
193        if !self.has_reached_eof() {
194            self.offset += 1;
195        }
196    }
197
198    /// Skips the next `count` characters, advancing the position accordingly.
199    ///
200    /// Updates line and column numbers as it advances.
201    ///
202    /// # Arguments
203    ///
204    /// * `count` - The number of characters to skip.
205    #[inline]
206    pub fn skip(&mut self, count: usize) {
207        self.offset = (self.offset + count).min(self.length);
208    }
209
210    /// Consumes the next `count` characters and returns them as a slice.
211    ///
212    /// Advances the position by `count` characters.
213    ///
214    /// # Arguments
215    ///
216    /// * `count` - The number of characters to consume.
217    ///
218    /// # Returns
219    ///
220    /// A byte slice containing the consumed characters.
221    #[inline]
222    pub fn consume(&mut self, count: usize) -> &'a [u8] {
223        let (from, until) = self.calculate_bound(count);
224
225        self.skip(count);
226
227        &self.bytes[from..until]
228    }
229
230    /// Consumes all remaining characters from the current position to the end of input.
231    ///
232    /// Advances the position to EOF.
233    ///
234    /// # Returns
235    ///
236    /// A byte slice containing the remaining characters.
237    #[inline]
238    pub fn consume_remaining(&mut self) -> &'a [u8] {
239        if self.has_reached_eof() {
240            return &[];
241        }
242
243        let from = self.offset;
244        self.offset = self.length;
245
246        &self.bytes[from..]
247    }
248
249    /// Consumes characters until the given byte slice is found.
250    ///
251    /// Advances the position to the start of the search slice if found,
252    /// or to EOF if not found.
253    ///
254    /// # Arguments
255    ///
256    /// * `search` - The byte slice to search for.
257    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
258    ///
259    /// # Returns
260    ///
261    /// A byte slice containing the consumed characters.
262    #[inline]
263    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
264        let start = self.offset;
265        if !ignore_ascii_case {
266            // For a single-byte search, use memchr.
267            if search.len() == 1 {
268                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
269                    self.offset += pos;
270                    &self.bytes[start..self.offset]
271                } else {
272                    self.offset = self.length;
273                    &self.bytes[start..self.length]
274                }
275            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
276                self.offset += pos;
277                &self.bytes[start..self.offset]
278            } else {
279                self.offset = self.length;
280                &self.bytes[start..self.length]
281            }
282        } else {
283            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
284                self.offset += 1;
285            }
286
287            &self.bytes[start..self.offset]
288        }
289    }
290
291    #[inline]
292    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
293        let start = self.offset;
294        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
295            self.offset += pos + 1;
296
297            &self.bytes[start..self.offset]
298        } else {
299            self.offset = self.length;
300
301            &self.bytes[start..self.length]
302        }
303    }
304
305    /// Consumes whitespaces until a non-whitespace character is found.
306    ///
307    /// # Returns
308    ///
309    /// A byte slice containing the consumed whitespaces.
310    #[inline]
311    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
312        let start = self.offset;
313        let bytes = self.bytes;
314        let len = self.length;
315        while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
316            self.offset += 1;
317        }
318
319        &bytes[start..self.offset]
320    }
321
322    /// Reads the next `n` characters without advancing the position.
323    ///
324    /// # Arguments
325    ///
326    /// * `n` - The number of characters to read.
327    ///
328    /// # Returns
329    ///
330    /// A byte slice containing the next `n` characters.
331    #[inline]
332    pub fn read(&self, n: usize) -> &'a [u8] {
333        let (from, until) = self.calculate_bound(n);
334
335        &self.bytes[from..until]
336    }
337
338    /// Reads a single byte at a specific byte offset within the input slice,
339    /// without advancing the internal cursor.
340    ///
341    /// This provides direct, low-level access to the underlying byte data.
342    ///
343    /// # Arguments
344    ///
345    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
346    ///   from which to read the byte.
347    ///
348    /// # Returns
349    ///
350    /// A reference to the byte located at the specified offset `at`.
351    ///
352    /// # Panics
353    ///
354    /// This method **panics** if the provided `at` offset is out of bounds
355    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
356    pub fn read_at(&self, at: usize) -> &'a u8 {
357        &self.bytes[at]
358    }
359
360    /// Checks if the input at the current position matches the given byte slice.
361    ///
362    /// # Arguments
363    ///
364    /// * `search` - The byte slice to compare against the input.
365    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
366    ///
367    /// # Returns
368    ///
369    /// `true` if the next bytes match `search`; `false` otherwise.
370    #[inline]
371    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
372        let (from, until) = self.calculate_bound(search.len());
373        let slice = &self.bytes[from..until];
374
375        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
376    }
377
378    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
379    ///
380    /// This method tries to match the provided byte slice `search` against the input starting
381    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
382    /// are skipped during matching, but their length is included in the returned length.
383    ///
384    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
385    /// in the returned length.
386    ///
387    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
388    /// and this method would return the total length of the input consumed to match `(string)`,
389    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
390    ///
391    /// # Arguments
392    ///
393    /// * `search` - The byte slice to match against the input.
394    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
395    ///
396    /// # Returns
397    ///
398    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
399    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
400    /// * `None` - If the input does not match `search`.
401    #[inline]
402    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
403        let mut offset = self.offset;
404        let mut search_offset = 0;
405        let mut length = 0;
406        let bytes = self.bytes;
407        let total = self.length;
408        while search_offset < search.len() {
409            // Skip whitespace in the input.
410            while offset < total && bytes[offset].is_ascii_whitespace() {
411                offset += 1;
412                length += 1;
413            }
414
415            if offset >= total {
416                return None;
417            }
418
419            let input_byte = bytes[offset];
420            let search_byte = search[search_offset];
421            let matched = if ignore_ascii_case {
422                input_byte.eq_ignore_ascii_case(&search_byte)
423            } else {
424                input_byte == search_byte
425            };
426
427            if matched {
428                offset += 1;
429                length += 1;
430                search_offset += 1;
431            } else {
432                return None;
433            }
434        }
435
436        Some(length)
437    }
438
439    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
440    ///
441    /// # Arguments
442    ///
443    /// * `offset` - The number of characters to skip before reading.
444    /// * `n` - The number of characters to read after skipping.
445    ///
446    /// # Returns
447    ///
448    /// A byte slice containing the peeked characters.
449    #[inline]
450    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
451        let from = self.offset + offset;
452        if from >= self.length {
453            return &self.bytes[self.length..self.length];
454        }
455
456        let mut until = from + n;
457        if until >= self.length {
458            until = self.length;
459        }
460
461        &self.bytes[from..until]
462    }
463
464    /// Calculates the bounds for slicing the input safely.
465    ///
466    /// Ensures that slicing does not go beyond the input length.
467    ///
468    /// # Arguments
469    ///
470    /// * `n` - The number of characters to include in the slice.
471    ///
472    /// # Returns
473    ///
474    /// A tuple `(from, until)` representing the start and end indices for slicing.
475    #[inline]
476    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
477        if self.has_reached_eof() {
478            return (self.length, self.length);
479        }
480
481        let mut until = self.offset + n;
482
483        if until >= self.length {
484            until = self.length;
485        }
486
487        (self.offset, until)
488    }
489}
490
491impl HasFileId for Input<'_> {
492    fn file_id(&self) -> FileId {
493        self.file_id
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use mago_span::Position;
500
501    use super::*;
502
503    #[test]
504    fn test_new() {
505        let bytes = b"Hello, world!";
506        let input = Input::new(FileId::zero(), bytes);
507
508        assert_eq!(input.current_position(), Position::new(0));
509        assert_eq!(input.length, bytes.len());
510        assert_eq!(input.bytes, bytes);
511    }
512
513    #[test]
514    fn test_is_eof() {
515        let bytes = b"";
516        let input = Input::new(FileId::zero(), bytes);
517
518        assert!(input.has_reached_eof());
519
520        let bytes = b"data";
521        let mut input = Input::new(FileId::zero(), bytes);
522
523        assert!(!input.has_reached_eof());
524
525        input.skip(4);
526
527        assert!(input.has_reached_eof());
528    }
529
530    #[test]
531    fn test_next() {
532        let bytes = b"a\nb\r\nc\rd";
533        let mut input = Input::new(FileId::zero(), bytes);
534
535        // 'a'
536        input.next();
537        assert_eq!(input.current_position(), Position::new(1));
538
539        // '\n'
540        input.next();
541        assert_eq!(input.current_position(), Position::new(2));
542
543        // 'b'
544        input.next();
545        assert_eq!(input.current_position(), Position::new(3));
546
547        // '\r\n' should be treated as one newline
548        input.next();
549        assert_eq!(input.current_position(), Position::new(4));
550
551        // 'c'
552        input.next();
553        assert_eq!(input.current_position(), Position::new(5));
554
555        // '\r'
556        input.next();
557        assert_eq!(input.current_position(), Position::new(6));
558
559        // 'd'
560        input.next();
561        assert_eq!(input.current_position(), Position::new(7));
562    }
563
564    #[test]
565    fn test_consume() {
566        let bytes = b"abcdef";
567        let mut input = Input::new(FileId::zero(), bytes);
568
569        let consumed = input.consume(3);
570        assert_eq!(consumed, b"abc");
571        assert_eq!(input.current_position(), Position::new(3));
572
573        let consumed = input.consume(3);
574        assert_eq!(consumed, b"def");
575        assert_eq!(input.current_position(), Position::new(6));
576
577        let consumed = input.consume(1); // Should return empty slice at EOF
578        assert_eq!(consumed, b"");
579        assert!(input.has_reached_eof());
580    }
581
582    #[test]
583    fn test_consume_remaining() {
584        let bytes = b"abcdef";
585        let mut input = Input::new(FileId::zero(), bytes);
586
587        input.skip(2);
588        let remaining = input.consume_remaining();
589        assert_eq!(remaining, b"cdef");
590        assert!(input.has_reached_eof());
591    }
592
593    #[test]
594    fn test_read() {
595        let bytes = b"abcdef";
596        let input = Input::new(FileId::zero(), bytes);
597
598        let read = input.read(3);
599        assert_eq!(read, b"abc");
600        assert_eq!(input.current_position(), Position::new(0));
601        // Position should not change
602    }
603
604    #[test]
605    fn test_is_at() {
606        let bytes = b"abcdef";
607        let mut input = Input::new(FileId::zero(), bytes);
608
609        assert!(input.is_at(b"abc", false));
610        input.skip(2);
611        assert!(input.is_at(b"cde", false));
612        assert!(!input.is_at(b"xyz", false));
613    }
614
615    #[test]
616    fn test_is_at_ignore_ascii_case() {
617        let bytes = b"AbCdEf";
618        let mut input = Input::new(FileId::zero(), bytes);
619
620        assert!(input.is_at(b"abc", true));
621        input.skip(2);
622        assert!(input.is_at(b"cde", true));
623        assert!(!input.is_at(b"xyz", true));
624    }
625
626    #[test]
627    fn test_peek() {
628        let bytes = b"abcdef";
629        let input = Input::new(FileId::zero(), bytes);
630
631        let peeked = input.peek(2, 3);
632        assert_eq!(peeked, b"cde");
633        assert_eq!(input.current_position(), Position::new(0));
634        // Position should not change
635    }
636
637    #[test]
638    fn test_to_bound() {
639        let bytes = b"abcdef";
640        let input = Input::new(FileId::zero(), bytes);
641
642        let (from, until) = input.calculate_bound(3);
643        assert_eq!((from, until), (0, 3));
644
645        let (from, until) = input.calculate_bound(10); // Exceeds length
646        assert_eq!((from, until), (0, 6));
647    }
648}