Skip to main content

sage_parser/
lexer.rs

1//! Lexer implementation and public API.
2
3use crate::Token;
4use logos::Logos;
5use miette::{Diagnostic, LabeledSpan, NamedSource, SourceCode};
6use std::sync::Arc;
7use thiserror::Error;
8
9/// A token with its source span.
10#[derive(Debug, Clone, PartialEq)]
11pub struct Spanned {
12    /// The token.
13    pub token: Token,
14    /// Start byte offset.
15    pub start: usize,
16    /// End byte offset (exclusive).
17    pub end: usize,
18}
19
20impl Spanned {
21    /// Create a new spanned token.
22    #[must_use]
23    pub fn new(token: Token, start: usize, end: usize) -> Self {
24        Self { token, start, end }
25    }
26
27    /// Get the length of this token in bytes.
28    #[must_use]
29    pub fn len(&self) -> usize {
30        self.end - self.start
31    }
32
33    /// Check if this token span is empty.
34    #[must_use]
35    pub fn is_empty(&self) -> bool {
36        self.len() == 0
37    }
38
39    /// Get the span as a tuple.
40    #[must_use]
41    pub fn span(&self) -> (usize, usize) {
42        (self.start, self.end)
43    }
44}
45
46/// A single lexer error at a specific location.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct LexErrorLocation {
49    /// Start byte offset of the invalid token.
50    pub start: usize,
51    /// End byte offset of the invalid token.
52    pub end: usize,
53    /// The invalid text that couldn't be lexed.
54    pub text: String,
55}
56
57/// Error type for lexer failures.
58#[derive(Error, Debug)]
59#[error("failed to lex source: {count} error(s) found")]
60pub struct LexError {
61    /// The source code being lexed.
62    source_code: NamedSource<String>,
63
64    /// All error locations found during lexing.
65    pub errors: Vec<LexErrorLocation>,
66
67    /// Number of errors for the error message.
68    count: usize,
69}
70
71impl LexError {
72    /// Create a new lex error.
73    fn new(source: String, errors: Vec<LexErrorLocation>) -> Self {
74        let count = errors.len();
75        Self {
76            source_code: NamedSource::new("<input>", source),
77            errors,
78            count,
79        }
80    }
81
82    /// Create a new lex error with a filename.
83    #[must_use]
84    pub fn with_filename(mut self, filename: impl Into<String>) -> Self {
85        let source = self.source_code.inner().clone();
86        self.source_code = NamedSource::new(filename.into(), source);
87        self
88    }
89}
90
91impl Diagnostic for LexError {
92    fn code<'a>(&'a self) -> Option<Box<dyn std::fmt::Display + 'a>> {
93        Some(Box::new("sage::lexer::E001"))
94    }
95
96    fn help<'a>(&'a self) -> Option<Box<dyn std::fmt::Display + 'a>> {
97        Some(Box::new("remove or replace invalid characters"))
98    }
99
100    fn source_code(&self) -> Option<&dyn SourceCode> {
101        Some(&self.source_code)
102    }
103
104    fn labels(&self) -> Option<Box<dyn Iterator<Item = LabeledSpan> + '_>> {
105        let labels = self.errors.iter().map(|e| {
106            LabeledSpan::new_with_span(
107                Some(format!("invalid token `{}`", e.text)),
108                (e.start, e.end - e.start),
109            )
110        });
111        Some(Box::new(labels))
112    }
113}
114
115/// Result of lexing source code.
116#[derive(Debug)]
117pub struct LexResult {
118    /// Successfully lexed tokens.
119    pub tokens: Vec<Spanned>,
120    /// The source code (for error reporting).
121    pub source: Arc<str>,
122}
123
124impl LexResult {
125    /// Get the tokens as a slice.
126    #[must_use]
127    pub fn tokens(&self) -> &[Spanned] {
128        &self.tokens
129    }
130
131    /// Consume self and return just the tokens.
132    #[must_use]
133    pub fn into_tokens(self) -> Vec<Spanned> {
134        self.tokens
135    }
136}
137
138/// Lex source code into tokens.
139///
140/// This function tokenizes the entire source, collecting all errors rather than
141/// stopping at the first one. If any errors are found, they are all reported.
142///
143/// # Arguments
144///
145/// * `source` - The source code to lex.
146///
147/// # Returns
148///
149/// * `Ok(LexResult)` - Successfully lexed tokens with source reference.
150/// * `Err(LexError)` - One or more lex errors occurred.
151///
152/// # Errors
153///
154/// Returns `LexError` if the source contains invalid characters that cannot
155/// be tokenized. All errors are collected and reported together.
156///
157/// # Example
158///
159/// ```
160/// use sage_parser::{lex, Token};
161///
162/// let result = lex("let x = 42").unwrap();
163/// assert_eq!(result.tokens()[0].token, Token::KwLet);
164/// ```
165pub fn lex(source: &str) -> Result<LexResult, LexError> {
166    let source_arc: Arc<str> = Arc::from(source);
167    let mut tokens = Vec::new();
168    let mut errors = Vec::new();
169
170    let lexer = Token::lexer(source);
171
172    for (result, span) in lexer.spanned() {
173        if let Ok(token) = result {
174            tokens.push(Spanned::new(token, span.start, span.end));
175        } else {
176            let text = source[span.start..span.end].to_string();
177            errors.push(LexErrorLocation {
178                start: span.start,
179                end: span.end,
180                text,
181            });
182        }
183    }
184
185    if errors.is_empty() {
186        Ok(LexResult {
187            tokens,
188            source: source_arc,
189        })
190    } else {
191        Err(LexError::new(source.to_string(), errors))
192    }
193}
194
195/// Lex source code, returning tokens even if there are errors.
196///
197/// This is useful for editor tooling where you want partial results.
198/// Errors are collected but don't prevent returning valid tokens.
199///
200/// # Returns
201///
202/// A tuple of (tokens, errors). The tokens vector contains all valid tokens
203/// found, and the errors vector contains all lex errors encountered.
204#[must_use]
205pub fn lex_partial(source: &str) -> (Vec<Spanned>, Vec<LexErrorLocation>) {
206    let mut tokens = Vec::new();
207    let mut errors = Vec::new();
208
209    let lexer = Token::lexer(source);
210
211    for (result, span) in lexer.spanned() {
212        if let Ok(token) = result {
213            tokens.push(Spanned::new(token, span.start, span.end));
214        } else {
215            let text = source[span.start..span.end].to_string();
216            errors.push(LexErrorLocation {
217                start: span.start,
218                end: span.end,
219                text,
220            });
221        }
222    }
223
224    (tokens, errors)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn lex_simple_tokens() {
233        let result = lex("let x = 42").unwrap();
234        let tokens = result.tokens();
235
236        assert_eq!(tokens.len(), 4);
237        assert_eq!(tokens[0].token, Token::KwLet);
238        assert_eq!(tokens[1].token, Token::Ident);
239        assert_eq!(tokens[2].token, Token::Eq);
240        assert_eq!(tokens[3].token, Token::IntLit);
241    }
242
243    #[test]
244    fn lex_preserves_spans() {
245        let result = lex("let x").unwrap();
246        let tokens = result.tokens();
247
248        assert_eq!(tokens[0].start, 0);
249        assert_eq!(tokens[0].end, 3); // "let"
250        assert_eq!(tokens[1].start, 4);
251        assert_eq!(tokens[1].end, 5); // "x"
252    }
253
254    #[test]
255    fn lex_empty_source() {
256        let result = lex("").unwrap();
257        assert!(result.tokens().is_empty());
258    }
259
260    #[test]
261    fn lex_whitespace_only() {
262        let result = lex("   \n\t  ").unwrap();
263        assert!(result.tokens().is_empty());
264    }
265
266    #[test]
267    fn lex_comments_only() {
268        let result = lex("// this is a comment").unwrap();
269        assert!(result.tokens().is_empty());
270    }
271
272    #[test]
273    fn lex_error_invalid_char() {
274        let err = lex("let # = 42").unwrap_err();
275
276        assert_eq!(err.errors.len(), 1);
277        assert_eq!(err.errors[0].text, "#");
278        assert_eq!(err.errors[0].start, 4);
279        assert_eq!(err.errors[0].end, 5);
280    }
281
282    #[test]
283    fn lex_error_multiple_invalid() {
284        let err = lex("let # x $ y").unwrap_err();
285
286        assert_eq!(err.errors.len(), 2);
287        assert_eq!(err.errors[0].text, "#");
288        assert_eq!(err.errors[1].text, "$");
289    }
290
291    #[test]
292    fn lex_partial_with_errors() {
293        let (tokens, errors) = lex_partial("let # x = 42");
294
295        // Should have valid tokens
296        assert_eq!(tokens.len(), 4); // let, x, =, 42
297        assert_eq!(tokens[0].token, Token::KwLet);
298
299        // And the error
300        assert_eq!(errors.len(), 1);
301        assert_eq!(errors[0].text, "#");
302    }
303
304    #[test]
305    fn lex_agent_declaration() {
306        let source = r#"
307agent Researcher {
308    belief topic: String
309
310    on start {
311        let result = infer("test")
312        emit(result)
313    }
314}
315
316run Researcher
317"#;
318        let result = lex(source).unwrap();
319        let tokens = result.tokens();
320
321        // Verify key tokens are present
322        assert_eq!(tokens[0].token, Token::KwAgent);
323        assert_eq!(tokens[1].token, Token::Ident);
324        assert_eq!(tokens[2].token, Token::LBrace);
325        assert_eq!(tokens[3].token, Token::KwBelief);
326    }
327
328    #[test]
329    fn lex_result_into_tokens() {
330        let result = lex("42").unwrap();
331        let tokens = result.into_tokens();
332
333        assert_eq!(tokens.len(), 1);
334        assert_eq!(tokens[0].token, Token::IntLit);
335    }
336
337    #[test]
338    fn spanned_len() {
339        let spanned = Spanned::new(Token::KwLet, 0, 3);
340        assert_eq!(spanned.len(), 3);
341        assert!(!spanned.is_empty());
342    }
343
344    #[test]
345    fn spanned_span() {
346        let spanned = Spanned::new(Token::KwLet, 5, 8);
347        assert_eq!(spanned.span(), (5, 8));
348    }
349
350    #[test]
351    fn lex_error_display() {
352        let err = lex("#").unwrap_err();
353        let display = format!("{err}");
354        assert!(display.contains("failed to lex"));
355    }
356
357    #[test]
358    fn lex_string_interpolation_markers() {
359        // String literals with {ident} patterns should lex as single string tokens
360        let result = lex(r#""Hello {name}!""#).unwrap();
361        let tokens = result.tokens();
362
363        assert_eq!(tokens.len(), 1);
364        assert_eq!(tokens[0].token, Token::StringLit);
365    }
366
367    #[test]
368    fn lex_complex_expression() {
369        let result = lex("a + b * c == d && e || !f").unwrap();
370        let tokens = result.tokens();
371
372        assert_eq!(tokens.len(), 12);
373        assert_eq!(tokens[0].token, Token::Ident); // a
374        assert_eq!(tokens[1].token, Token::Plus); // +
375        assert_eq!(tokens[2].token, Token::Ident); // b
376        assert_eq!(tokens[3].token, Token::Star); // *
377        assert_eq!(tokens[4].token, Token::Ident); // c
378        assert_eq!(tokens[5].token, Token::EqEq); // ==
379        assert_eq!(tokens[6].token, Token::Ident); // d
380        assert_eq!(tokens[7].token, Token::And); // &&
381        assert_eq!(tokens[8].token, Token::Ident); // e
382        assert_eq!(tokens[9].token, Token::Or); // ||
383        assert_eq!(tokens[10].token, Token::Bang); // !
384        assert_eq!(tokens[11].token, Token::Ident); // f
385    }
386
387    #[test]
388    fn lex_list_type() {
389        let result = lex("List<String>").unwrap();
390        let tokens = result.tokens();
391
392        assert_eq!(tokens.len(), 4);
393        assert_eq!(tokens[0].token, Token::TyList);
394        assert_eq!(tokens[1].token, Token::Lt);
395        assert_eq!(tokens[2].token, Token::TyString);
396        assert_eq!(tokens[3].token, Token::Gt);
397    }
398
399    #[test]
400    fn lex_error_with_filename() {
401        let err = lex("#").unwrap_err().with_filename("test.sg");
402        // The error should still work
403        assert_eq!(err.errors.len(), 1);
404    }
405
406    #[test]
407    fn lex_error_is_diagnostic() {
408        use miette::Diagnostic;
409
410        let err = lex("#").unwrap_err();
411
412        // Check that Diagnostic trait methods work
413        assert!(err.code().is_some());
414        assert!(err.help().is_some());
415        assert!(err.source_code().is_some());
416        assert!(err.labels().is_some());
417    }
418}