mago_syntax_core/
input.rs

1use mago_database::file::HasFileId;
2use memchr::memchr;
3use memchr::memmem::find;
4
5use mago_database::file::File;
6use mago_database::file::FileId;
7use mago_span::Position;
8
9/// A struct representing the input code being lexed.
10///
11/// The `Input` struct provides methods to read, peek, consume, and skip characters
12/// from the bytes input code while keeping track of the current position (line, column, offset).
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct Input<'a> {
15    pub(crate) bytes: &'a [u8],
16    pub(crate) length: usize,
17    pub(crate) offset: usize,
18    pub(crate) starting_position: Position,
19    pub(crate) file_id: FileId,
20}
21
22impl<'a> Input<'a> {
23    /// Creates a new `Input` instance from the given input.
24    ///
25    /// # Arguments
26    ///
27    /// * `file_id` - The unique identifier for the source file this input belongs to.
28    /// * `bytes` - A byte slice representing the input code to be lexed.
29    ///
30    /// # Returns
31    ///
32    /// A new `Input` instance initialized at the beginning of the input.
33    pub fn new(file_id: FileId, bytes: &'a [u8]) -> Self {
34        let length = bytes.len();
35
36        Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
37    }
38
39    /// Creates a new `Input` instance from the contents of a `File`.
40    ///
41    /// # Arguments
42    ///
43    /// * `file` - A reference to the `File` containing the source code.
44    ///
45    /// # Returns
46    ///
47    /// A new `Input` instance initialized with the file's ID and contents.
48    pub fn from_file(file: &'a File) -> Self {
49        Self::new(file.id, file.contents.as_bytes())
50    }
51
52    /// Creates a new `Input` instance representing a byte slice that is
53    /// "anchored" at a specific absolute position within a larger source file.
54    ///
55    /// This is useful when lexing a subset (slice) of a source file, as it allows
56    /// generated tokens to retain accurate absolute positions and spans relative
57    /// to the original file.
58    ///
59    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
60    /// but the absolute position is calculated relative to the `anchor_position`.
61    ///
62    /// # Arguments
63    ///
64    /// * `file_id` - The unique identifier for the source file this input belongs to.
65    /// * `bytes` - A byte slice representing the input code subset to be lexed.
66    /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
67    ///
68    /// # Returns
69    ///
70    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
71    /// relative to `anchor_position`.
72    pub fn anchored_at(file_id: FileId, bytes: &'a [u8], anchor_position: Position) -> Self {
73        let length = bytes.len();
74
75        Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
76    }
77
78    /// Returns the source file identifier of the input code.
79    #[inline]
80    pub const fn file_id(&self) -> FileId {
81        self.file_id
82    }
83
84    /// Returns the absolute current `Position` of the lexer within the original source file.
85    ///
86    /// It calculates this by adding the internal offset (progress within the current byte slice)
87    /// to the `starting_position` the `Input` was initialized with.
88    #[inline]
89    pub const fn current_position(&self) -> Position {
90        // Calculate absolute position by adding internal offset to the starting base
91        self.starting_position.forward(self.offset as u32)
92    }
93
94    /// Returns the current internal byte offset relative to the start of the input slice.
95    ///
96    /// This indicates how many bytes have been consumed from the current `bytes` slice.
97    /// To get the absolute position in the original source file, use `current_position()`.
98    #[inline]
99    pub const fn current_offset(&self) -> usize {
100        self.offset
101    }
102
103    /// Returns `true` if the input slice is empty (length is zero).
104    #[inline]
105    pub const fn is_empty(&self) -> bool {
106        self.length == 0
107    }
108
109    /// Returns the total length in bytes of the input slice being processed.
110    #[inline]
111    pub const fn len(&self) -> usize {
112        self.length
113    }
114
115    /// Checks if the current position is at the end of the input.
116    ///
117    /// # Returns
118    ///
119    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
120    #[inline]
121    pub const fn has_reached_eof(&self) -> bool {
122        self.offset >= self.length
123    }
124
125    /// Returns a byte slice within a specified absolute range.
126    ///
127    /// The `from` and `to` arguments are absolute byte offsets from the beginning
128    /// of the original source file. The method calculates the correct slice
129    /// relative to the `starting_position` of this `Input`.
130    ///
131    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
132    /// positions are absolute, even when the `Input` only contains a subsection
133    /// of the source file.
134    ///
135    /// The returned slice is defensively clamped to the bounds of the current
136    /// `Input`'s byte slice to prevent panics.
137    ///
138    /// # Arguments
139    ///
140    /// * `from` - The absolute starting byte offset.
141    /// * `to` - The absolute ending byte offset (exclusive).
142    ///
143    /// # Returns
144    ///
145    /// A byte slice `&[u8]` corresponding to the requested range.
146    #[inline]
147    pub fn slice_in_range(&self, from: u32, to: u32) -> &'a [u8] {
148        let base_offset = self.starting_position.offset;
149
150        // Calculate the start and end positions relative to the local `bytes` slice.
151        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
152        let local_from = from.saturating_sub(base_offset) as usize;
153        let local_to = to.saturating_sub(base_offset) as usize;
154
155        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
156        let start = local_from.min(self.length);
157        let end = local_to.min(self.length);
158
159        // Ensure the start index is not greater than the end index.
160        if start >= end {
161            return &[];
162        }
163
164        // If the start index is beyond the length of the input, return an empty slice.
165        if start >= self.length {
166            return &[];
167        }
168
169        &self.bytes[start..end]
170    }
171
172    /// Advances the current position by one character, updating line and column numbers.
173    ///
174    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
175    ///
176    /// If the end of input is reached, no action is taken.
177    #[inline]
178    pub fn next(&mut self) {
179        if !self.has_reached_eof() {
180            self.offset += 1;
181        }
182    }
183
184    /// Skips the next `count` characters, advancing the position accordingly.
185    ///
186    /// Updates line and column numbers as it advances.
187    ///
188    /// # Arguments
189    ///
190    /// * `count` - The number of characters to skip.
191    #[inline]
192    pub fn skip(&mut self, count: usize) {
193        self.offset = (self.offset + count).min(self.length);
194    }
195
196    /// Consumes the next `count` characters and returns them as a slice.
197    ///
198    /// Advances the position by `count` characters.
199    ///
200    /// # Arguments
201    ///
202    /// * `count` - The number of characters to consume.
203    ///
204    /// # Returns
205    ///
206    /// A byte slice containing the consumed characters.
207    #[inline]
208    pub fn consume(&mut self, count: usize) -> &'a [u8] {
209        let (from, until) = self.calculate_bound(count);
210
211        self.skip(count);
212
213        &self.bytes[from..until]
214    }
215
216    /// Consumes all remaining characters from the current position to the end of input.
217    ///
218    /// Advances the position to EOF.
219    ///
220    /// # Returns
221    ///
222    /// A byte slice containing the remaining characters.
223    #[inline]
224    pub fn consume_remaining(&mut self) -> &'a [u8] {
225        if self.has_reached_eof() {
226            return &[];
227        }
228
229        let from = self.offset;
230        self.offset = self.length;
231
232        &self.bytes[from..]
233    }
234
235    /// Consumes characters until the given byte slice is found.
236    ///
237    /// Advances the position to the start of the search slice if found,
238    /// or to EOF if not found.
239    ///
240    /// # Arguments
241    ///
242    /// * `search` - The byte slice to search for.
243    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
244    ///
245    /// # Returns
246    ///
247    /// A byte slice containing the consumed characters.
248    #[inline]
249    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
250        let start = self.offset;
251        if !ignore_ascii_case {
252            // For a single-byte search, use memchr.
253            if search.len() == 1 {
254                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
255                    self.offset += pos;
256                    &self.bytes[start..self.offset]
257                } else {
258                    self.offset = self.length;
259                    &self.bytes[start..self.length]
260                }
261            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
262                self.offset += pos;
263                &self.bytes[start..self.offset]
264            } else {
265                self.offset = self.length;
266                &self.bytes[start..self.length]
267            }
268        } else {
269            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
270                self.offset += 1;
271            }
272
273            &self.bytes[start..self.offset]
274        }
275    }
276
277    #[inline]
278    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
279        let start = self.offset;
280        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
281            self.offset += pos + 1;
282
283            &self.bytes[start..self.offset]
284        } else {
285            self.offset = self.length;
286
287            &self.bytes[start..self.length]
288        }
289    }
290
291    /// Consumes whitespaces until a non-whitespace character is found.
292    ///
293    /// # Returns
294    ///
295    /// A byte slice containing the consumed whitespaces.
296    #[inline]
297    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
298        let start = self.offset;
299        let bytes = self.bytes;
300        let len = self.length;
301        while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
302            self.offset += 1;
303        }
304
305        &bytes[start..self.offset]
306    }
307
308    /// Reads the next `n` characters without advancing the position.
309    ///
310    /// # Arguments
311    ///
312    /// * `n` - The number of characters to read.
313    ///
314    /// # Returns
315    ///
316    /// A byte slice containing the next `n` characters.
317    #[inline]
318    pub fn read(&self, n: usize) -> &'a [u8] {
319        let (from, until) = self.calculate_bound(n);
320
321        &self.bytes[from..until]
322    }
323
324    /// Reads a single byte at a specific byte offset within the input slice,
325    /// without advancing the internal cursor.
326    ///
327    /// This provides direct, low-level access to the underlying byte data.
328    ///
329    /// # Arguments
330    ///
331    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
332    ///   from which to read the byte.
333    ///
334    /// # Returns
335    ///
336    /// A reference to the byte located at the specified offset `at`.
337    ///
338    /// # Panics
339    ///
340    /// This method **panics** if the provided `at` offset is out of bounds
341    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
342    pub fn read_at(&self, at: usize) -> &'a u8 {
343        &self.bytes[at]
344    }
345
346    /// Checks if the input at the current position matches the given byte slice.
347    ///
348    /// # Arguments
349    ///
350    /// * `search` - The byte slice to compare against the input.
351    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
352    ///
353    /// # Returns
354    ///
355    /// `true` if the next bytes match `search`; `false` otherwise.
356    #[inline]
357    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
358        let (from, until) = self.calculate_bound(search.len());
359        let slice = &self.bytes[from..until];
360
361        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
362    }
363
364    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
365    ///
366    /// This method tries to match the provided byte slice `search` against the input starting
367    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
368    /// are skipped during matching, but their length is included in the returned length.
369    ///
370    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
371    /// in the returned length.
372    ///
373    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
374    /// and this method would return the total length of the input consumed to match `(string)`,
375    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
376    ///
377    /// # Arguments
378    ///
379    /// * `search` - The byte slice to match against the input.
380    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
381    ///
382    /// # Returns
383    ///
384    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
385    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
386    /// * `None` - If the input does not match `search`.
387    #[inline]
388    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
389        let mut offset = self.offset;
390        let mut search_offset = 0;
391        let mut length = 0;
392        let bytes = self.bytes;
393        let total = self.length;
394        while search_offset < search.len() {
395            // Skip whitespace in the input.
396            while offset < total && bytes[offset].is_ascii_whitespace() {
397                offset += 1;
398                length += 1;
399            }
400
401            if offset >= total {
402                return None;
403            }
404
405            let input_byte = bytes[offset];
406            let search_byte = search[search_offset];
407            let matched = if ignore_ascii_case {
408                input_byte.eq_ignore_ascii_case(&search_byte)
409            } else {
410                input_byte == search_byte
411            };
412
413            if matched {
414                offset += 1;
415                length += 1;
416                search_offset += 1;
417            } else {
418                return None;
419            }
420        }
421
422        Some(length)
423    }
424
425    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
426    ///
427    /// # Arguments
428    ///
429    /// * `offset` - The number of characters to skip before reading.
430    /// * `n` - The number of characters to read after skipping.
431    ///
432    /// # Returns
433    ///
434    /// A byte slice containing the peeked characters.
435    #[inline]
436    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
437        let from = self.offset + offset;
438        if from >= self.length {
439            return &self.bytes[self.length..self.length];
440        }
441
442        let mut until = from + n;
443        if until >= self.length {
444            until = self.length;
445        }
446
447        &self.bytes[from..until]
448    }
449
450    /// Calculates the bounds for slicing the input safely.
451    ///
452    /// Ensures that slicing does not go beyond the input length.
453    ///
454    /// # Arguments
455    ///
456    /// * `n` - The number of characters to include in the slice.
457    ///
458    /// # Returns
459    ///
460    /// A tuple `(from, until)` representing the start and end indices for slicing.
461    #[inline]
462    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
463        if self.has_reached_eof() {
464            return (self.length, self.length);
465        }
466
467        let mut until = self.offset + n;
468
469        if until >= self.length {
470            until = self.length;
471        }
472
473        (self.offset, until)
474    }
475}
476
477impl HasFileId for Input<'_> {
478    fn file_id(&self) -> FileId {
479        self.file_id
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use mago_span::Position;
486
487    use super::*;
488
489    #[test]
490    fn test_new() {
491        let bytes = b"Hello, world!";
492        let input = Input::new(FileId::zero(), bytes);
493
494        assert_eq!(input.current_position(), Position::new(0));
495        assert_eq!(input.length, bytes.len());
496        assert_eq!(input.bytes, bytes);
497    }
498
499    #[test]
500    fn test_is_eof() {
501        let bytes = b"";
502        let input = Input::new(FileId::zero(), bytes);
503
504        assert!(input.has_reached_eof());
505
506        let bytes = b"data";
507        let mut input = Input::new(FileId::zero(), bytes);
508
509        assert!(!input.has_reached_eof());
510
511        input.skip(4);
512
513        assert!(input.has_reached_eof());
514    }
515
516    #[test]
517    fn test_next() {
518        let bytes = b"a\nb\r\nc\rd";
519        let mut input = Input::new(FileId::zero(), bytes);
520
521        // 'a'
522        input.next();
523        assert_eq!(input.current_position(), Position::new(1));
524
525        // '\n'
526        input.next();
527        assert_eq!(input.current_position(), Position::new(2));
528
529        // 'b'
530        input.next();
531        assert_eq!(input.current_position(), Position::new(3));
532
533        // '\r\n' should be treated as one newline
534        input.next();
535        assert_eq!(input.current_position(), Position::new(4));
536
537        // 'c'
538        input.next();
539        assert_eq!(input.current_position(), Position::new(5));
540
541        // '\r'
542        input.next();
543        assert_eq!(input.current_position(), Position::new(6));
544
545        // 'd'
546        input.next();
547        assert_eq!(input.current_position(), Position::new(7));
548    }
549
550    #[test]
551    fn test_consume() {
552        let bytes = b"abcdef";
553        let mut input = Input::new(FileId::zero(), bytes);
554
555        let consumed = input.consume(3);
556        assert_eq!(consumed, b"abc");
557        assert_eq!(input.current_position(), Position::new(3));
558
559        let consumed = input.consume(3);
560        assert_eq!(consumed, b"def");
561        assert_eq!(input.current_position(), Position::new(6));
562
563        let consumed = input.consume(1); // Should return empty slice at EOF
564        assert_eq!(consumed, b"");
565        assert!(input.has_reached_eof());
566    }
567
568    #[test]
569    fn test_consume_remaining() {
570        let bytes = b"abcdef";
571        let mut input = Input::new(FileId::zero(), bytes);
572
573        input.skip(2);
574        let remaining = input.consume_remaining();
575        assert_eq!(remaining, b"cdef");
576        assert!(input.has_reached_eof());
577    }
578
579    #[test]
580    fn test_read() {
581        let bytes = b"abcdef";
582        let input = Input::new(FileId::zero(), bytes);
583
584        let read = input.read(3);
585        assert_eq!(read, b"abc");
586        assert_eq!(input.current_position(), Position::new(0));
587        // Position should not change
588    }
589
590    #[test]
591    fn test_is_at() {
592        let bytes = b"abcdef";
593        let mut input = Input::new(FileId::zero(), bytes);
594
595        assert!(input.is_at(b"abc", false));
596        input.skip(2);
597        assert!(input.is_at(b"cde", false));
598        assert!(!input.is_at(b"xyz", false));
599    }
600
601    #[test]
602    fn test_is_at_ignore_ascii_case() {
603        let bytes = b"AbCdEf";
604        let mut input = Input::new(FileId::zero(), bytes);
605
606        assert!(input.is_at(b"abc", true));
607        input.skip(2);
608        assert!(input.is_at(b"cde", true));
609        assert!(!input.is_at(b"xyz", true));
610    }
611
612    #[test]
613    fn test_peek() {
614        let bytes = b"abcdef";
615        let input = Input::new(FileId::zero(), bytes);
616
617        let peeked = input.peek(2, 3);
618        assert_eq!(peeked, b"cde");
619        assert_eq!(input.current_position(), Position::new(0));
620        // Position should not change
621    }
622
623    #[test]
624    fn test_to_bound() {
625        let bytes = b"abcdef";
626        let input = Input::new(FileId::zero(), bytes);
627
628        let (from, until) = input.calculate_bound(3);
629        assert_eq!((from, until), (0, 3));
630
631        let (from, until) = input.calculate_bound(10); // Exceeds length
632        assert_eq!((from, until), (0, 6));
633    }
634}