mago_lexer/input.rs
1use memchr::memchr;
2use memchr::memmem::find;
3
4use mago_source::SourceIdentifier;
5use mago_span::Position;
6
7/// A struct representing the input code being lexed.
8///
9/// The `Input` struct provides methods to read, peek, consume, and skip characters
10/// from the bytes input code while keeping track of the current position (line, column, offset).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Input<'a> {
13 pub(crate) bytes: &'a [u8],
14 pub(crate) length: usize,
15 pub(crate) position: Position,
16}
17
18impl<'a> Input<'a> {
19 /// Creates a new `Input` instance from the given input.
20 ///
21 /// # Arguments
22 ///
23 /// * `input` - A byte slice representing the input code to be processed.
24 ///
25 /// # Returns
26 ///
27 /// A new `Input` instance initialized at the beginning of the input.
28 pub fn new(source: SourceIdentifier, bytes: &'a [u8]) -> Self {
29 let length = bytes.len();
30
31 Self { bytes, length, position: Position::start_of(source) }
32 }
33
34 /// Returns the source identifier of the input code.
35 #[inline]
36 pub const fn source_identifier(&self) -> SourceIdentifier {
37 self.position.source
38 }
39
40 /// Returns the current position in the input code.
41 ///
42 /// # Returns
43 ///
44 /// A `Position` struct containing the current line, column, and offset.
45 #[inline]
46 pub const fn position(&self) -> Position {
47 self.position
48 }
49
50 /// Checks if the current position is at the end of the input.
51 ///
52 /// # Returns
53 ///
54 /// `true` if the current offset is greater than or equal to the input length; `false` otherwise.
55 #[inline]
56 pub const fn has_reached_eof(&self) -> bool {
57 self.position.offset >= self.length
58 }
59
60 /// Advances the current position by one character, updating line and column numbers.
61 ///
62 /// Handles different line endings (`\n`, `\r`, `\r\n`) and updates line and column counters accordingly.
63 ///
64 /// If the end of input is reached, no action is taken.
65 #[inline]
66 pub fn next(&mut self) {
67 if !self.has_reached_eof() {
68 self.position.offset += 1;
69 }
70 }
71
72 /// Skips the next `count` characters, advancing the position accordingly.
73 ///
74 /// Updates line and column numbers as it advances.
75 ///
76 /// # Arguments
77 ///
78 /// * `count` - The number of characters to skip.
79 #[inline]
80 pub fn skip(&mut self, count: usize) {
81 self.position.offset = (self.position.offset + count).min(self.length);
82 }
83
84 /// Consumes the next `count` characters and returns them as a slice.
85 ///
86 /// Advances the position by `count` characters.
87 ///
88 /// # Arguments
89 ///
90 /// * `count` - The number of characters to consume.
91 ///
92 /// # Returns
93 ///
94 /// A byte slice containing the consumed characters.
95 #[inline]
96 pub fn consume(&mut self, count: usize) -> &'a [u8] {
97 let (from, until) = self.calculate_bound(count);
98
99 self.skip(count);
100
101 &self.bytes[from..until]
102 }
103
104 /// Consumes all remaining characters from the current position to the end of input.
105 ///
106 /// Advances the position to EOF.
107 ///
108 /// # Returns
109 ///
110 /// A byte slice containing the remaining characters.
111 #[inline]
112 pub fn consume_remaining(&mut self) -> &'a [u8] {
113 if self.has_reached_eof() {
114 return &[];
115 }
116
117 let from = self.position.offset;
118 self.position.offset = self.length;
119
120 &self.bytes[from..]
121 }
122
123 /// Consumes characters until the given byte slice is found.
124 ///
125 /// Advances the position to the start of the search slice if found,
126 /// or to EOF if not found.
127 ///
128 /// # Arguments
129 ///
130 /// * `search` - The byte slice to search for.
131 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing characters.
132 ///
133 /// # Returns
134 ///
135 /// A byte slice containing the consumed characters.
136 #[inline]
137 pub fn consume_until(&mut self, search: &[u8], ignore_ascii_case: bool) -> &'a [u8] {
138 let start = self.position.offset;
139 if !ignore_ascii_case {
140 // For a single-byte search, use memchr.
141 if search.len() == 1 {
142 if let Some(pos) = memchr(search[0], &self.bytes[self.position.offset..]) {
143 self.position.offset += pos;
144 &self.bytes[start..self.position.offset]
145 } else {
146 self.position.offset = self.length;
147 &self.bytes[start..self.length]
148 }
149 } else if let Some(pos) = find(&self.bytes[self.position.offset..], search) {
150 self.position.offset += pos;
151 &self.bytes[start..self.position.offset]
152 } else {
153 self.position.offset = self.length;
154 &self.bytes[start..self.length]
155 }
156 } else {
157 while !self.has_reached_eof() && !self.is_at(search, ignore_ascii_case) {
158 self.position.offset += 1;
159 }
160
161 &self.bytes[start..self.position.offset]
162 }
163 }
164
165 #[inline]
166 pub fn consume_through(&mut self, search: u8) -> &'a [u8] {
167 let start = self.position.offset;
168 if let Some(pos) = memchr::memchr(search, &self.bytes[self.position.offset..]) {
169 self.position.offset += pos + 1;
170
171 &self.bytes[start..self.position.offset]
172 } else {
173 self.position.offset = self.length;
174
175 &self.bytes[start..self.length]
176 }
177 }
178
179 /// Consumes whitespaces until a non-whitespace character is found.
180 ///
181 /// # Returns
182 ///
183 /// A byte slice containing the consumed whitespaces.
184 #[inline]
185 pub fn consume_whitespaces(&mut self) -> &'a [u8] {
186 let start = self.position.offset;
187 let bytes = self.bytes;
188 let len = self.length;
189 while self.position.offset < len && bytes[self.position.offset].is_ascii_whitespace() {
190 self.position.offset += 1;
191 }
192
193 &bytes[start..self.position.offset]
194 }
195
196 /// Reads the next `n` characters without advancing the position.
197 ///
198 /// # Arguments
199 ///
200 /// * `n` - The number of characters to read.
201 ///
202 /// # Returns
203 ///
204 /// A byte slice containing the next `n` characters.
205 #[inline]
206 pub fn read(&self, n: usize) -> &'a [u8] {
207 let (from, until) = self.calculate_bound(n);
208
209 &self.bytes[from..until]
210 }
211
212 /// Checks if the input at the current position matches the given byte slice.
213 ///
214 /// # Arguments
215 ///
216 /// * `search` - The byte slice to compare against the input.
217 /// * `ignore_ascii_case` - Whether to ignore ASCII case when comparing.
218 ///
219 /// # Returns
220 ///
221 /// `true` if the next bytes match `search`; `false` otherwise.
222 #[inline]
223 pub fn is_at(&self, search: &[u8], ignore_ascii_case: bool) -> bool {
224 let (from, until) = self.calculate_bound(search.len());
225 let slice = &self.bytes[from..until];
226
227 if ignore_ascii_case { slice.eq_ignore_ascii_case(search) } else { slice == search }
228 }
229
230 /// Attempts to match the given byte sequence at the current position, ignoring whitespace in the input.
231 ///
232 /// This method tries to match the provided byte slice `search` against the input starting
233 /// from the current position, possibly ignoring ASCII case. Whitespace characters in the input
234 /// are skipped during matching, but their length is included in the returned length.
235 ///
236 /// Importantly, the method **does not include** any trailing whitespace **after** the matched sequence
237 /// in the returned length.
238 ///
239 /// For example, to match the sequence `(string)`, the input could be `(string)`, `( string )`, `( string )`, etc.,
240 /// and this method would return the total length of the input consumed to match `(string)`,
241 /// including any whitespace within the matched sequence, but **excluding** any whitespace after it.
242 ///
243 /// # Arguments
244 ///
245 /// * `search` - The byte slice to match against the input.
246 /// * `ignore_ascii_case` - If `true`, ASCII case is ignored during comparison.
247 ///
248 /// # Returns
249 ///
250 /// * `Some(length)` - If the input matches `search` (ignoring whitespace within the sequence), returns the total length
251 /// of the input consumed to match `search`, including any skipped whitespace **within** the matched sequence.
252 /// * `None` - If the input does not match `search`.
253 ///
254 /// # Examples
255 ///
256 /// ```rust
257 /// use mago_lexer::input::Input;
258 /// use mago_source::SourceIdentifier;
259 ///
260 /// let source = SourceIdentifier::dummy();
261 ///
262 /// // Given input "( string ) x", starting at offset 0:
263 /// let input = Input::new(source.clone(), b"( string ) x");
264 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(string)", true), Some(10)); // 10 bytes consumed up to ')'
265 ///
266 /// // Given input "(int)", with no whitespace:
267 /// let input = Input::new(source.clone(), b"(int)");
268 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // 5 bytes consumed
269 ///
270 /// // Given input "( InT )abc", ignoring ASCII case:
271 /// let input = Input::new(source.clone(), b"( InT )abc");
272 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(10)); // 10 bytes consumed up to ')'
273 ///
274 /// // Given input "(integer)", attempting to match "(int)":
275 /// let input = Input::new(source.clone(), b"(integer)");
276 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", false), None); // Does not match
277 ///
278 /// // Trailing whitespace after ')':
279 /// let input = Input::new(source.clone(), b"(int) x");
280 /// assert_eq!(input.match_sequence_ignore_whitespace(b"(int)", true), Some(5)); // Length up to ')', excludes spaces after ')'
281 /// ```
282 #[inline]
283 pub const fn match_sequence_ignore_whitespace(&self, search: &[u8], ignore_ascii_case: bool) -> Option<usize> {
284 let mut offset = self.position.offset;
285 let mut search_offset = 0;
286 let mut length = 0;
287 let bytes = self.bytes;
288 let total = self.length;
289 while search_offset < search.len() {
290 // Skip whitespace in the input.
291 while offset < total && bytes[offset].is_ascii_whitespace() {
292 offset += 1;
293 length += 1;
294 }
295
296 if offset >= total {
297 return None;
298 }
299
300 let input_byte = bytes[offset];
301 let search_byte = search[search_offset];
302 let matched = if ignore_ascii_case {
303 input_byte.eq_ignore_ascii_case(&search_byte)
304 } else {
305 input_byte == search_byte
306 };
307
308 if matched {
309 offset += 1;
310 length += 1;
311 search_offset += 1;
312 } else {
313 return None;
314 }
315 }
316
317 Some(length)
318 }
319
320 /// Peeks ahead `i` characters and reads the next `n` characters without advancing the position.
321 ///
322 /// # Arguments
323 ///
324 /// * `offset` - The number of characters to skip before reading.
325 /// * `n` - The number of characters to read after skipping.
326 ///
327 /// # Returns
328 ///
329 /// A byte slice containing the peeked characters.
330 #[inline]
331 pub fn peek(&self, offset: usize, n: usize) -> &'a [u8] {
332 let from = self.position.offset + offset;
333 if from >= self.length {
334 return &self.bytes[self.length..self.length];
335 }
336
337 let mut until = from + n;
338 if until >= self.length {
339 until = self.length;
340 }
341
342 &self.bytes[from..until]
343 }
344
345 /// Calculates the bounds for slicing the input safely.
346 ///
347 /// Ensures that slicing does not go beyond the input length.
348 ///
349 /// # Arguments
350 ///
351 /// * `n` - The number of characters to include in the slice.
352 ///
353 /// # Returns
354 ///
355 /// A tuple `(from, until)` representing the start and end indices for slicing.
356 #[inline]
357 const fn calculate_bound(&self, n: usize) -> (usize, usize) {
358 if self.has_reached_eof() {
359 return (self.length, self.length);
360 }
361
362 let mut until = self.position.offset + n;
363
364 if until >= self.length {
365 until = self.length;
366 }
367
368 (self.position.offset, until)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use mago_span::Position;
375
376 use super::*;
377
378 #[test]
379 fn test_new() {
380 let bytes = b"Hello, world!";
381 let input = Input::new(SourceIdentifier::dummy(), bytes);
382
383 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
384 assert_eq!(input.length, bytes.len());
385 assert_eq!(input.bytes, bytes);
386 }
387
388 #[test]
389 fn test_is_eof() {
390 let bytes = b"";
391 let input = Input::new(SourceIdentifier::dummy(), bytes);
392
393 assert!(input.has_reached_eof());
394
395 let bytes = b"data";
396 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
397
398 assert!(!input.has_reached_eof());
399
400 input.skip(4);
401
402 assert!(input.has_reached_eof());
403 }
404
405 #[test]
406 fn test_next() {
407 let bytes = b"a\nb\r\nc\rd";
408 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
409
410 // 'a'
411 input.next();
412 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 1));
413
414 // '\n'
415 input.next();
416 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 2));
417
418 // 'b'
419 input.next();
420 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 3));
421
422 // '\r\n' should be treated as one newline
423 input.next();
424 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 4));
425
426 // 'c'
427 input.next();
428 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 5));
429
430 // '\r'
431 input.next();
432 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 6));
433
434 // 'd'
435 input.next();
436 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 7));
437 }
438
439 #[test]
440 fn test_consume() {
441 let bytes = b"abcdef";
442 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
443
444 let consumed = input.consume(3);
445 assert_eq!(consumed, b"abc");
446 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 3));
447
448 let consumed = input.consume(3);
449 assert_eq!(consumed, b"def");
450 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 6));
451
452 let consumed = input.consume(1); // Should return empty slice at EOF
453 assert_eq!(consumed, b"");
454 assert!(input.has_reached_eof());
455 }
456
457 #[test]
458 fn test_consume_remaining() {
459 let bytes = b"abcdef";
460 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
461
462 input.skip(2);
463 let remaining = input.consume_remaining();
464 assert_eq!(remaining, b"cdef");
465 assert!(input.has_reached_eof());
466 }
467
468 #[test]
469 fn test_read() {
470 let bytes = b"abcdef";
471 let input = Input::new(SourceIdentifier::dummy(), bytes);
472
473 let read = input.read(3);
474 assert_eq!(read, b"abc");
475 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
476 // Position should not change
477 }
478
479 #[test]
480 fn test_is_at() {
481 let bytes = b"abcdef";
482 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
483
484 assert!(input.is_at(b"abc", false));
485 input.skip(2);
486 assert!(input.is_at(b"cde", false));
487 assert!(!input.is_at(b"xyz", false));
488 }
489
490 #[test]
491 fn test_is_at_ignore_ascii_case() {
492 let bytes = b"AbCdEf";
493 let mut input = Input::new(SourceIdentifier::dummy(), bytes);
494
495 assert!(input.is_at(b"abc", true));
496 input.skip(2);
497 assert!(input.is_at(b"cde", true));
498 assert!(!input.is_at(b"xyz", true));
499 }
500
501 #[test]
502 fn test_peek() {
503 let bytes = b"abcdef";
504 let input = Input::new(SourceIdentifier::dummy(), bytes);
505
506 let peeked = input.peek(2, 3);
507 assert_eq!(peeked, b"cde");
508 assert_eq!(input.position(), Position::new(SourceIdentifier::dummy(), 0));
509 // Position should not change
510 }
511
512 #[test]
513 fn test_to_bound() {
514 let bytes = b"abcdef";
515 let input = Input::new(SourceIdentifier::dummy(), bytes);
516
517 let (from, until) = input.calculate_bound(3);
518 assert_eq!((from, until), (0, 3));
519
520 let (from, until) = input.calculate_bound(10); // Exceeds length
521 assert_eq!((from, until), (0, 6));
522 }
523}