Skip to main content

strata/
lexer.rs

1use crate::error::{ParseError, ParseErrorKind, Span};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub struct Token {
5    pub kind: TokenKind,
6    pub span: Span,
7}
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum TokenKind {
11    // Literals:
12    Null,
13    True,
14    False,
15    Int(i64),
16    String(String),
17    Bytes(Vec<u8>),
18
19    // Identifiers:
20    Ident(String),
21
22    // Punctuation:
23    LBrace,
24    RBrace,
25    LBracket,
26    RBracket,
27    Colon,
28    Comma,
29
30    EOF,
31}
32
33pub struct Lexer<'a> {
34    input: &'a [u8],
35    offset: usize,
36    line: usize,
37    column: usize,
38}
39
40impl<'a> Lexer<'a> {
41    pub fn new(input: &'a str) -> Self {
42        Self {
43            input: input.as_bytes(),
44            offset: 0,
45            line: 1,
46            column: 1,
47        }
48    }
49
50    fn span(&self) -> Span {
51        Span {
52            offset: self.offset,
53            line: self.line,
54            column: self.column,
55        }
56    }
57
58    fn error(&self, kind: ParseErrorKind) -> ParseError {
59        ParseError {
60            kind,
61            span: self.span(),
62        }
63    }
64
65    fn peek(&self) -> Option<u8> {
66        self.input.get(self.offset).copied()
67    }
68
69    fn bump(&mut self) -> Option<u8> {
70        let byte = self.peek()?;
71        self.offset += 1;
72
73        if byte == b'\n' {
74            self.line += 1;
75            self.column = 1;
76        } else {
77            self.column += 1;
78        }
79
80        Some(byte)
81    }
82
83    fn hex_digit(byte: u8) -> Option<u8> {
84        match byte {
85            b'0'..=b'9' => Some(byte - b'0'),
86            b'a'..=b'f' => Some(byte - b'a' + 10),
87            b'A'..=b'F' => Some(byte - b'A' + 10),
88            _ => None,
89        }
90    }
91
92    fn skip_ignored(&mut self) {
93        loop {
94            //skip whitespace
95            while matches!(self.peek(), Some(b' ' | b'\n' | b'\r' | b'\t')) {
96                self.bump();
97            }
98
99            // line comment with //
100            if self.peek() == Some(b'/') && self.input.get(self.offset + 1) == Some(&b'/') {
101                self.bump();
102                self.bump();
103                while let Some(current_byte) = self.peek() {
104                    self.bump();
105                    if current_byte == b'\n' {
106                        break;
107                    }
108                }
109                continue;
110            }
111
112            // line comment with #
113            if self.peek() == Some(b'#') {
114                self.bump();
115                while let Some(current_byte) = self.peek() {
116                    self.bump();
117                    if current_byte == b'\n' {
118                        break;
119                    }
120                }
121                continue;
122            }
123
124            break;
125        }
126    }
127
128    fn lex_identifier(&mut self) -> TokenKind {
129        let start = self.offset;
130        self.bump();
131
132        while matches!(
133            self.peek(),
134            Some(b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_')
135        ) {
136            self.bump();
137        }
138
139        let ident = std::str::from_utf8(&self.input[start..self.offset])
140            .unwrap()
141            .to_string();
142
143        match ident.as_str() {
144            "null" => TokenKind::Null,
145            "true" => TokenKind::True,
146            "false" => TokenKind::False,
147            _ => TokenKind::Ident(ident),
148        }
149    }
150
151    fn lex_int(&mut self) -> Result<TokenKind, ParseError> {
152        let start = self.offset;
153
154        // opt leading '-'
155        if self.peek() == Some(b'-') {
156            self.bump();
157        }
158
159        let mut saw_digit = false;
160
161        while matches!(self.peek(), Some(b'0'..=b'9')) {
162            saw_digit = true;
163            self.bump();
164        }
165
166        // must have at least 1 digit
167        if !saw_digit {
168            return Err(self.error(ParseErrorKind::IntegerOutOfRange));
169        }
170
171        let slice = &self.input[start..self.offset];
172        let text = std::str::from_utf8(slice)
173            .map_err(|_| self.error(ParseErrorKind::IntegerOutOfRange))?;
174
175        let parsed_value = text
176            .parse::<i64>()
177            .map_err(|_| self.error(ParseErrorKind::IntegerOutOfRange))?;
178
179        Ok(TokenKind::Int(parsed_value))
180    }
181
182    fn lex_bytes(&mut self) -> Result<TokenKind, ParseError> {
183        // must start with 0x
184        if self.peek() != Some(b'0') || self.input.get(self.offset + 1) != Some(&b'x') {
185            return Err(self.error(ParseErrorKind::MalformedBytesLiteral));
186        }
187
188        // consume 0x
189        self.bump();
190        self.bump();
191
192        let hex_start = self.offset;
193
194        while matches!(self.peek(), Some(b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F')) {
195            self.bump();
196        }
197
198        let hex_len = self.offset - hex_start;
199
200        // must have even number of hex digits and at least one byte
201        if hex_len == 0 || !hex_len.is_multiple_of(2) {
202            return Err(self.error(ParseErrorKind::MalformedBytesLiteral));
203        }
204
205        let hex = &self.input[hex_start..self.offset];
206        let mut bytes = Vec::with_capacity(hex_len / 2);
207
208        for i in (0..hex_len).step_by(2) {
209            let high_nibble = Self::hex_digit(hex[i])
210                .ok_or_else(|| self.error(ParseErrorKind::MalformedBytesLiteral))?;
211
212            let low_nibble = Self::hex_digit(hex[i + 1])
213                .ok_or_else(|| self.error(ParseErrorKind::MalformedBytesLiteral))?;
214
215            bytes.push((high_nibble << 4) | low_nibble);
216        }
217
218        Ok(TokenKind::Bytes(bytes))
219    }
220
221    fn lex_string(&mut self) -> Result<TokenKind, ParseError> {
222        self.bump(); // opening '"'
223
224        let mut out = String::new();
225
226        while let Some(current_byte) = self.peek() {
227            match current_byte {
228                b'"' => {
229                    // closing quote
230                    self.bump();
231                    return Ok(TokenKind::String(out));
232                }
233
234                b'\\' => {
235                    //escape seq
236                    self.bump();
237                    let escape_char = self
238                        .bump()
239                        .ok_or_else(|| self.error(ParseErrorKind::MalformedBytesLiteral))?;
240
241                    match escape_char {
242                        b'"' => out.push('"'),
243                        b'\\' => out.push('\\'),
244                        b'n' => out.push('\n'),
245                        b'r' => out.push('\r'),
246                        b't' => out.push('\t'),
247
248                        b'u' => {
249                            // \uXXXX
250                            let mut codepoint = 0u32;
251
252                            for _ in 0..4 {
253                                let hex_byte = self.bump().ok_or_else(|| {
254                                    self.error(ParseErrorKind::MalformedBytesLiteral)
255                                })?;
256                                let digit = Self::hex_digit(hex_byte).ok_or_else(|| {
257                                    self.error(ParseErrorKind::MalformedBytesLiteral)
258                                })?;
259                                codepoint = (codepoint << 4) | (digit as u32);
260                            }
261
262                            let unicode_char = char::from_u32(codepoint)
263                                .ok_or_else(|| self.error(ParseErrorKind::MalformedBytesLiteral))?;
264                            out.push(unicode_char);
265                        }
266                        _ => return Err(self.error(ParseErrorKind::MalformedBytesLiteral)),
267                    }
268                }
269
270                b'\n' | b'\r' => {
271                    // strings cant span lines
272                    return Err(self.error(ParseErrorKind::MalformedBytesLiteral));
273                }
274
275                _ => {
276                    if current_byte >= 0x80 {
277                        return Err(self.error(ParseErrorKind::MalformedBytesLiteral));
278                    }
279
280                    out.push(current_byte as char);
281                    self.bump();
282                }
283            }
284        }
285
286        Err(self.error(ParseErrorKind::MalformedBytesLiteral))
287    }
288
289    pub fn next_token(&mut self) -> Result<Token, ParseError> {
290        self.skip_ignored();
291
292        let start = self.span();
293
294        let current_byte = match self.peek() {
295            Some(b) => b,
296            None => {
297                return Ok(Token {
298                    kind: TokenKind::EOF,
299                    span: start,
300                });
301            }
302        };
303
304        let kind = match current_byte {
305            b'{' => {
306                self.bump();
307                TokenKind::LBrace
308            }
309            b'}' => {
310                self.bump();
311                TokenKind::RBrace
312            }
313            b'[' => {
314                self.bump();
315                TokenKind::LBracket
316            }
317            b']' => {
318                self.bump();
319                TokenKind::RBracket
320            }
321            b':' => {
322                self.bump();
323                TokenKind::Colon
324            }
325            b',' => {
326                self.bump();
327                TokenKind::Comma
328            }
329
330            // string literal
331            b'"' => self.lex_string()?,
332
333            // bytes literal
334            b'0' if self.input.get(self.offset + 1) == Some(&b'x') => self.lex_bytes()?,
335
336            // integer literal
337            b'-' | b'0'..=b'9' => self.lex_int()?,
338
339            // identifier or keyword
340            b'a'..=b'z' | b'A'..=b'Z' | b'_' => self.lex_identifier(),
341
342            _ => {
343                return Err(self.error(ParseErrorKind::UnexpectedToken {
344                    expected: "valid token",
345                    found: "invalid character",
346                }));
347            }
348        };
349
350        Ok(Token { kind, span: start })
351    }
352}