cas_parser/parser/ast/
literal.rs

1use cas_error::Error;
2use crate::{
3    parser::{
4        ast::{expr::Expr, helper::{SquareDelimited, Surrounded}},
5        error::{EmptyRadixLiteral, InvalidRadixBase, InvalidRadixDigit, UnexpectedToken},
6        fmt::Latex,
7        token::{
8            Boolean,
9            CloseParen,
10            Name,
11            Int,
12            OpenParen,
13            OpenSquare,
14            Semicolon,
15            Quote,
16        },
17        Parse,
18        Parser,
19        ParseResult,
20    },
21    tokenizer::TokenKind,
22    return_if_ok,
23};
24use std::{collections::HashSet, fmt, ops::Range};
25
26#[cfg(feature = "serde")]
27use serde::{Deserialize, Serialize};
28
29/// An integer literal, representing as a [`String`].
30#[derive(Debug, Clone, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32pub struct LitInt {
33    /// The value of the integer literal as a string.
34    pub value: String,
35
36    /// The region of the source code that this literal was parsed from.
37    pub span: Range<usize>,
38}
39
40impl std::fmt::Display for LitInt {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        write!(f, "{}", self.value)
43    }
44}
45
46impl Latex for LitInt {
47    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        write!(f, "{}", self.value)
49    }
50}
51
52/// A floating-point literal, represented as a [`String`].
53#[derive(Debug, Clone, PartialEq, Eq)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub struct LitFloat {
56    /// The value of the floating-point literal as a string.
57    pub value: String,
58
59    /// The region of the source code that this literal was parsed from.
60    pub span: Range<usize>,
61}
62
63impl std::fmt::Display for LitFloat {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        write!(f, "{}", self.value)
66    }
67}
68
69impl Latex for LitFloat {
70    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        write!(f, "{}", self.value)
72    }
73}
74
75/// Parse either an integer or a floating-point number.
76fn parse_ascii_number(input: &mut Parser) -> Result<Literal, Vec<Error>> {
77    struct Num {
78        value: String,
79        span: Range<usize>,
80        has_digits: bool,
81        has_decimal: bool,
82    }
83
84    let mut result: Option<Num> = None;
85    input.advance_past_whitespace();
86
87    while let Ok(token) = input.next_token_raw() {
88        match token.kind {
89            TokenKind::Int => {
90                let mut has_decimal = false;
91                result = result.map_or_else(|| {
92                    Some(Num {
93                        value: token.lexeme.to_owned(),
94                        span: token.span.clone(),
95                        has_digits: true,
96                        has_decimal: false,
97                    })
98                }, |mut num| {
99                    num.value.push_str(token.lexeme);
100                    num.span.end = token.span.end;
101                    num.has_digits = true;
102                    has_decimal = num.has_decimal;
103                    Some(num)
104                });
105
106                if has_decimal {
107                    break;
108                }
109            },
110            TokenKind::Dot => {
111                result = result.map_or_else(|| {
112                    Some(Num {
113                        value: ".".to_owned(),
114                        span: token.span.clone(),
115                        has_digits: false,
116                        has_decimal: true,
117                    })
118                }, |mut num| {
119                    num.value.push('.');
120                    num.span.end = token.span.end;
121                    num.has_decimal = true;
122                    Some(num)
123                });
124            },
125            _ => {
126                input.prev();
127                break;
128            },
129        }
130    }
131
132    let num = result.ok_or_else(|| {
133        // clone to emulate peeking
134        let mut input_ahead = input.clone();
135        match input_ahead.next_token() {
136            Ok(token) => vec![Error::new(vec![token.span], UnexpectedToken {
137                expected: &[TokenKind::Int],
138                found: token.kind,
139            })],
140            Err(e) => vec![e],
141        }
142    })?;
143
144    if !num.has_digits {
145        // could have only encountered a single dot for `num` to be `Some`, yet have no digits
146        return Err(vec![Error::new(vec![num.span.clone()], UnexpectedToken {
147            expected: &[TokenKind::Int],
148            found: TokenKind::Dot,
149        })]);
150    }
151
152    if num.has_decimal {
153        Ok(Literal::Float(LitFloat {
154            value: num.value,
155            span: num.span,
156        }))
157    } else {
158        Ok(Literal::Integer(LitInt {
159            value: num.value,
160            span: num.span,
161        }))
162    }
163}
164
165/// The digits in base 64, in order of increasing value.
166pub const DIGITS: [char; 64] = [
167    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
168    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
169    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
170    '+', '/',
171];
172
173/// Helper struct to parse the digits used in various bases.
174#[derive(Debug, Clone, PartialEq, Eq)]
175struct RadixWord {
176    /// The parsed digits.
177    pub value: String,
178
179    /// The region of the source code that this literal was parsed from.
180    pub span: Range<usize>,
181}
182
183impl RadixWord {
184    fn parse(input: &mut Parser) -> Self {
185        let mut value = String::new();
186        let mut span = 0..0;
187        while let Ok(token) = input.next_token_raw() {
188            match token.kind {
189                TokenKind::Add
190                    | TokenKind::Name
191                    | TokenKind::Int
192                    | TokenKind::Div => value.push_str(token.lexeme),
193                _ => {
194                    input.prev();
195                    break;
196                },
197            }
198
199            if span.start == 0 {
200                span.start = token.span.start;
201            }
202            span.end = token.span.end;
203        }
204
205        Self {
206            value,
207            span,
208        }
209    }
210}
211
212/// Helper function to ensure the given string represents a valid base for radix notation.
213fn validate_radix_base(num: &Int) -> ParseResult<u8> {
214    match num.lexeme.parse() {
215        Ok(base) if (2..=64).contains(&base) => ParseResult::Ok(base),
216        Ok(base) if base < 2 => ParseResult::Recoverable(
217            64, // use base 64 to limit invalid radix digit errors
218            vec![Error::new(vec![num.span.clone()], InvalidRadixBase { too_large: false })],
219        ),
220        _ => ParseResult::Recoverable(
221            64,
222            vec![Error::new(vec![num.span.clone()], InvalidRadixBase { too_large: true })],
223        ),
224    }
225}
226
227/// A number written in radix notation. Radix notation allows users to express integers in a base
228/// other than base 10.
229#[derive(Debug, Clone, PartialEq, Eq)]
230#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
231pub struct LitRadix {
232    /// The radix of the literal. This value must be between 2 and 64, inclusive.
233    pub base: u8,
234
235    /// The number, expressed in the given radix.
236    pub value: String,
237
238    /// The region of the source code that this literal was parsed from.
239    pub span: Range<usize>,
240}
241
242impl<'source> Parse<'source> for LitRadix {
243    fn std_parse(
244        input: &mut Parser<'source>,
245        recoverable_errors: &mut Vec<Error>
246    ) -> Result<Self, Vec<Error>> {
247        let num = input.try_parse().forward_errors(recoverable_errors)?;
248        let quote = input.try_parse::<Quote>().forward_errors(recoverable_errors)?;
249
250        let base = validate_radix_base(&num).forward_errors(recoverable_errors)?;
251        let word = RadixWord::parse(input);
252        if word.value.is_empty() {
253            recoverable_errors.push(Error::new(vec![quote.span], EmptyRadixLiteral {
254                radix: base,
255                allowed: &DIGITS[..base as usize],
256            }));
257        }
258
259        // ensure that the number is valid for this radix
260        let allowed_digits = &DIGITS[..base as usize];
261        let mut bad_digits = HashSet::new();
262        let mut bad_digit_spans: Vec<Range<usize>> = Vec::new();
263        for (i, c) in word.value.chars().enumerate() {
264            // if we find a digit that isn't allowed, that is fatal
265            // but continue to find all the bad digits so we can report them all at once
266            if !allowed_digits.contains(&c) {
267                let char_start = word.span.start + i;
268                if let Some(last_span) = bad_digit_spans.last_mut() {
269                    // merge adjacent spans
270                    if last_span.end == char_start {
271                        last_span.end += 1;
272                    } else {
273                        bad_digit_spans.push(char_start..char_start + 1);
274                    }
275                } else {
276                    bad_digit_spans.push(char_start..char_start + 1);
277                }
278
279                bad_digits.insert(c);
280                continue;
281            }
282        }
283
284        if !bad_digit_spans.is_empty() {
285            recoverable_errors.push(Error::new(bad_digit_spans, InvalidRadixDigit {
286                radix: base,
287                allowed: allowed_digits,
288                digits: bad_digits,
289                last_op_digit: {
290                    if let Some(ch) = word.value.chars().last() {
291                        ['+', '/'].into_iter()
292                            .find(|&op| op == ch)
293                            .map(|op| (op, word.span.end - 1..word.span.end))
294                    } else {
295                        None
296                    }
297                },
298            }));
299        }
300
301        Ok(Self {
302            base,
303            value: word.value,
304            span: num.span.start..word.span.end,
305        })
306    }
307}
308
309impl std::fmt::Display for LitRadix {
310    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
311        write!(f, "{}'{}", self.base, self.value)
312    }
313}
314
315impl Latex for LitRadix {
316    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
317        write!(f, "{}'{}", self.base, self.value)
318    }
319}
320
321/// A boolean literal, either `true` or `false`.
322#[derive(Debug, Clone, PartialEq, Eq)]
323#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
324pub struct LitBool {
325    /// The value of the boolean literal.
326    pub value: bool,
327
328    /// The region of the source code that this literal was parsed from.
329    pub span: Range<usize>,
330}
331
332impl<'source> Parse<'source> for LitBool {
333    fn std_parse(
334        input: &mut Parser<'source>,
335        recoverable_errors: &mut Vec<Error>
336    ) -> Result<Self, Vec<Error>> {
337        input.try_parse::<Boolean>()
338            .map(|boolean| Self {
339                value: match boolean.lexeme {
340                    "true" => true,
341                    "false" => false,
342                    _ => unreachable!(),
343                },
344                span: boolean.span,
345            })
346            .forward_errors(recoverable_errors)
347    }
348}
349
350impl std::fmt::Display for LitBool {
351    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
352        write!(f, "{}", self.value)
353    }
354}
355
356impl Latex for LitBool {
357    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
358        write!(f, "{}", self.value)
359    }
360}
361
362/// A symbol / identifier literal. Symbols are used to represent variables and functions.
363#[derive(Debug, Clone, PartialEq, Eq)]
364#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
365pub struct LitSym {
366    /// The name of the symbol.
367    pub name: String,
368
369    /// The region of the source code that this literal was parsed from.
370    pub span: Range<usize>,
371}
372
373impl<'source> Parse<'source> for LitSym {
374    fn std_parse(
375        input: &mut Parser<'source>,
376        recoverable_errors: &mut Vec<Error>
377    ) -> Result<Self, Vec<Error>> {
378        // TODO: it would be nice if we could report an error if the symbol is a keyword
379        //
380        // for example:
381        // break(x) = x
382        // ^^^^^ error: `break` is a keyword and cannot be used as a symbol
383        //
384        // unfortunately this is hard since CalcScript is context-sensitive and we would have to
385        // to parse further ahead to determine if this error should be reported
386        // maybe we should require a `let` keyword to declare variables?
387        input.try_parse::<Name>()
388            .map(|name| Self {
389                name: name.lexeme.to_owned(),
390                span: name.span,
391            })
392            .forward_errors(recoverable_errors)
393    }
394}
395
396impl std::fmt::Display for LitSym {
397    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
398        write!(f, "{}", self.name)
399    }
400}
401
402impl Latex for LitSym {
403    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
404        match self.name.as_str() {
405            "tau" | "pi" | "phi" | "theta" => write!(f, "\\{} ", self.name),
406            _ => write!(f, "{}", self.name),
407        }
408    }
409}
410
411/// The unit type, written as `()`. The unit type is by-default returned by functions that do not
412/// return a value.
413#[derive(Debug, Clone, PartialEq, Eq)]
414#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
415pub struct LitUnit {
416    /// The region of the source code that this literal was parsed from.
417    pub span: Range<usize>,
418}
419
420impl<'source> Parse<'source> for LitUnit {
421    fn std_parse(
422        input: &mut Parser<'source>,
423        recoverable_errors: &mut Vec<Error>
424    ) -> Result<Self, Vec<Error>> {
425        let open = input.try_parse::<OpenParen>().forward_errors(recoverable_errors)?;
426        let close = input.try_parse::<CloseParen>().forward_errors(recoverable_errors)?;
427        Ok(Self {
428            span: open.span.start..close.span.end,
429        })
430    }
431}
432
433impl std::fmt::Display for LitUnit {
434    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
435        write!(f, "()")
436    }
437}
438
439impl Latex for LitUnit {
440    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
441        write!(f, "()")
442    }
443}
444
445/// The list type, consisting of a list of expressions surrounded by square brackets and delimited by
446/// commas: `[expr1, expr2, ...]`.
447#[derive(Debug, Clone, PartialEq, Eq)]
448#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
449pub struct LitList {
450    /// The list of expressions.
451    pub values: Vec<Expr>,
452
453    /// The region of the source code that this literal was parsed from.
454    pub span: Range<usize>,
455}
456
457impl<'source> Parse<'source> for LitList {
458    fn std_parse(
459        input: &mut Parser<'source>,
460        recoverable_errors: &mut Vec<Error>
461    ) -> Result<Self, Vec<Error>> {
462        let surrounded = input.try_parse::<SquareDelimited<_>>().forward_errors(recoverable_errors)?;
463
464        Ok(Self {
465            values: surrounded.value.values,
466            span: surrounded.open.span.start..surrounded.close.span.end,
467        })
468    }
469}
470
471impl std::fmt::Display for LitList {
472    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
473        write!(f, "[")?;
474        for (i, value) in self.values.iter().enumerate() {
475            if i > 0 {
476                write!(f, ", ")?;
477            }
478            write!(f, "{}", value)?;
479        }
480        write!(f, "]")
481    }
482}
483
484impl Latex for LitList {
485    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
486        write!(f, "[")?;
487        for (i, value) in self.values.iter().enumerate() {
488            if i > 0 {
489                write!(f, ", ")?;
490            }
491            value.fmt_latex(f)?;
492        }
493        write!(f, "]")
494    }
495}
496
497/// The list type, formed by repeating the given expression `n` times: `[expr; n]`.
498#[derive(Debug, Clone, PartialEq, Eq)]
499#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
500pub struct LitListRepeat {
501    /// The expression to repeat.
502    pub value: Box<Expr>,
503
504    /// The number of times to repeat the expression.
505    pub count: Box<Expr>,
506
507    /// The region of the source code that this literal was parsed from.
508    pub span: Range<usize>,
509}
510
511impl<'source> Parse<'source> for LitListRepeat {
512    fn std_parse(
513        input: &mut Parser<'source>,
514        recoverable_errors: &mut Vec<Error>
515    ) -> Result<Self, Vec<Error>> {
516        /// Inner struct representing the contents of a repeated list so that we can use the
517        /// [`Surrounded`] helper with it.
518        #[derive(Debug)]
519        struct LitListRepeatInner {
520            /// The expression to repeat.
521            value: Expr,
522
523            /// The number of times to repeat the expression.
524            count: Expr,
525        }
526
527        impl<'source> Parse<'source> for LitListRepeatInner {
528            fn std_parse(
529                input: &mut Parser<'source>,
530                recoverable_errors: &mut Vec<Error>
531            ) -> Result<Self, Vec<Error>> {
532                let value = input.try_parse().forward_errors(recoverable_errors)?;
533                input.try_parse::<Semicolon>().forward_errors(recoverable_errors)?;
534                let count = input.try_parse().forward_errors(recoverable_errors)?;
535                Ok(Self { value, count })
536            }
537        }
538
539        let inner = input.try_parse::<Surrounded<OpenSquare, LitListRepeatInner>>().forward_errors(recoverable_errors)?;
540
541        Ok(Self {
542            value: Box::new(inner.value.value),
543            count: Box::new(inner.value.count),
544            span: inner.open.span.start..inner.close.span.end,
545        })
546    }
547}
548
549impl std::fmt::Display for LitListRepeat {
550    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
551        write!(f, "[{}; {}]", self.value, self.count)
552    }
553}
554
555impl Latex for LitListRepeat {
556    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
557        write!(f, "[{}; {}]", self.value, self.count)
558    }
559}
560
561/// Represents a literal value in CalcScript.
562///
563/// A literal is any value that can is written directly into the source code. For example, the
564/// number `1` is a literal (it is currently the only literal type supported by CalcScript).
565#[derive(Debug, Clone, PartialEq, Eq)]
566#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
567pub enum Literal {
568    /// An integer literal.
569    Integer(LitInt),
570
571    /// A floating-point literal.
572    Float(LitFloat),
573
574    /// A number written in radix notation. Radix notation allows users to express integers in a
575    /// base other than base 10.
576    Radix(LitRadix),
577
578    /// A boolean literal, either `true` or `false`.
579    Boolean(LitBool),
580
581    /// A symbol / identifier literal. Symbols are used to represent variables and functions.
582    Symbol(LitSym),
583
584    /// The unit type, written as `()`. The unit type is by-default returned by functions that do
585    /// not return a value.
586    Unit(LitUnit),
587
588    /// The list type, consisting of a list of expressions surrounded by square brackets and
589    /// delimited by commas: `[expr1, expr2, ...]`.
590    List(LitList),
591
592    /// The list type, formed by repeating the given expression `n` times: `[expr; n]`.
593    ListRepeat(LitListRepeat),
594}
595
596impl Literal {
597    /// Returns the span of the literal.
598    pub fn span(&self) -> Range<usize> {
599        match self {
600            Literal::Integer(int) => int.span.clone(),
601            Literal::Float(float) => float.span.clone(),
602            Literal::Radix(radix) => radix.span.clone(),
603            Literal::Boolean(boolean) => boolean.span.clone(),
604            Literal::Symbol(name) => name.span.clone(),
605            Literal::Unit(unit) => unit.span.clone(),
606            Literal::List(list) => list.span.clone(),
607            Literal::ListRepeat(repeat) => repeat.span.clone(),
608        }
609    }
610}
611
612impl<'source> Parse<'source> for Literal {
613    fn std_parse(
614        input: &mut Parser<'source>,
615        recoverable_errors: &mut Vec<Error>
616    ) -> Result<Self, Vec<Error>> {
617        let _ = return_if_ok!(input.try_parse().map(Literal::Boolean).forward_errors(recoverable_errors));
618        let _ = return_if_ok!(input.try_parse().map(Literal::Radix).forward_errors(recoverable_errors));
619        let _ = return_if_ok!(parse_ascii_number(input));
620        let _ = return_if_ok!(input.try_parse().map(Literal::Symbol).forward_errors(recoverable_errors));
621        let _ = return_if_ok!(input.try_parse().map(Literal::Unit).forward_errors(recoverable_errors));
622        let _ = return_if_ok!(input.try_parse().map(Literal::List).forward_errors(recoverable_errors));
623        input.try_parse().map(Literal::ListRepeat).forward_errors(recoverable_errors)
624    }
625}
626
627impl std::fmt::Display for Literal {
628    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
629        match self {
630            Literal::Integer(int) => int.fmt(f),
631            Literal::Float(float) => float.fmt(f),
632            Literal::Radix(radix) => radix.fmt(f),
633            Literal::Boolean(boolean) => boolean.fmt(f),
634            Literal::Symbol(name) => name.fmt(f),
635            Literal::Unit(unit) => unit.fmt(f),
636            Literal::List(list) => list.fmt(f),
637            Literal::ListRepeat(repeat) => repeat.fmt(f),
638        }
639    }
640}
641
642impl Latex for Literal {
643    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
644        match self {
645            Literal::Integer(int) => int.fmt_latex(f),
646            Literal::Float(float) => float.fmt_latex(f),
647            Literal::Radix(radix) => radix.fmt_latex(f),
648            Literal::Boolean(boolean) => boolean.fmt_latex(f),
649            Literal::Symbol(name) => name.fmt_latex(f),
650            Literal::Unit(unit) => unit.fmt_latex(f),
651            Literal::List(list) => list.fmt_latex(f),
652            Literal::ListRepeat(repeat) => repeat.fmt_latex(f),
653        }
654    }
655}