mago_syntax_core/input.rs
1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13 pub(crate) bytes: &'a [u8],
14 pub(crate) length: usize,
15 pub(crate) offset: usize,
16 pub(crate) starting_position: Position,
17}
18
19impl<'a> Input<'a> {
20 /// Creates a new `Input` instance from the given input.
21 ///
22 /// # Arguments
23 ///
24 /// * `input` - A byte slice representing the input code to be processed.
25 ///
26 /// # Returns
27 ///
28 /// A new `Input` instance initialized at the beginning of the input.
29 pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
30 let length = bytes.len();
31
32 Self { bytes, length, offset: 0, starting_position: Position::start_of(source) }
33 }
34
35 /// Creates a new `Input` instance representing a byte slice that is
36 /// "anchored" at a specific absolute position within a larger source file.
37 ///
38 /// This is useful when lexing a subset (slice) of a source file, as it allows
39 /// generated tokens to retain accurate absolute positions and spans relative
40 /// to the original file.
41 ///
42 /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
43 /// but the absolute position is calculated relative to the `anchor_position`.
44 ///
45 /// # Arguments
46 ///
47 /// * `bytes` - A byte slice representing the input code subset to be lexed.
48 /// * `anchor_position` - The absolute `Position` in the original source file where
49 /// the provided `bytes` slice begins.
50 ///
51 /// # Returns
52 ///
53 /// A new `Input` instance ready to lex the `bytes`, maintaining positions
54 /// relative to `anchor_position`.
55 pub fn anchored_at(bytes: &'a [u8], anchor_position: Position) -> Self {
56 let length = bytes.len();
57
58 Self { bytes, length, offset: 0, starting_position: anchor_position }
59 }
60
61 /// Returns the source identifier of the input code.
62 #[inline]
63 pub const fn source_identifier(&self) -> SourceIdentifier {
64 self.starting_position.source
65 }
66
67 /// Returns the absolute current `Position` of the lexer within the original source file.
68 ///
69 /// It calculates this by adding the internal offset (progress within the current byte slice)
70 /// to the `starting_position` the `Input` was initialized with.
71 #[inline]
72 pub const fn current_position(&self) -> Position {
73 // Calculate absolute position by adding internal offset to the starting base
74 self.starting_position.forward(self.offset)
75 }
76
77 /// Returns the current internal byte offset relative to the start of the input slice.
78 ///
79 /// This indicates how many bytes have been consumed from the current `bytes` slice.
80 /// To get the absolute position in the original source file, use `current_position()`.
81 #[inline]
82 pub const fn current_offset(&self) -> usize {
83 self.offset
84 }
85
86 /// Returns `true` if the input slice is empty (length is zero).
87 #[inline]
88 pub const fn is_empty(&self) -> bool {
89 self.length == 0
90 }
91
92 /// Returns the total length in bytes of the input slice being processed.
93 #[inline]
94 pub const fn len(&self) -> usize {
95 self.length
96 }
97
98 /// Checks if the current position is at the end of the input.
99 ///
100 /// # Returns
101 ///
102 /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
103 #[inline]
104 pub const fn has_reached_eof(&self) -> bool {
105 self.offset >= self.length
106 }
107
108 /// Returns a byte slice within a specified absolute range.
109 ///
110 /// The `from` and `to` arguments are absolute byte offsets from the beginning
111 /// of the original source file. The method calculates the correct slice
112 /// relative to the `starting_position` of this `Input`.
113 ///
114 /// This is useful for retrieving the raw text of a `Span` or `Token` whose
115 /// positions are absolute, even when the `Input` only contains a subsection
116 /// of the source file.
117 ///
118 /// The returned slice is defensively clamped to the bounds of the current
119 /// `Input`'s byte slice to prevent panics.
120 ///
121 /// # Arguments
122 ///
123 /// * `from` - The absolute starting byte offset.
124 /// * `to` - The absolute ending byte offset (exclusive).
125 ///
126 /// # Returns
127 ///
128 /// A byte slice `&[u8]` corresponding to the requested range.
129 #[inline]
130 pub fn slice_in_range(&self, from: usize, to: usize) -> &'a [u8] {
131 let base_offset = self.starting_position.offset;
132
133 // Calculate the start and end positions relative to the local `bytes` slice.
134 // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
135 let local_from = from.saturating_sub(base_offset);
136 let local_to = to.saturating_sub(base_offset);
137
138 // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
139 let start = local_from.min(self.length);
140 let end = local_to.min(self.length);
141
142 // Ensure the start index is not greater than the end index.
143 if start >= end {
144 return &[];
145 }
146
147 // If the start index is beyond the length of the input, return an empty slice.
148 if start >= self.length {
149 return &[];
150 }
151
152 &self.bytes[start..end]
153 }
154
155 /// Advances the current position by one character, updating line and column numbers.
156 ///
157 /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
158 ///
159 /// If the end of input is reached, no action is taken.
160 #[inline]
161 pub fn next(&mut self) {
162 if !self.has_reached_eof() {
163 self.offset += 1;
164 }
165 }
166
167 /// Skips the next `count` characters, advancing the position accordingly.
168 ///
169 /// Updates line and column numbers as it advances.
170 ///
171 /// # Arguments
172 ///
173 /// * `count` - The number of characters to skip.
174 #[inline]
175 pub fn skip(&mut self, count: usize) {
176 self.offset = (self.offset + count).min(self.length);
177 }
178
179 /// Consumes the next `count` characters and returns them as a slice.
180 ///
181 /// Advances the position by `count` characters.
182 ///
183 /// # Arguments
184 ///
185 /// * `count` - The number of characters to consume.
186 ///
187 /// # Returns
188 ///
189 /// A byte slice containing the consumed characters.
190 #[inline]
191 pub fn consume(&mut self, count: usize) -> &'a [u8] {
192 let (from, until) = self.calculate_bound(count);
193
194 self.skip(count);
195
196 &self.bytes[from..until]
197 }
198
199 /// Consumes all remaining characters from the current position to the end of input.
200 ///
201 /// Advances the position to EOF.
202 ///
203 /// # Returns
204 ///
205 /// A byte slice containing the remaining characters.
206 #[inline]
207 pub fn consume_remaining(&mut self) -> &'a [u8] {
208 if self.has_reached_eof() {
209 return &[];
210 }
211
212 let from = self.offset;
213 self.offset = self.length;
214
215 &self.bytes[from..]
216 }
217
218 /// Consumes characters until the given byte slice is found.
219 ///
220 /// Advances the position to the start of the search slice if found,
221 /// or to EOF if not found.
222 ///
223 /// # Arguments
224 ///
225 /// * `search` - The byte slice to search for.
226 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
227 ///
228 /// # Returns
229 ///
230 /// A byte slice containing the consumed characters.
231 #[inline]
232 pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
233 let start = self.offset;
234 if !ignore_ascii_case {
235 // For a single-byte search, use memchr.
236 if search.len() == 1 {
237 if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
238 self.offset += pos;
239 &self.bytes[start..self.offset]
240 } else {
241 self.offset = self.length;
242 &self.bytes[start..self.length]
243 }
244 } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
245 self.offset += pos;
246 &self.bytes[start..self.offset]
247 } else {
248 self.offset = self.length;
249 &self.bytes[start..self.length]
250 }
251 } else {
252 while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
253 self.offset += 1;
254 }
255
256 &self.bytes[start..self.offset]
257 }
258 }
259
260 #[inline]
261 pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
262 let start = self.offset;
263 if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
264 self.offset += pos + 1;
265
266 &self.bytes[start..self.offset]
267 } else {
268 self.offset = self.length;
269
270 &self.bytes[start..self.length]
271 }
272 }
273
274 /// Consumes whitespaces until a non-whitespace character is found.
275 ///
276 /// # Returns
277 ///
278 /// A byte slice containing the consumed whitespaces.
279 #[inline]
280 pub fn consume_whitespaces(&mut self) -> &'a [u8] {
281 let start = self.offset;
282 let bytes = self.bytes;
283 let len = self.length;
284 while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
285 self.offset += 1;
286 }
287
288 &bytes[start..self.offset]
289 }
290
291 /// Reads the next `n` characters without advancing the position.
292 ///
293 /// # Arguments
294 ///
295 /// * `n` - The number of characters to read.
296 ///
297 /// # Returns
298 ///
299 /// A byte slice containing the next `n` characters.
300 #[inline]
301 pub fn read(&self, n: usize) -> &'a [u8] {
302 let (from, until) = self.calculate_bound(n);
303
304 &self.bytes[from..until]
305 }
306
307 /// Reads a single byte at a specific byte offset within the input slice,
308 /// without advancing the internal cursor.
309 ///
310 /// This provides direct, low-level access to the underlying byte data.
311 ///
312 /// # Arguments
313 ///
314 /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
315 /// from which to read the byte.
316 ///
317 /// # Returns
318 ///
319 /// A reference to the byte located at the specified offset `at`.
320 ///
321 /// # Panics
322 ///
323 /// This method **panics** if the provided `at` offset is out of bounds
324 /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
325 pub fn read_at(&self, at: usize) -> &'a u8 {
326 &self.bytes[at]
327 }
328
329 /// Checks if the input at the current position matches the given byte slice.
330 ///
331 /// # Arguments
332 ///
333 /// * `search` - The byte slice to compare against the input.
334 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
335 ///
336 /// # Returns
337 ///
338 /// `true` if the next bytes match `search`; `false` otherwise.
339 #[inline]
340 pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
341 let (from, until) = self.calculate_bound(search.len());
342 let slice = &self.bytes[from..until];
343
344 if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
345 }
346
347 /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
348 ///
349 /// This method tries to match the provided byte slice `search` against the input starting
350 /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
351 /// are skipped during matching, but their length is included in the returned length.
352 ///
353 /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
354 /// in the returned length.
355 ///
356 /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `( string )`, etc.,
357 /// and this method would return the total length of the input consumed to match `(string)`,
358 /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
359 ///
360 /// # Arguments
361 ///
362 /// * `search` - The byte slice to match against the input.
363 /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
364 ///
365 /// # Returns
366 ///
367 /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
368 /// of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
369 /// * `None` - If the input does not match `search`.
370 ///
371 /// # Examples
372 ///
373 /// ```rust
374 /// use mago_syntax_core::input::Input;
375 /// use mago_source::SourceIdentifier;
376 ///
377 /// let source = SourceIdentifier::dummy();
378 ///
379 /// // Given input "( string ) x", starting at offset 0:
380 /// let input = Input::new(source.clone(), b"( string ) x");
381 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
382 ///
383 /// // Given input "(int)", with no whitespace:
384 /// let input = Input::new(source.clone(), b"(int)");
385 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
386 ///
387 /// // Given input "( InT )abc", ignoring ASCII case:
388 /// let input = Input::new(source.clone(), b"( InT )abc");
389 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
390 ///
391 /// // Given input "(integer)", attempting to match "(int)":
392 /// let input = Input::new(source.clone(), b"(integer)");
393 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
394 ///
395 /// // Trailing whitespace after ')':
396 /// let input = Input::new(source.clone(), b"(int) x");
397 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
398 /// ```
399 #[inline]
400 pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
401 let mut offset = self.offset;
402 let mut search_offset = 0;
403 let mut length = 0;
404 let bytes = self.bytes;
405 let total = self.length;
406 while search_offset < search.len() {
407 // Skip whitespace in the input.
408 while offset < total && bytes[offset].is_ascii_whitespace() {
409 offset += 1;
410 length += 1;
411 }
412
413 if offset >= total {
414 return None;
415 }
416
417 let input_byte = bytes[offset];
418 let search_byte = search[search_offset];
419 let matched = if ignore_ascii_case {
420 input_byte.eq_ignore_ascii_case(&search_byte)
421 } else {
422 input_byte == search_byte
423 };
424
425 if matched {
426 offset += 1;
427 length += 1;
428 search_offset += 1;
429 } else {
430 return None;
431 }
432 }
433
434 Some(length)
435 }
436
437 /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
438 ///
439 /// # Arguments
440 ///
441 /// * `offset` - The number of characters to skip before reading.
442 /// * `n` - The number of characters to read after skipping.
443 ///
444 /// # Returns
445 ///
446 /// A byte slice containing the peeked characters.
447 #[inline]
448 pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
449 let from = self.offset + offset;
450 if from >= self.length {
451 return &self.bytes[self.length..self.length];
452 }
453
454 let mut until = from + n;
455 if until >= self.length {
456 until = self.length;
457 }
458
459 &self.bytes[from..until]
460 }
461
462 /// Calculates the bounds for slicing the input safely.
463 ///
464 /// Ensures that slicing does not go beyond the input length.
465 ///
466 /// # Arguments
467 ///
468 /// * `n` - The number of characters to include in the slice.
469 ///
470 /// # Returns
471 ///
472 /// A tuple `(from, until)` representing the start and end indices for slicing.
473 #[inline]
474 const fn calculate_bound(&self, n: usize) -> (usize, usize) {
475 if self.has_reached_eof() {
476 return (self.length, self.length);
477 }
478
479 let mut until = self.offset + n;
480
481 if until >= self.length {
482 until = self.length;
483 }
484
485 (self.offset, until)
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use mago_span::Position;
492
493 use super::*;
494
495 #[test]
496 fn test_new() {
497 let bytes = b"Hello, world!";
498 let input = Input::new(SourceIdentifier::dummy(), bytes);
499
500 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
501 assert_eq!(input.length, bytes.len());
502 assert_eq!(input.bytes, bytes);
503 }
504
505 #[test]
506 fn test_is_eof() {
507 let bytes = b"";
508 let input = Input::new(SourceIdentifier::dummy(), bytes);
509
510 assert!(input.has_reached_eof());
511
512 let bytes = b"data";
513 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
514
515 assert!(!input.has_reached_eof());
516
517 input.skip(4);
518
519 assert!(input.has_reached_eof());
520 }
521
522 #[test]
523 fn test_next() {
524 let bytes = b"a\nb\r\nc\rd";
525 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
526
527 // 'a'
528 input.next();
529 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 1));
530
531 // '\n'
532 input.next();
533 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 2));
534
535 // 'b'
536 input.next();
537 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
538
539 // '\r\n' should be treated as one newline
540 input.next();
541 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 4));
542
543 // 'c'
544 input.next();
545 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 5));
546
547 // '\r'
548 input.next();
549 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
550
551 // 'd'
552 input.next();
553 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 7));
554 }
555
556 #[test]
557 fn test_consume() {
558 let bytes = b"abcdef";
559 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
560
561 let consumed = input.consume(3);
562 assert_eq!(consumed, b"abc");
563 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
564
565 let consumed = input.consume(3);
566 assert_eq!(consumed, b"def");
567 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
568
569 let consumed = input.consume(1); // Should return empty slice at EOF
570 assert_eq!(consumed, b"");
571 assert!(input.has_reached_eof());
572 }
573
574 #[test]
575 fn test_consume_remaining() {
576 let bytes = b"abcdef";
577 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
578
579 input.skip(2);
580 let remaining = input.consume_remaining();
581 assert_eq!(remaining, b"cdef");
582 assert!(input.has_reached_eof());
583 }
584
585 #[test]
586 fn test_read() {
587 let bytes = b"abcdef";
588 let input = Input::new(SourceIdentifier::dummy(), bytes);
589
590 let read = input.read(3);
591 assert_eq!(read, b"abc");
592 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
593 // Position should not change
594 }
595
596 #[test]
597 fn test_is_at() {
598 let bytes = b"abcdef";
599 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
600
601 assert!(input.is_at(b"abc", false));
602 input.skip(2);
603 assert!(input.is_at(b"cde", false));
604 assert!(!input.is_at(b"xyz", false));
605 }
606
607 #[test]
608 fn test_is_at_ignore_ascii_case() {
609 let bytes = b"AbCdEf";
610 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
611
612 assert!(input.is_at(b"abc", true));
613 input.skip(2);
614 assert!(input.is_at(b"cde", true));
615 assert!(!input.is_at(b"xyz", true));
616 }
617
618 #[test]
619 fn test_peek() {
620 let bytes = b"abcdef";
621 let input = Input::new(SourceIdentifier::dummy(), bytes);
622
623 let peeked = input.peek(2, 3);
624 assert_eq!(peeked, b"cde");
625 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
626 // Position should not change
627 }
628
629 #[test]
630 fn test_to_bound() {
631 let bytes = b"abcdef";
632 let input = Input::new(SourceIdentifier::dummy(), bytes);
633
634 let (from, until) = input.calculate_bound(3);
635 assert_eq!((from, until), (0, 3));
636
637 let (from, until) = input.calculate_bound(10); // Exceeds length
638 assert_eq!((from, until), (0, 6));
639 }
640}