mago_syntax_core/
input.rs

1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13    pub(crate) bytes: &'a [u8],
14    pub(crate) length: usize,
15    pub(crate) offset: usize,
16    pub(crate) starting_position: Position,
17}
18
19impl<'a> Input<'a> {
20    /// Creates a new `Input` instance from the given input.
21    ///
22    /// # Arguments
23    ///
24    /// * `input` - A byte slice representing the input code to be processed.
25    ///
26    /// # Returns
27    ///
28    /// A new `Input` instance initialized at the beginning of the input.
29    pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
30        let length = bytes.len();
31
32        Self { bytes, length, offset: 0, starting_position: Position::start_of(source) }
33    }
34
35    /// Creates a new `Input` instance representing a byte slice that is
36    /// "anchored" at a specific absolute position within a larger source file.
37    ///
38    /// This is useful when lexing a subset (slice) of a source file, as it allows
39    /// generated tokens to retain accurate absolute positions and spans relative
40    /// to the original file.
41    ///
42    /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
43    /// but the absolute position is calculated relative to the `anchor_position`.
44    ///
45    /// # Arguments
46    ///
47    /// * `bytes` - A byte slice representing the input code subset to be lexed.
48    /// * `anchor_position` - The absolute `Position` in the original source file where
49    ///   the provided `bytes` slice begins.
50    ///
51    /// # Returns
52    ///
53    /// A new `Input` instance ready to lex the `bytes`, maintaining positions
54    /// relative to `anchor_position`.
55    pub fn anchored_at(bytes: &'a [u8], anchor_position: Position) -> Self {
56        let length = bytes.len();
57
58        Self { bytes, length, offset: 0, starting_position: anchor_position }
59    }
60
61    /// Returns the source identifier of the input code.
62    #[inline]
63    pub const fn source_identifier(&self) -> SourceIdentifier {
64        self.starting_position.source
65    }
66
67    /// Returns the absolute current `Position` of the lexer within the original source file.
68    ///
69    /// It calculates this by adding the internal offset (progress within the current byte slice)
70    /// to the `starting_position` the `Input` was initialized with.
71    #[inline]
72    pub const fn current_position(&self) -> Position {
73        // Calculate absolute position by adding internal offset to the starting base
74        self.starting_position.forward(self.offset)
75    }
76
77    /// Returns the current internal byte offset relative to the start of the input slice.
78    ///
79    /// This indicates how many bytes have been consumed from the current `bytes` slice.
80    /// To get the absolute position in the original source file, use `current_position()`.
81    #[inline]
82    pub const fn current_offset(&self) -> usize {
83        self.offset
84    }
85
86    /// Returns `true` if the input slice is empty (length is zero).
87    #[inline]
88    pub const fn is_empty(&self) -> bool {
89        self.length == 0
90    }
91
92    /// Returns the total length in bytes of the input slice being processed.
93    #[inline]
94    pub const fn len(&self) -> usize {
95        self.length
96    }
97
98    /// Checks if the current position is at the end of the input.
99    ///
100    /// # Returns
101    ///
102    /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
103    #[inline]
104    pub const fn has_reached_eof(&self) -> bool {
105        self.offset >= self.length
106    }
107
108    /// Returns a byte slice within a specified absolute range.
109    ///
110    /// The `from` and `to` arguments are absolute byte offsets from the beginning
111    /// of the original source file. The method calculates the correct slice
112    /// relative to the `starting_position` of this `Input`.
113    ///
114    /// This is useful for retrieving the raw text of a `Span` or `Token` whose
115    /// positions are absolute, even when the `Input` only contains a subsection
116    /// of the source file.
117    ///
118    /// The returned slice is defensively clamped to the bounds of the current
119    /// `Input`'s byte slice to prevent panics.
120    ///
121    /// # Arguments
122    ///
123    /// * `from` - The absolute starting byte offset.
124    /// * `to` - The absolute ending byte offset (exclusive).
125    ///
126    /// # Returns
127    ///
128    /// A byte slice `&[u8]` corresponding to the requested range.
129    #[inline]
130    pub fn slice_in_range(&self, from: usize, to: usize) -> &'a [u8] {
131        let base_offset = self.starting_position.offset;
132
133        // Calculate the start and end positions relative to the local `bytes` slice.
134        // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
135        let local_from = from.saturating_sub(base_offset);
136        let local_to = to.saturating_sub(base_offset);
137
138        // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
139        let start = local_from.min(self.length);
140        let end = local_to.min(self.length);
141
142        // Ensure the start index is not greater than the end index.
143        if start >= end {
144            return &[];
145        }
146
147        // If the start index is beyond the length of the input, return an empty slice.
148        if start >= self.length {
149            return &[];
150        }
151
152        &self.bytes[start..end]
153    }
154
155    /// Advances the current position by one character, updating line and column numbers.
156    ///
157    /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
158    ///
159    /// If the end of input is reached, no action is taken.
160    #[inline]
161    pub fn next(&mut self) {
162        if !self.has_reached_eof() {
163            self.offset += 1;
164        }
165    }
166
167    /// Skips the next `count` characters, advancing the position accordingly.
168    ///
169    /// Updates line and column numbers as it advances.
170    ///
171    /// # Arguments
172    ///
173    /// * `count` - The number of characters to skip.
174    #[inline]
175    pub fn skip(&mut self, count: usize) {
176        self.offset = (self.offset + count).min(self.length);
177    }
178
179    /// Consumes the next `count` characters and returns them as a slice.
180    ///
181    /// Advances the position by `count` characters.
182    ///
183    /// # Arguments
184    ///
185    /// * `count` - The number of characters to consume.
186    ///
187    /// # Returns
188    ///
189    /// A byte slice containing the consumed characters.
190    #[inline]
191    pub fn consume(&mut self, count: usize) -> &'a [u8] {
192        let (from, until) = self.calculate_bound(count);
193
194        self.skip(count);
195
196        &self.bytes[from..until]
197    }
198
199    /// Consumes all remaining characters from the current position to the end of input.
200    ///
201    /// Advances the position to EOF.
202    ///
203    /// # Returns
204    ///
205    /// A byte slice containing the remaining characters.
206    #[inline]
207    pub fn consume_remaining(&mut self) -> &'a [u8] {
208        if self.has_reached_eof() {
209            return &[];
210        }
211
212        let from = self.offset;
213        self.offset = self.length;
214
215        &self.bytes[from..]
216    }
217
218    /// Consumes characters until the given byte slice is found.
219    ///
220    /// Advances the position to the start of the search slice if found,
221    /// or to EOF if not found.
222    ///
223    /// # Arguments
224    ///
225    /// * `search` - The byte slice to search for.
226    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
227    ///
228    /// # Returns
229    ///
230    /// A byte slice containing the consumed characters.
231    #[inline]
232    pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
233        let start = self.offset;
234        if !ignore_ascii_case {
235            // For a single-byte search, use memchr.
236            if search.len() == 1 {
237                if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
238                    self.offset += pos;
239                    &self.bytes[start..self.offset]
240                } else {
241                    self.offset = self.length;
242                    &self.bytes[start..self.length]
243                }
244            } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
245                self.offset += pos;
246                &self.bytes[start..self.offset]
247            } else {
248                self.offset = self.length;
249                &self.bytes[start..self.length]
250            }
251        } else {
252            while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
253                self.offset += 1;
254            }
255
256            &self.bytes[start..self.offset]
257        }
258    }
259
260    #[inline]
261    pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
262        let start = self.offset;
263        if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
264            self.offset += pos + 1;
265
266            &self.bytes[start..self.offset]
267        } else {
268            self.offset = self.length;
269
270            &self.bytes[start..self.length]
271        }
272    }
273
274    /// Consumes whitespaces until a non-whitespace character is found.
275    ///
276    /// # Returns
277    ///
278    /// A byte slice containing the consumed whitespaces.
279    #[inline]
280    pub fn consume_whitespaces(&mut self) -> &'a [u8] {
281        let start = self.offset;
282        let bytes = self.bytes;
283        let len = self.length;
284        while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
285            self.offset += 1;
286        }
287
288        &bytes[start..self.offset]
289    }
290
291    /// Reads the next `n` characters without advancing the position.
292    ///
293    /// # Arguments
294    ///
295    /// * `n` - The number of characters to read.
296    ///
297    /// # Returns
298    ///
299    /// A byte slice containing the next `n` characters.
300    #[inline]
301    pub fn read(&self, n: usize) -> &'a [u8] {
302        let (from, until) = self.calculate_bound(n);
303
304        &self.bytes[from..until]
305    }
306
307    /// Reads a single byte at a specific byte offset within the input slice,
308    /// without advancing the internal cursor.
309    ///
310    /// This provides direct, low-level access to the underlying byte data.
311    ///
312    /// # Arguments
313    ///
314    /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
315    ///   from which to read the byte.
316    ///
317    /// # Returns
318    ///
319    /// A reference to the byte located at the specified offset `at`.
320    ///
321    /// # Panics
322    ///
323    /// This method **panics** if the provided `at` offset is out of bounds
324    /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
325    pub fn read_at(&self, at: usize) -> &'a u8 {
326        &self.bytes[at]
327    }
328
329    /// Checks if the input at the current position matches the given byte slice.
330    ///
331    /// # Arguments
332    ///
333    /// * `search` - The byte slice to compare against the input.
334    /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
335    ///
336    /// # Returns
337    ///
338    /// `true` if the next bytes match `search`; `false` otherwise.
339    #[inline]
340    pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
341        let (from, until) = self.calculate_bound(search.len());
342        let slice = &self.bytes[from..until];
343
344        if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
345    }
346
347    /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
348    ///
349    /// This method tries to match the provided byte slice `search` against the input starting
350    /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
351    /// are skipped during matching, but their length is included in the returned length.
352    ///
353    /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
354    /// in the returned length.
355    ///
356    /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `(  string )`, etc.,
357    /// and this method would return the total length of the input consumed to match `(string)`,
358    /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
359    ///
360    /// # Arguments
361    ///
362    /// * `search` - The byte slice to match against the input.
363    /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
364    ///
365    /// # Returns
366    ///
367    /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
368    ///   of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
369    /// * `None` - If the input does not match `search`.
370    ///
371    /// # Examples
372    ///
373    /// ```rust
374    /// use mago_syntax_core::input::Input;
375    /// use mago_source::SourceIdentifier;
376    ///
377    /// let source = SourceIdentifier::dummy();
378    ///
379    /// // Given input "( string ) x", starting at offset 0:
380    /// let input = Input::new(source.clone(), b"( string ) x");
381    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
382    ///
383    /// // Given input "(int)", with no whitespace:
384    /// let input = Input::new(source.clone(), b"(int)");
385    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
386    ///
387    /// // Given input "(  InT   )abc", ignoring ASCII case:
388    /// let input = Input::new(source.clone(), b"(  InT   )abc");
389    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
390    ///
391    /// // Given input "(integer)", attempting to match "(int)":
392    /// let input = Input::new(source.clone(), b"(integer)");
393    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
394    ///
395    /// // Trailing whitespace after ')':
396    /// let input = Input::new(source.clone(), b"(int)   x");
397    /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
398    /// ```
399    #[inline]
400    pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
401        let mut offset = self.offset;
402        let mut search_offset = 0;
403        let mut length = 0;
404        let bytes = self.bytes;
405        let total = self.length;
406        while search_offset < search.len() {
407            // Skip whitespace in the input.
408            while offset < total && bytes[offset].is_ascii_whitespace() {
409                offset += 1;
410                length += 1;
411            }
412
413            if offset >= total {
414                return None;
415            }
416
417            let input_byte = bytes[offset];
418            let search_byte = search[search_offset];
419            let matched = if ignore_ascii_case {
420                input_byte.eq_ignore_ascii_case(&search_byte)
421            } else {
422                input_byte == search_byte
423            };
424
425            if matched {
426                offset += 1;
427                length += 1;
428                search_offset += 1;
429            } else {
430                return None;
431            }
432        }
433
434        Some(length)
435    }
436
437    /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
438    ///
439    /// # Arguments
440    ///
441    /// * `offset` - The number of characters to skip before reading.
442    /// * `n` - The number of characters to read after skipping.
443    ///
444    /// # Returns
445    ///
446    /// A byte slice containing the peeked characters.
447    #[inline]
448    pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
449        let from = self.offset + offset;
450        if from >= self.length {
451            return &self.bytes[self.length..self.length];
452        }
453
454        let mut until = from + n;
455        if until >= self.length {
456            until = self.length;
457        }
458
459        &self.bytes[from..until]
460    }
461
462    /// Calculates the bounds for slicing the input safely.
463    ///
464    /// Ensures that slicing does not go beyond the input length.
465    ///
466    /// # Arguments
467    ///
468    /// * `n` - The number of characters to include in the slice.
469    ///
470    /// # Returns
471    ///
472    /// A tuple `(from, until)` representing the start and end indices for slicing.
473    #[inline]
474    const fn calculate_bound(&self, n: usize) -> (usize, usize) {
475        if self.has_reached_eof() {
476            return (self.length, self.length);
477        }
478
479        let mut until = self.offset + n;
480
481        if until >= self.length {
482            until = self.length;
483        }
484
485        (self.offset, until)
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use mago_span::Position;
492
493    use super::*;
494
495    #[test]
496    fn test_new() {
497        let bytes = b"Hello, world!";
498        let input = Input::new(SourceIdentifier::dummy(), bytes);
499
500        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
501        assert_eq!(input.length, bytes.len());
502        assert_eq!(input.bytes, bytes);
503    }
504
505    #[test]
506    fn test_is_eof() {
507        let bytes = b"";
508        let input = Input::new(SourceIdentifier::dummy(), bytes);
509
510        assert!(input.has_reached_eof());
511
512        let bytes = b"data";
513        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
514
515        assert!(!input.has_reached_eof());
516
517        input.skip(4);
518
519        assert!(input.has_reached_eof());
520    }
521
522    #[test]
523    fn test_next() {
524        let bytes = b"a\nb\r\nc\rd";
525        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
526
527        // 'a'
528        input.next();
529        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 1));
530
531        // '\n'
532        input.next();
533        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 2));
534
535        // 'b'
536        input.next();
537        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
538
539        // '\r\n' should be treated as one newline
540        input.next();
541        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 4));
542
543        // 'c'
544        input.next();
545        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 5));
546
547        // '\r'
548        input.next();
549        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
550
551        // 'd'
552        input.next();
553        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 7));
554    }
555
556    #[test]
557    fn test_consume() {
558        let bytes = b"abcdef";
559        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
560
561        let consumed = input.consume(3);
562        assert_eq!(consumed, b"abc");
563        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
564
565        let consumed = input.consume(3);
566        assert_eq!(consumed, b"def");
567        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
568
569        let consumed = input.consume(1); // Should return empty slice at EOF
570        assert_eq!(consumed, b"");
571        assert!(input.has_reached_eof());
572    }
573
574    #[test]
575    fn test_consume_remaining() {
576        let bytes = b"abcdef";
577        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
578
579        input.skip(2);
580        let remaining = input.consume_remaining();
581        assert_eq!(remaining, b"cdef");
582        assert!(input.has_reached_eof());
583    }
584
585    #[test]
586    fn test_read() {
587        let bytes = b"abcdef";
588        let input = Input::new(SourceIdentifier::dummy(), bytes);
589
590        let read = input.read(3);
591        assert_eq!(read, b"abc");
592        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
593        // Position should not change
594    }
595
596    #[test]
597    fn test_is_at() {
598        let bytes = b"abcdef";
599        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
600
601        assert!(input.is_at(b"abc", false));
602        input.skip(2);
603        assert!(input.is_at(b"cde", false));
604        assert!(!input.is_at(b"xyz", false));
605    }
606
607    #[test]
608    fn test_is_at_ignore_ascii_case() {
609        let bytes = b"AbCdEf";
610        let mut input = Input::new(SourceIdentifier::dummy(), bytes);
611
612        assert!(input.is_at(b"abc", true));
613        input.skip(2);
614        assert!(input.is_at(b"cde", true));
615        assert!(!input.is_at(b"xyz", true));
616    }
617
618    #[test]
619    fn test_peek() {
620        let bytes = b"abcdef";
621        let input = Input::new(SourceIdentifier::dummy(), bytes);
622
623        let peeked = input.peek(2, 3);
624        assert_eq!(peeked, b"cde");
625        assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
626        // Position should not change
627    }
628
629    #[test]
630    fn test_to_bound() {
631        let bytes = b"abcdef";
632        let input = Input::new(SourceIdentifier::dummy(), bytes);
633
634        let (from, until) = input.calculate_bound(3);
635        assert_eq!((from, until), (0, 3));
636
637        let (from, until) = input.calculate_bound(10); // Exceeds length
638        assert_eq!((from, until), (0, 6));
639    }
640}