mago_syntax_core/input.rs
1use mago_database::file::HasFileId;
2use memchr::memchr;
3use memchr::memmem::find;
4
5use mago_database::file::File;
6use mago_database::file::FileId;
7use mago_span::Position;
8
9/// A struct representing the input code being lexed.
10///
11/// The `Input` struct provides methods to read, peek, consume, and skip characters
12/// from the bytes input code while keeping track of the current position (line, column, offset).
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct Input<'a> {
15 pub(crate) bytes: &'a [u8],
16 pub(crate) length: usize,
17 pub(crate) offset: usize,
18 pub(crate) starting_position: Position,
19 pub(crate) file_id: FileId,
20}
21
22impl<'a> Input<'a> {
23 /// Creates a new `Input` instance from the given input.
24 ///
25 /// # Arguments
26 ///
27 /// * `file_id` - The unique identifier for the source file this input belongs to.
28 /// * `bytes` - A byte slice representing the input code to be lexed.
29 ///
30 /// # Returns
31 ///
32 /// A new `Input` instance initialized at the beginning of the input.
33 pub fn new(file_id: FileId, bytes: &'a [u8]) -> Self {
34 let length = bytes.len();
35
36 Self { bytes, length, offset: 0, file_id, starting_position: Position::new(0) }
37 }
38
39 /// Creates a new `Input` instance from the contents of a `File`.
40 ///
41 /// # Arguments
42 ///
43 /// * `file` - A reference to the `File` containing the source code.
44 ///
45 /// # Returns
46 ///
47 /// A new `Input` instance initialized with the file's ID and contents.
48 pub fn from_file(file: &'a File) -> Self {
49 Self::new(file.id, file.contents.as_bytes())
50 }
51
52 /// Creates a new `Input` instance representing a byte slice that is
53 /// "anchored" at a specific absolute position within a larger source file.
54 ///
55 /// This is useful when lexing a subset (slice) of a source file, as it allows
56 /// generated tokens to retain accurate absolute positions and spans relative
57 /// to the original file.
58 ///
59 /// The internal cursor (`offset`) starts at 0 relative to the `bytes` slice,
60 /// but the absolute position is calculated relative to the `anchor_position`.
61 ///
62 /// # Arguments
63 ///
64 /// * `file_id` - The unique identifier for the source file this input belongs to.
65 /// * `bytes` - A byte slice representing the input code subset to be lexed.
66 /// * `anchor_position` - The absolute `Position` in the original source file where the provided `bytes` slice begins.
67 ///
68 /// # Returns
69 ///
70 /// A new `Input` instance ready to lex the `bytes`, maintaining positions
71 /// relative to `anchor_position`.
72 pub fn anchored_at(file_id: FileId, bytes: &'a [u8], anchor_position: Position) -> Self {
73 let length = bytes.len();
74
75 Self { bytes, length, offset: 0, file_id, starting_position: anchor_position }
76 }
77
78 /// Returns the source file identifier of the input code.
79 #[inline]
80 pub const fn file_id(&self) -> FileId {
81 self.file_id
82 }
83
84 /// Returns the absolute current `Position` of the lexer within the original source file.
85 ///
86 /// It calculates this by adding the internal offset (progress within the current byte slice)
87 /// to the `starting_position` the `Input` was initialized with.
88 #[inline]
89 pub const fn current_position(&self) -> Position {
90 // Calculate absolute position by adding internal offset to the starting base
91 self.starting_position.forward(self.offset as u32)
92 }
93
94 /// Returns the current internal byte offset relative to the start of the input slice.
95 ///
96 /// This indicates how many bytes have been consumed from the current `bytes` slice.
97 /// To get the absolute position in the original source file, use `current_position()`.
98 #[inline]
99 pub const fn current_offset(&self) -> usize {
100 self.offset
101 }
102
103 /// Returns `true` if the input slice is empty (length is zero).
104 #[inline]
105 pub const fn is_empty(&self) -> bool {
106 self.length == 0
107 }
108
109 /// Returns the total length in bytes of the input slice being processed.
110 #[inline]
111 pub const fn len(&self) -> usize {
112 self.length
113 }
114
115 /// Checks if the current position is at the end of the input.
116 ///
117 /// # Returns
118 ///
119 /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
120 #[inline]
121 pub const fn has_reached_eof(&self) -> bool {
122 self.offset >= self.length
123 }
124
125 /// Returns a byte slice within a specified absolute range.
126 ///
127 /// The `from` and `to` arguments are absolute byte offsets from the beginning
128 /// of the original source file. The method calculates the correct slice
129 /// relative to the `starting_position` of this `Input`.
130 ///
131 /// This is useful for retrieving the raw text of a `Span` or `Token` whose
132 /// positions are absolute, even when the `Input` only contains a subsection
133 /// of the source file.
134 ///
135 /// The returned slice is defensively clamped to the bounds of the current
136 /// `Input`'s byte slice to prevent panics.
137 ///
138 /// # Arguments
139 ///
140 /// * `from` - The absolute starting byte offset.
141 /// * `to` - The absolute ending byte offset (exclusive).
142 ///
143 /// # Returns
144 ///
145 /// A byte slice `&[u8]` corresponding to the requested range.
146 #[inline]
147 pub fn slice_in_range(&self, from: u32, to: u32) -> &'a [u8] {
148 let base_offset = self.starting_position.offset;
149
150 // Calculate the start and end positions relative to the local `bytes` slice.
151 // `saturating_sub` prevents underflow if `from`/`to` are smaller than `base_offset`.
152 let local_from = from.saturating_sub(base_offset) as usize;
153 let local_to = to.saturating_sub(base_offset) as usize;
154
155 // Clamp the local indices to the actual length of the `bytes` slice to prevent panics.
156 let start = local_from.min(self.length);
157 let end = local_to.min(self.length);
158
159 // Ensure the start index is not greater than the end index.
160 if start >= end {
161 return &[];
162 }
163
164 // If the start index is beyond the length of the input, return an empty slice.
165 if start >= self.length {
166 return &[];
167 }
168
169 &self.bytes[start..end]
170 }
171
172 /// Advances the current position by one character, updating line and column numbers.
173 ///
174 /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
175 ///
176 /// If the end of input is reached, no action is taken.
177 #[inline]
178 pub fn next(&mut self) {
179 if !self.has_reached_eof() {
180 self.offset += 1;
181 }
182 }
183
184 /// Skips the next `count` characters, advancing the position accordingly.
185 ///
186 /// Updates line and column numbers as it advances.
187 ///
188 /// # Arguments
189 ///
190 /// * `count` - The number of characters to skip.
191 #[inline]
192 pub fn skip(&mut self, count: usize) {
193 self.offset = (self.offset + count).min(self.length);
194 }
195
196 /// Consumes the next `count` characters and returns them as a slice.
197 ///
198 /// Advances the position by `count` characters.
199 ///
200 /// # Arguments
201 ///
202 /// * `count` - The number of characters to consume.
203 ///
204 /// # Returns
205 ///
206 /// A byte slice containing the consumed characters.
207 #[inline]
208 pub fn consume(&mut self, count: usize) -> &'a [u8] {
209 let (from, until) = self.calculate_bound(count);
210
211 self.skip(count);
212
213 &self.bytes[from..until]
214 }
215
216 /// Consumes all remaining characters from the current position to the end of input.
217 ///
218 /// Advances the position to EOF.
219 ///
220 /// # Returns
221 ///
222 /// A byte slice containing the remaining characters.
223 #[inline]
224 pub fn consume_remaining(&mut self) -> &'a [u8] {
225 if self.has_reached_eof() {
226 return &[];
227 }
228
229 let from = self.offset;
230 self.offset = self.length;
231
232 &self.bytes[from..]
233 }
234
235 /// Consumes characters until the given byte slice is found.
236 ///
237 /// Advances the position to the start of the search slice if found,
238 /// or to EOF if not found.
239 ///
240 /// # Arguments
241 ///
242 /// * `search` - The byte slice to search for.
243 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
244 ///
245 /// # Returns
246 ///
247 /// A byte slice containing the consumed characters.
248 #[inline]
249 pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
250 let start = self.offset;
251 if !ignore_ascii_case {
252 // For a single-byte search, use memchr.
253 if search.len() == 1 {
254 if let Some(pos) = memchr(search[0], &self.bytes[self.offset..]) {
255 self.offset += pos;
256 &self.bytes[start..self.offset]
257 } else {
258 self.offset = self.length;
259 &self.bytes[start..self.length]
260 }
261 } else if let Some(pos) = find(&self.bytes[self.offset..], search) {
262 self.offset += pos;
263 &self.bytes[start..self.offset]
264 } else {
265 self.offset = self.length;
266 &self.bytes[start..self.length]
267 }
268 } else {
269 while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
270 self.offset += 1;
271 }
272
273 &self.bytes[start..self.offset]
274 }
275 }
276
277 #[inline]
278 pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
279 let start = self.offset;
280 if let Some(pos) = memchr::memchr(search, &self.bytes[self.offset..]) {
281 self.offset += pos + 1;
282
283 &self.bytes[start..self.offset]
284 } else {
285 self.offset = self.length;
286
287 &self.bytes[start..self.length]
288 }
289 }
290
291 /// Consumes whitespaces until a non-whitespace character is found.
292 ///
293 /// # Returns
294 ///
295 /// A byte slice containing the consumed whitespaces.
296 #[inline]
297 pub fn consume_whitespaces(&mut self) -> &'a [u8] {
298 let start = self.offset;
299 let bytes = self.bytes;
300 let len = self.length;
301 while self.offset < len && bytes[self.offset].is_ascii_whitespace() {
302 self.offset += 1;
303 }
304
305 &bytes[start..self.offset]
306 }
307
308 /// Reads the next `n` characters without advancing the position.
309 ///
310 /// # Arguments
311 ///
312 /// * `n` - The number of characters to read.
313 ///
314 /// # Returns
315 ///
316 /// A byte slice containing the next `n` characters.
317 #[inline]
318 pub fn read(&self, n: usize) -> &'a [u8] {
319 let (from, until) = self.calculate_bound(n);
320
321 &self.bytes[from..until]
322 }
323
324 /// Reads a single byte at a specific byte offset within the input slice,
325 /// without advancing the internal cursor.
326 ///
327 /// This provides direct, low-level access to the underlying byte data.
328 ///
329 /// # Arguments
330 ///
331 /// * `at` - The zero-based byte offset within the input slice (`self.bytes`)
332 /// from which to read the byte.
333 ///
334 /// # Returns
335 ///
336 /// A reference to the byte located at the specified offset `at`.
337 ///
338 /// # Panics
339 ///
340 /// This method **panics** if the provided `at` offset is out of bounds
341 /// for the input byte slice (i.e., if `at >= self.bytes.len()`).
342 pub fn read_at(&self, at: usize) -> &'a u8 {
343 &self.bytes[at]
344 }
345
346 /// Checks if the input at the current position matches the given byte slice.
347 ///
348 /// # Arguments
349 ///
350 /// * `search` - The byte slice to compare against the input.
351 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
352 ///
353 /// # Returns
354 ///
355 /// `true` if the next bytes match `search`; `false` otherwise.
356 #[inline]
357 pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
358 let (from, until) = self.calculate_bound(search.len());
359 let slice = &self.bytes[from..until];
360
361 if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
362 }
363
364 /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
365 ///
366 /// This method tries to match the provided byte slice `search` against the input starting
367 /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
368 /// are skipped during matching, but their length is included in the returned length.
369 ///
370 /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
371 /// in the returned length.
372 ///
373 /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `( string )`, etc.,
374 /// and this method would return the total length of the input consumed to match `(string)`,
375 /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
376 ///
377 /// # Arguments
378 ///
379 /// * `search` - The byte slice to match against the input.
380 /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
381 ///
382 /// # Returns
383 ///
384 /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
385 /// of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
386 /// * `None` - If the input does not match `search`.
387 #[inline]
388 pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
389 let mut offset = self.offset;
390 let mut search_offset = 0;
391 let mut length = 0;
392 let bytes = self.bytes;
393 let total = self.length;
394 while search_offset < search.len() {
395 // Skip whitespace in the input.
396 while offset < total && bytes[offset].is_ascii_whitespace() {
397 offset += 1;
398 length += 1;
399 }
400
401 if offset >= total {
402 return None;
403 }
404
405 let input_byte = bytes[offset];
406 let search_byte = search[search_offset];
407 let matched = if ignore_ascii_case {
408 input_byte.eq_ignore_ascii_case(&search_byte)
409 } else {
410 input_byte == search_byte
411 };
412
413 if matched {
414 offset += 1;
415 length += 1;
416 search_offset += 1;
417 } else {
418 return None;
419 }
420 }
421
422 Some(length)
423 }
424
425 /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
426 ///
427 /// # Arguments
428 ///
429 /// * `offset` - The number of characters to skip before reading.
430 /// * `n` - The number of characters to read after skipping.
431 ///
432 /// # Returns
433 ///
434 /// A byte slice containing the peeked characters.
435 #[inline]
436 pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
437 let from = self.offset + offset;
438 if from >= self.length {
439 return &self.bytes[self.length..self.length];
440 }
441
442 let mut until = from + n;
443 if until >= self.length {
444 until = self.length;
445 }
446
447 &self.bytes[from..until]
448 }
449
450 /// Calculates the bounds for slicing the input safely.
451 ///
452 /// Ensures that slicing does not go beyond the input length.
453 ///
454 /// # Arguments
455 ///
456 /// * `n` - The number of characters to include in the slice.
457 ///
458 /// # Returns
459 ///
460 /// A tuple `(from, until)` representing the start and end indices for slicing.
461 #[inline]
462 const fn calculate_bound(&self, n: usize) -> (usize, usize) {
463 if self.has_reached_eof() {
464 return (self.length, self.length);
465 }
466
467 let mut until = self.offset + n;
468
469 if until >= self.length {
470 until = self.length;
471 }
472
473 (self.offset, until)
474 }
475}
476
477impl HasFileId for Input<'_> {
478 fn file_id(&self) -> FileId {
479 self.file_id
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use mago_span::Position;
486
487 use super::*;
488
489 #[test]
490 fn test_new() {
491 let bytes = b"Hello, world!";
492 let input = Input::new(FileId::zero(), bytes);
493
494 assert_eq!(input.current_position(), Position::new(0));
495 assert_eq!(input.length, bytes.len());
496 assert_eq!(input.bytes, bytes);
497 }
498
499 #[test]
500 fn test_is_eof() {
501 let bytes = b"";
502 let input = Input::new(FileId::zero(), bytes);
503
504 assert!(input.has_reached_eof());
505
506 let bytes = b"data";
507 let mut input = Input::new(FileId::zero(), bytes);
508
509 assert!(!input.has_reached_eof());
510
511 input.skip(4);
512
513 assert!(input.has_reached_eof());
514 }
515
516 #[test]
517 fn test_next() {
518 let bytes = b"a\nb\r\nc\rd";
519 let mut input = Input::new(FileId::zero(), bytes);
520
521 // 'a'
522 input.next();
523 assert_eq!(input.current_position(), Position::new(1));
524
525 // '\n'
526 input.next();
527 assert_eq!(input.current_position(), Position::new(2));
528
529 // 'b'
530 input.next();
531 assert_eq!(input.current_position(), Position::new(3));
532
533 // '\r\n' should be treated as one newline
534 input.next();
535 assert_eq!(input.current_position(), Position::new(4));
536
537 // 'c'
538 input.next();
539 assert_eq!(input.current_position(), Position::new(5));
540
541 // '\r'
542 input.next();
543 assert_eq!(input.current_position(), Position::new(6));
544
545 // 'd'
546 input.next();
547 assert_eq!(input.current_position(), Position::new(7));
548 }
549
550 #[test]
551 fn test_consume() {
552 let bytes = b"abcdef";
553 let mut input = Input::new(FileId::zero(), bytes);
554
555 let consumed = input.consume(3);
556 assert_eq!(consumed, b"abc");
557 assert_eq!(input.current_position(), Position::new(3));
558
559 let consumed = input.consume(3);
560 assert_eq!(consumed, b"def");
561 assert_eq!(input.current_position(), Position::new(6));
562
563 let consumed = input.consume(1); // Should return empty slice at EOF
564 assert_eq!(consumed, b"");
565 assert!(input.has_reached_eof());
566 }
567
568 #[test]
569 fn test_consume_remaining() {
570 let bytes = b"abcdef";
571 let mut input = Input::new(FileId::zero(), bytes);
572
573 input.skip(2);
574 let remaining = input.consume_remaining();
575 assert_eq!(remaining, b"cdef");
576 assert!(input.has_reached_eof());
577 }
578
579 #[test]
580 fn test_read() {
581 let bytes = b"abcdef";
582 let input = Input::new(FileId::zero(), bytes);
583
584 let read = input.read(3);
585 assert_eq!(read, b"abc");
586 assert_eq!(input.current_position(), Position::new(0));
587 // Position should not change
588 }
589
590 #[test]
591 fn test_is_at() {
592 let bytes = b"abcdef";
593 let mut input = Input::new(FileId::zero(), bytes);
594
595 assert!(input.is_at(b"abc", false));
596 input.skip(2);
597 assert!(input.is_at(b"cde", false));
598 assert!(!input.is_at(b"xyz", false));
599 }
600
601 #[test]
602 fn test_is_at_ignore_ascii_case() {
603 let bytes = b"AbCdEf";
604 let mut input = Input::new(FileId::zero(), bytes);
605
606 assert!(input.is_at(b"abc", true));
607 input.skip(2);
608 assert!(input.is_at(b"cde", true));
609 assert!(!input.is_at(b"xyz", true));
610 }
611
612 #[test]
613 fn test_peek() {
614 let bytes = b"abcdef";
615 let input = Input::new(FileId::zero(), bytes);
616
617 let peeked = input.peek(2, 3);
618 assert_eq!(peeked, b"cde");
619 assert_eq!(input.current_position(), Position::new(0));
620 // Position should not change
621 }
622
623 #[test]
624 fn test_to_bound() {
625 let bytes = b"abcdef";
626 let input = Input::new(FileId::zero(), bytes);
627
628 let (from, until) = input.calculate_bound(3);
629 assert_eq!((from, until), (0, 3));
630
631 let (from, until) = input.calculate_bound(10); // Exceeds length
632 assert_eq!((from, until), (0, 6));
633 }
634}