mago_syntax_core/input.rs
1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13 pub(crate) bytes: &'a [u8],
14 pub(crate) length: usize,
15 pub(crate) offset: usize,
16 pub(crate) starting_position: Position,
17}
18
19impl<'a> Input<'a> {
20 /// Creates a new `Input` instance from the given input.
21 ///
22 /// # Arguments
23 ///
24 /// * `input` - A byte slice representing the input code to be processed.
25 ///
26 /// # Returns
27 ///
28 /// A new `Input` instance initialized at the beginning of the input.
29 pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
30 let length = bytes.len();
31
32 Self { bytes, length, offset: 0, starting_position: Position::start_of(source) }
33 }
34
35 /// Creates a new `Input` instance representing a byte slice that is
36 /// "anchored" at a specific absolute position within a larger source file.
37 ///
38 /// This is useful when lexing a subset (slice) of a source file, as it allows
39 /// generated tokens to retain accurate absolute positions and spans relative
40 /// to the original file.
41 ///
42 /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
43 /// but the absolute position is calculated relative to the `anchor_position`.
44 ///
45 /// # Arguments
46 ///
47 /// * `bytes` - A byte slice representing the input code subset to be lexed.
48 /// * `anchor_position` - The absolute `Position` in the original source file where
49 /// the provided `bytes` slice begins.
50 ///
51 /// # Returns
52 ///
53 /// A new `Input` instance ready to lex the `bytes`, maintaining positions
54 /// relative to `anchor_position`.
55 pub fn anchored_at(bytes: &'a [u8], anchor_position: Position) -> Self {
56 let length = bytes.len();
57
58 Self { bytes, length, offset: 0, starting_position: anchor_position }
59 }
60
61 /// Returns the source identifier of the input code.
62 #[inline]
63 pub const fn source_identifier(&self) -> SourceIdentifier {
64 self.starting_position.source
65 }
66
67 /// Returns the absolute current `Position` of the lexer within the original source file.
68 ///
69 /// It calculates this by adding the internal offset (progress within the current byte slice)
70 /// to the `starting_position` the `Input` was initialized with.
71 #[inline]
72 pub const fn current_position(&self) -> Position {
73 // Calculate absolute position by adding internal offset to the starting base
74 self.starting_position.forward(self.offset)
75 }
76
77 /// Returns the current internal byte offset relative to the start of the input slice.
78 ///
79 /// This indicates how many bytes have been consumed from the current `bytes` slice.
80 /// To get the absolute position in the original source file, use `current_position()`.
81 #[inline]
82 pub const fn current_offset(&self) -> usize {
83 self.offset
84 }
85
86 /// Returns `true` if the input slice is empty (length is zero).
87 #[inline]
88 pub const fn is_empty(&self) -> bool {
89 self.length == 0
90 }
91
92 /// Returns the total length in bytes of the input slice being processed.
93 #[inline]
94 pub const fn len(&self) -> usize {
95 self.length
96 }
97
98 /// Checks if the current position is at the end of the input.
99 ///
100 /// # Returns
101 ///
102 /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
103 #[inline]
104 pub const fn has_reached_eof(&self) -> bool {
105 self.offset >= self.length
106 }
107
108 /// Advances the current position by one character, updating line and column numbers.
109 ///
110 /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
111 ///
112 /// If the end of input is reached, no action is taken.
113 #[inline]
114 pub fn next(&mut self) {
115 if !self.has_reached_eof() {
116 self.offset += 1;
117 }
118 }
119
120 /// Skips the next `count` characters, advancing the position accordingly.
121 ///
122 /// Updates line and column numbers as it advances.
123 ///
124 /// # Arguments
125 ///
126 /// * `count` - The number of characters to skip.
127 #[inline]
128 pub fn skip(&mut self, count: usize) {
129 self.offset = (self.offset + count).min(self.length);
130 }
131
132 /// Consumes the next `count` characters and returns them as a slice.
133 ///
134 /// Advances the position by `count` characters.
135 ///
136 /// # Arguments
137 ///
138 /// * `count` - The number of characters to consume.
139 ///
140 /// # Returns
141 ///
142 /// A byte slice containing the consumed characters.
143 #[inline]
144 pub fn consume(&mut self, count: usize) -> &'a [u8] {
145 let (from, until) = self.calculate_bound(count);
146
147 self.skip(count);
148
149 &self.bytes[from..until]
150 }
151
152 /// Consumes all remaining characters from the current position to the end of input.
153 ///
154 /// Advances the position to EOF.
155 ///
156 /// # Returns
157 ///
158 /// A byte slice containing the remaining characters.
159 #[inline]
160 pub fn consume_remaining(&mut self) -> &'a [u8] {
161 if self.has_reached_eof() {
162 return &[];
163 }
164
165 let from = self.offset;
166 self.offset = self.length;
167
168 &self.bytes[from..]
169 }
170
171 /// Consumes characters until the given byte slice is found.
172 ///
173 /// Advances the position to the start of the search slice if found,
174 /// or to EOF if not found.
175 ///
176 /// # Arguments
177 ///
178 /// * `search` - The byte slice to search for.
179 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
180 ///
181 /// # Returns
182 ///
183 /// A byte slice containing the consumed characters.
184 #[inline]
185 pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
186 let start = self.offset;
187 if !ignore_ascii_case {
188 // For a single-byte search, use memchr.
189 if search.len() == 1 {
190 if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
191 self.offset += pos;
192 &self.bytes[start..self.offset]
193 } else {
194 self.offset = self.length;
195 &self.bytes[start..self.length]
196 }
197 } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
198 self.offset += pos;
199 &self.bytes[start..self.offset]
200 } else {
201 self.offset = self.length;
202 &self.bytes[start..self.length]
203 }
204 } else {
205 while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
206 self.offset += 1;
207 }
208
209 &self.bytes[start..self.offset]
210 }
211 }
212
213 #[inline]
214 pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
215 let start = self.offset;
216 if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
217 self.offset += pos + 1;
218
219 &self.bytes[start..self.offset]
220 } else {
221 self.offset = self.length;
222
223 &self.bytes[start..self.length]
224 }
225 }
226
227 /// Consumes whitespaces until a non-whitespace character is found.
228 ///
229 /// # Returns
230 ///
231 /// A byte slice containing the consumed whitespaces.
232 #[inline]
233 pub fn consume_whitespaces(&mut self) -> &'a [u8] {
234 let start = self.offset;
235 let bytes = self.bytes;
236 let len = self.length;
237 while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
238 self.offset += 1;
239 }
240
241 &bytes[start..self.offset]
242 }
243
244 /// Reads the next `n` characters without advancing the position.
245 ///
246 /// # Arguments
247 ///
248 /// * `n` - The number of characters to read.
249 ///
250 /// # Returns
251 ///
252 /// A byte slice containing the next `n` characters.
253 #[inline]
254 pub fn read(&self, n: usize) -> &'a [u8] {
255 let (from, until) = self.calculate_bound(n);
256
257 &self.bytes[from..until]
258 }
259
260 /// Reads a single byte at a specific byte offset within the input slice,
261 /// without advancing the internal cursor.
262 ///
263 /// This provides direct, low-level access to the underlying byte data.
264 ///
265 /// # Arguments
266 ///
267 /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
268 /// from which to read the byte.
269 ///
270 /// # Returns
271 ///
272 /// A reference to the byte located at the specified offset `at`.
273 ///
274 /// # Panics
275 ///
276 /// This method **panics** if the provided `at` offset is out of bounds
277 /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
278 pub fn read_at(&self, at: usize) -> &'a u8 {
279 &self.bytes[at]
280 }
281
282 /// Checks if the input at the current position matches the given byte slice.
283 ///
284 /// # Arguments
285 ///
286 /// * `search` - The byte slice to compare against the input.
287 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
288 ///
289 /// # Returns
290 ///
291 /// `true` if the next bytes match `search`; `false` otherwise.
292 #[inline]
293 pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
294 let (from, until) = self.calculate_bound(search.len());
295 let slice = &self.bytes[from..until];
296
297 if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
298 }
299
300 /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
301 ///
302 /// This method tries to match the provided byte slice `search` against the input starting
303 /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
304 /// are skipped during matching, but their length is included in the returned length.
305 ///
306 /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
307 /// in the returned length.
308 ///
309 /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `( string )`, etc.,
310 /// and this method would return the total length of the input consumed to match `(string)`,
311 /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
312 ///
313 /// # Arguments
314 ///
315 /// * `search` - The byte slice to match against the input.
316 /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
317 ///
318 /// # Returns
319 ///
320 /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
321 /// of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
322 /// * `None` - If the input does not match `search`.
323 ///
324 /// # Examples
325 ///
326 /// ```rust
327 /// use mago_syntax_core::input::Input;
328 /// use mago_source::SourceIdentifier;
329 ///
330 /// let source = SourceIdentifier::dummy();
331 ///
332 /// // Given input "( string ) x", starting at offset 0:
333 /// let input = Input::new(source.clone(), b"( string ) x");
334 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
335 ///
336 /// // Given input "(int)", with no whitespace:
337 /// let input = Input::new(source.clone(), b"(int)");
338 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
339 ///
340 /// // Given input "( InT )abc", ignoring ASCII case:
341 /// let input = Input::new(source.clone(), b"( InT )abc");
342 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
343 ///
344 /// // Given input "(integer)", attempting to match "(int)":
345 /// let input = Input::new(source.clone(), b"(integer)");
346 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
347 ///
348 /// // Trailing whitespace after ')':
349 /// let input = Input::new(source.clone(), b"(int) x");
350 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
351 /// ```
352 #[inline]
353 pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
354 let mut offset = self.offset;
355 let mut search_offset = 0;
356 let mut length = 0;
357 let bytes = self.bytes;
358 let total = self.length;
359 while search_offset < search.len() {
360 // Skip whitespace in the input.
361 while offset < total && bytes[offset].is_ascii_whitespace() {
362 offset += 1;
363 length += 1;
364 }
365
366 if offset >= total {
367 return None;
368 }
369
370 let input_byte = bytes[offset];
371 let search_byte = search[search_offset];
372 let matched = if ignore_ascii_case {
373 input_byte.eq_ignore_ascii_case(&search_byte)
374 } else {
375 input_byte == search_byte
376 };
377
378 if matched {
379 offset += 1;
380 length += 1;
381 search_offset += 1;
382 } else {
383 return None;
384 }
385 }
386
387 Some(length)
388 }
389
390 /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
391 ///
392 /// # Arguments
393 ///
394 /// * `offset` - The number of characters to skip before reading.
395 /// * `n` - The number of characters to read after skipping.
396 ///
397 /// # Returns
398 ///
399 /// A byte slice containing the peeked characters.
400 #[inline]
401 pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
402 let from = self.offset + offset;
403 if from >= self.length {
404 return &self.bytes[self.length..self.length];
405 }
406
407 let mut until = from + n;
408 if until >= self.length {
409 until = self.length;
410 }
411
412 &self.bytes[from..until]
413 }
414
415 /// Calculates the bounds for slicing the input safely.
416 ///
417 /// Ensures that slicing does not go beyond the input length.
418 ///
419 /// # Arguments
420 ///
421 /// * `n` - The number of characters to include in the slice.
422 ///
423 /// # Returns
424 ///
425 /// A tuple `(from, until)` representing the start and end indices for slicing.
426 #[inline]
427 const fn calculate_bound(&self, n: usize) -> (usize, usize) {
428 if self.has_reached_eof() {
429 return (self.length, self.length);
430 }
431
432 let mut until = self.offset + n;
433
434 if until >= self.length {
435 until = self.length;
436 }
437
438 (self.offset, until)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use mago_span::Position;
445
446 use super::*;
447
448 #[test]
449 fn test_new() {
450 let bytes = b"Hello, world!";
451 let input = Input::new(SourceIdentifier::dummy(), bytes);
452
453 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
454 assert_eq!(input.length, bytes.len());
455 assert_eq!(input.bytes, bytes);
456 }
457
458 #[test]
459 fn test_is_eof() {
460 let bytes = b"";
461 let input = Input::new(SourceIdentifier::dummy(), bytes);
462
463 assert!(input.has_reached_eof());
464
465 let bytes = b"data";
466 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
467
468 assert!(!input.has_reached_eof());
469
470 input.skip(4);
471
472 assert!(input.has_reached_eof());
473 }
474
475 #[test]
476 fn test_next() {
477 let bytes = b"a\nb\r\nc\rd";
478 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
479
480 // 'a'
481 input.next();
482 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 1));
483
484 // '\n'
485 input.next();
486 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 2));
487
488 // 'b'
489 input.next();
490 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
491
492 // '\r\n' should be treated as one newline
493 input.next();
494 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 4));
495
496 // 'c'
497 input.next();
498 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 5));
499
500 // '\r'
501 input.next();
502 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
503
504 // 'd'
505 input.next();
506 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 7));
507 }
508
509 #[test]
510 fn test_consume() {
511 let bytes = b"abcdef";
512 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
513
514 let consumed = input.consume(3);
515 assert_eq!(consumed, b"abc");
516 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 3));
517
518 let consumed = input.consume(3);
519 assert_eq!(consumed, b"def");
520 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 6));
521
522 let consumed = input.consume(1); // Should return empty slice at EOF
523 assert_eq!(consumed, b"");
524 assert!(input.has_reached_eof());
525 }
526
527 #[test]
528 fn test_consume_remaining() {
529 let bytes = b"abcdef";
530 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
531
532 input.skip(2);
533 let remaining = input.consume_remaining();
534 assert_eq!(remaining, b"cdef");
535 assert!(input.has_reached_eof());
536 }
537
538 #[test]
539 fn test_read() {
540 let bytes = b"abcdef";
541 let input = Input::new(SourceIdentifier::dummy(), bytes);
542
543 let read = input.read(3);
544 assert_eq!(read, b"abc");
545 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
546 // Position should not change
547 }
548
549 #[test]
550 fn test_is_at() {
551 let bytes = b"abcdef";
552 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
553
554 assert!(input.is_at(b"abc", false));
555 input.skip(2);
556 assert!(input.is_at(b"cde", false));
557 assert!(!input.is_at(b"xyz", false));
558 }
559
560 #[test]
561 fn test_is_at_ignore_ascii_case() {
562 let bytes = b"AbCdEf";
563 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
564
565 assert!(input.is_at(b"abc", true));
566 input.skip(2);
567 assert!(input.is_at(b"cde", true));
568 assert!(!input.is_at(b"xyz", true));
569 }
570
571 #[test]
572 fn test_peek() {
573 let bytes = b"abcdef";
574 let input = Input::new(SourceIdentifier::dummy(), bytes);
575
576 let peeked = input.peek(2, 3);
577 assert_eq!(peeked, b"cde");
578 assert_eq!(input.current_position(), Position::new(SourceIdentifier::dummy(), 0));
579 // Position should not change
580 }
581
582 #[test]
583 fn test_to_bound() {
584 let bytes = b"abcdef";
585 let input = Input::new(SourceIdentifier::dummy(), bytes);
586
587 let (from, until) = input.calculate_bound(3);
588 assert_eq!((from, until), (0, 3));
589
590 let (from, until) = input.calculate_bound(10); // Exceeds length
591 assert_eq!((from, until), (0, 6));
592 }
593}