gpwgpu_core/
parser.rs

1use std::{borrow::Cow, collections::{HashMap, hash_map::Entry}, fmt::Write};
2
3use nom::{
4    branch::alt,
5    bytes::complete::{take_till, take_till1, take_until},
6    character::complete::{alpha1, alphanumeric1, char, line_ending, multispace0, space0},
7    combinator::{cut, map, opt, recognize},
8    error::{ErrorKind, ParseError},
9    multi::{many0, many0_count},
10    number::complete::double,
11    sequence::{delimited, pair, preceded, tuple},
12    IResult,
13};
14
15use nom_supreme::tag::complete::tag;
16
17pub type NomError<'a> = nom_supreme::error::ErrorTree<&'a str>;
18
19#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize, Clone)]
20pub enum Expr<'a> {
21    Bool(bool),
22    Num(f64),
23    Ident(Cow<'a, str>),
24
25    Neg(Box<Expr<'a>>),
26    Not(Box<Expr<'a>>),
27
28    Mul(Box<Expr<'a>>, Box<Expr<'a>>),
29    Div(Box<Expr<'a>>, Box<Expr<'a>>),
30
31    Add(Box<Expr<'a>>, Box<Expr<'a>>),
32    Sub(Box<Expr<'a>>, Box<Expr<'a>>),
33
34    LessThan(Box<Expr<'a>>, Box<Expr<'a>>),
35    GreaterThan(Box<Expr<'a>>, Box<Expr<'a>>),
36    LessThanOrEqual(Box<Expr<'a>>, Box<Expr<'a>>),
37    GreaterThanOrEqual(Box<Expr<'a>>, Box<Expr<'a>>),
38    Equal(Box<Expr<'a>>, Box<Expr<'a>>),
39    NotEqual(Box<Expr<'a>>, Box<Expr<'a>>),
40
41    And(Box<Expr<'a>>, Box<Expr<'a>>),
42
43    Or(Box<Expr<'a>>, Box<Expr<'a>>),
44}
45
46#[derive(Debug, thiserror::Error)]
47pub enum EvalError {
48    #[error("A number was encountered in an expression where a boolean was expected")]
49    NumberInLogic,
50    #[error("A boolean was encountered in an expression where a number was expected")]
51    BoolInMath,
52    #[error("The identifier was not found {}", 0)]
53    IdentNotFound(String),
54}
55
56impl<'a> Expr<'a> {
57    fn into_owned(self) -> Expr<'static> {
58        use Expr::*;
59        match self {
60            Bool(b) => Bool(b),
61            Num(n) => Num(n),
62            Ident(cow) => Ident(Cow::Owned(cow.into_owned())),
63            Neg(e) => Neg(Box::new(e.into_owned())),
64            Not(e) => Not(Box::new(e.into_owned())),
65            Mul(e1, e2) => Mul(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
66            Div(e1, e2) => Div(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
67            Add(e1, e2) => Add(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
68            Sub(e1, e2) => Sub(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
69            LessThan(e1, e2) => LessThan(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
70            GreaterThan(e1, e2) => {
71                GreaterThan(Box::new(e1.into_owned()), Box::new(e2.into_owned()))
72            }
73            LessThanOrEqual(e1, e2) => {
74                LessThanOrEqual(Box::new(e1.into_owned()), Box::new(e2.into_owned()))
75            }
76            GreaterThanOrEqual(e1, e2) => {
77                GreaterThanOrEqual(Box::new(e1.into_owned()), Box::new(e2.into_owned()))
78            }
79            Equal(e1, e2) => Equal(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
80            NotEqual(e1, e2) => NotEqual(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
81            And(e1, e2) => And(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
82            Or(e1, e2) => Or(Box::new(e1.into_owned()), Box::new(e2.into_owned())),
83        }
84    }
85}
86
87impl<'a> Expr<'a> {
88    pub fn simplify_without_ident(self) -> Result<Expr<'a>, EvalError> {
89        self.simplify(|ident| Some(Expr::Ident(ident.into())))
90    }
91
92    pub fn simplify(
93        self,
94        lookup: impl Fn(Cow<'a, str>) -> Option<Expr<'a>>,
95    ) -> Result<Expr<'a>, EvalError> {
96        self.simplify_internal(&lookup)
97    }
98
99    fn simplify_internal(
100        self,
101        lookup: &impl Fn(Cow<'a, str>) -> Option<Expr<'a>>,
102    ) -> Result<Expr<'a>, EvalError> {
103        use Expr::*;
104        let out = match self {
105            Bool(b) => Bool(b),
106            Num(n) => Num(n),
107            Ident(ref name) => {
108                let expr = lookup(name.clone())
109                    .ok_or_else(|| EvalError::IdentNotFound(name.to_string()))?;
110                let expr = if expr != self {
111                    expr.simplify_internal(lookup)?
112                } else {
113                    expr
114                };
115                expr
116            }
117
118            Neg(inner) => {
119                let inner = inner.simplify_internal(lookup)?;
120                match inner {
121                    Num(n) => Num(-n),
122                    Bool(_) => return Err(EvalError::BoolInMath),
123                    _ => Neg(Box::new(inner)),
124                }
125            }
126            Not(inner) => {
127                let inner = inner.simplify_internal(lookup)?;
128                match inner {
129                    Num(_) => return Err(EvalError::NumberInLogic),
130                    Bool(b) => Bool(!b),
131                    _ => Not(Box::new(inner)),
132                }
133            }
134
135            Mul(left, right) => {
136                let left = left.simplify_internal(lookup)?;
137                let right = right.simplify_internal(lookup)?;
138                match (left, right) {
139                    (Num(n), e) | (e, Num(n)) if n == 1.0 => e,
140                    (Num(n1), Num(n2)) => Num(n1 * n2),
141                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
142                    (left, right) => Mul(Box::new(left), Box::new(right)),
143                }
144            }
145            Div(left, right) => {
146                let left = left.simplify_internal(lookup)?;
147                let right = right.simplify_internal(lookup)?;
148                match (left, right) {
149                    (e, Num(n)) if n == 1.0 => e,
150                    (Num(n1), Num(n2)) => Num(n1 / n2),
151                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
152                    (left, right) => Div(Box::new(left), Box::new(right)),
153                }
154            }
155
156            Add(left, right) => {
157                let left = left.simplify_internal(lookup)?;
158                let right = right.simplify_internal(lookup)?;
159                match (left, right) {
160                    (Num(n), e) | (e, Num(n)) if n == 0.0 => e,
161                    (Num(n1), Num(n2)) => Num(n1 + n2),
162                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
163                    (left, right) => Add(Box::new(left), Box::new(right)),
164                }
165            }
166            Sub(left, right) => {
167                let left = left.simplify_internal(lookup)?;
168                let right = right.simplify_internal(lookup)?;
169                match (left, right) {
170                    (Num(n), e) if n == 0.0 => Neg(Box::new(e)),
171                    (e, Num(n)) if n == 0.0 => e,
172                    (Num(n1), Num(n2)) => Num(n1 - n2),
173                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
174                    (left, right) => Sub(Box::new(left), Box::new(right)),
175                }
176            }
177
178            LessThan(left, right) => {
179                let left = left.simplify_internal(lookup)?;
180                let right = right.simplify_internal(lookup)?;
181                match (left, right) {
182                    (Num(n1), Num(n2)) => Bool(n1 < n2),
183                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
184                    (left, right) => LessThan(Box::new(left), Box::new(right)),
185                }
186            }
187            GreaterThan(left, right) => {
188                let left = left.simplify_internal(lookup)?;
189                let right = right.simplify_internal(lookup)?;
190                match (left, right) {
191                    (Num(n1), Num(n2)) => Bool(n1 > n2),
192                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
193                    (left, right) => GreaterThan(Box::new(left), Box::new(right)),
194                }
195            }
196            LessThanOrEqual(left, right) => {
197                let left = left.simplify_internal(lookup)?;
198                let right = right.simplify_internal(lookup)?;
199                match (left, right) {
200                    (Num(n1), Num(n2)) => Bool(n1 <= n2),
201                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
202                    (left, right) => LessThanOrEqual(Box::new(left), Box::new(right)),
203                }
204            }
205            GreaterThanOrEqual(left, right) => {
206                let left = left.simplify_internal(lookup)?;
207                let right = right.simplify_internal(lookup)?;
208                match (left, right) {
209                    (Num(n1), Num(n2)) => Bool(n1 >= n2),
210                    (Bool(_), _) | (_, Bool(_)) => return Err(EvalError::BoolInMath),
211                    (left, right) => GreaterThanOrEqual(Box::new(left), Box::new(right)),
212                }
213            }
214            Equal(left, right) => {
215                let left = left.simplify_internal(lookup)?;
216                let right = right.simplify_internal(lookup)?;
217                match (left, right) {
218                    (Num(n1), Num(n2)) => Bool(n1 == n2),
219                    (Bool(b1), Bool(b2)) => Bool(b1 == b2),
220                    (Num(_), Bool(_)) => return Err(EvalError::BoolInMath),
221                    (Bool(_), Num(_)) => return Err(EvalError::NumberInLogic),
222                    (left, right) => Equal(Box::new(left), Box::new(right)),
223                }
224            }
225            NotEqual(left, right) => {
226                let left = left.simplify_internal(lookup)?;
227                let right = right.simplify_internal(lookup)?;
228                match (left, right) {
229                    (Num(n1), Num(n2)) => Bool(n1 != n2),
230                    (Bool(b1), Bool(b2)) => Bool(b1 != b2),
231                    (left, right) => NotEqual(Box::new(left), Box::new(right)),
232                }
233            }
234
235            And(left, right) => {
236                let left = left.simplify_internal(lookup)?;
237                if matches!(left, Bool(false)) {
238                    Bool(false)
239                } else {
240                    let right = right.simplify_internal(lookup)?;
241                    match (left, right) {
242                        (Num(_), _) | (_, Num(_)) => return Err(EvalError::NumberInLogic),
243                        (Bool(b1), Bool(b2)) => Bool(b1 && b2),
244                        (left, right) => And(Box::new(left), Box::new(right)),
245                    }
246                }
247            }
248
249            Or(left, right) => {
250                let left = left.simplify_internal(lookup)?;
251                if matches!(left, Bool(true)) {
252                    Bool(true)
253                } else {
254                    let right = right.simplify_internal(lookup)?;
255                    match (left, right) {
256                        (Num(_), _) | (_, Num(_)) => return Err(EvalError::NumberInLogic),
257                        (Bool(b1), Bool(b2)) => Bool(b1 || b2),
258                        (left, right) => Or(Box::new(left), Box::new(right)),
259                    }
260                }
261            }
262        };
263        return Ok(out);
264    }
265}
266
267fn parse_bool(input: &str) -> IResult<&str, Expr, NomError> {
268    let (input, result) = preceded(
269        multispace0,
270        alt((
271            map(tag("true"), |_| Expr::Bool(true)),
272            map(tag("false"), |_| Expr::Bool(false)),
273        )),
274    )(input)?;
275
276    Ok((input, result))
277}
278
279fn parse_num(input: &str) -> IResult<&str, Expr, NomError> {
280    preceded(multispace0, map(double, Expr::Num))(input)
281}
282
283fn parse_ident(input: &str) -> IResult<&str, Expr, NomError> {
284    preceded(
285        multispace0,
286        map(
287            recognize(pair(
288                alt((alpha1, tag("_"))),
289                many0_count(alt((alphanumeric1, tag("_")))),
290            )),
291            |s: &str| Expr::Ident(s.into()),
292        ),
293    )(input)
294}
295
296fn parse_parens(input: &str) -> IResult<&str, Expr, NomError> {
297    preceded(
298        multispace0,
299        delimited(
300            preceded(multispace0, char('(')),
301            parse_expr,
302            preceded(multispace0, char(')')),
303        ),
304    )(input)
305}
306
307fn parse_neg(input: &str) -> IResult<&str, Expr, NomError> {
308    map(
309        preceded(multispace0, pair(char('-'), parse_factor)),
310        |(_, expr)| Expr::Neg(Box::new(expr)),
311    )(input)
312}
313
314fn parse_not(input: &str) -> IResult<&str, Expr, NomError> {
315    map(
316        preceded(multispace0, pair(char('!'), parse_factor)),
317        |(_, expr)| Expr::Not(Box::new(expr)),
318    )(input)
319}
320
321fn parse_factor(input: &str) -> IResult<&str, Expr, NomError> {
322    alt((
323        parse_bool,
324        parse_ident,
325        parse_num,
326        parse_neg,
327        parse_not,
328        parse_parens,
329    ))(input)
330}
331
332fn parse_mul_div(input: &str) -> IResult<&str, Expr, NomError> {
333    let (input, init) = parse_factor(input)?;
334    let (input, ops) = nom::multi::many0(pair(
335        preceded(multispace0, alt((char('*'), char('/')))),
336        preceded(multispace0, parse_factor),
337    ))(input)?;
338
339    let expr = ops.into_iter().fold(init, |acc, (op, factor)| {
340        if op == '*' {
341            Expr::Mul(Box::new(acc), Box::new(factor))
342        } else {
343            Expr::Div(Box::new(acc), Box::new(factor))
344        }
345    });
346
347    Ok((input, expr))
348}
349
350fn parse_add_sub(input: &str) -> IResult<&str, Expr, NomError> {
351    let (input, init) = parse_mul_div(input)?;
352    let (input, ops) = nom::multi::many0(pair(
353        preceded(multispace0, alt((char('+'), char('-')))),
354        preceded(multispace0, parse_mul_div),
355    ))(input)?;
356
357    let expr = ops.into_iter().fold(init, |acc, (op, term)| {
358        if op == '+' {
359            Expr::Add(Box::new(acc), Box::new(term))
360        } else {
361            Expr::Sub(Box::new(acc), Box::new(term))
362        }
363    });
364
365    Ok((input, expr))
366}
367
368fn parse_comparison(input: &str) -> IResult<&str, Expr, NomError> {
369    let (input, initial) = parse_add_sub(input)?;
370
371    let (input, mut comparisons) = many0(pair(
372        preceded(
373            multispace0,
374            alt((
375                tag("<="),
376                tag(">="),
377                tag("<"),
378                tag(">"),
379                tag("=="),
380                tag("!="),
381            )),
382        ),
383        preceded(multispace0, parse_add_sub),
384    ))(input)?;
385
386    let (op, expr) = match comparisons.len() {
387        0 => return Ok((input, initial)),
388        1 => comparisons.remove(0),
389        _ => {
390            return Err(nom::Err::Failure(NomError::from_error_kind(
391                input,
392                ErrorKind::TooLarge,
393            )));
394        }
395    };
396
397    use Expr::*;
398    let result = match op {
399        "<=" => LessThanOrEqual(Box::new(initial), Box::new(expr)),
400        ">=" => GreaterThanOrEqual(Box::new(initial), Box::new(expr)),
401        "<" => LessThan(Box::new(initial), Box::new(expr)),
402        ">" => GreaterThan(Box::new(initial), Box::new(expr)),
403        "==" => Equal(Box::new(initial), Box::new(expr)),
404        "!=" => NotEqual(Box::new(initial), Box::new(expr)),
405        _ => unreachable!(),
406    };
407
408    Ok((input, result))
409}
410
411fn parse_and(input: &str) -> IResult<&str, Expr, NomError> {
412    let (input, init) = parse_comparison(input)?;
413    let (input, ops) = nom::multi::many0(pair(
414        preceded(multispace0, tag("&&")),
415        preceded(multispace0, parse_comparison),
416    ))(input)?;
417
418    let expr = ops
419        .into_iter()
420        .fold(init, |acc, term| Expr::And(Box::new(acc), Box::new(term.1)));
421
422    Ok((input, expr))
423}
424
425fn parse_or(input: &str) -> IResult<&str, Expr, NomError> {
426    let (input, init) = parse_and(input)?;
427    let (input, ops) = nom::multi::many0(pair(
428        preceded(multispace0, tag("||")),
429        preceded(multispace0, parse_and),
430    ))(input)?;
431
432    let expr = ops
433        .into_iter()
434        .fold(init, |acc, term| Expr::Or(Box::new(acc), Box::new(term.1)));
435
436    Ok((input, expr))
437}
438
439pub fn parse_expr(input: &str) -> IResult<&str, Expr, NomError> {
440    parse_or(input)
441}
442
443pub fn parse_token_expr(input: &str) -> IResult<&str, Token, NomError> {
444    let (input, _) = preceded(multispace0, tag("expr"))(input)?;
445
446    let (input, inner) = cut(get_inner)(input)?;
447
448    let (_, expr) = cut(parse_expr)(inner)?;
449
450    let expr = Token::Expr(expr);
451
452    Ok((input, expr))
453}
454
455// https://stackoverflow.com/questions/70630556/parse-allowing-nested-parentheses-in-nom
456pub fn take_until_unbalanced(
457    opening_bracket: char,
458    closing_bracket: char,
459) -> impl Fn(&str) -> IResult<&str, &str, NomError> {
460    move |i: &str| {
461        let mut index = 0;
462        let mut bracket_counter = 0;
463        while let Some(n) = &i[index..].find(&[opening_bracket, closing_bracket, '\\'][..]) {
464            index += n;
465            let mut it = i[index..].chars();
466            match it.next().unwrap_or_default() {
467                c if c == '\\' => {
468                    // Skip the escape char `\`.
469                    index += '\\'.len_utf8();
470                    // Skip also the following char.
471                    let c = it.next().unwrap_or_default();
472                    index += c.len_utf8();
473                }
474                c if c == opening_bracket => {
475                    bracket_counter += 1;
476                    index += opening_bracket.len_utf8();
477                }
478                c if c == closing_bracket => {
479                    // Closing bracket.
480                    bracket_counter -= 1;
481                    index += closing_bracket.len_utf8();
482                }
483                // Can not happen.
484                _ => unreachable!(),
485            };
486            // We found the unmatched closing bracket.
487            if bracket_counter == -1 {
488                // We do not consume it.
489                index -= closing_bracket.len_utf8();
490                return Ok((&i[index..], &i[0..index]));
491            };
492        }
493
494        if bracket_counter == 0 {
495            Ok(("", i))
496        } else {
497            Err(nom::Err::Error(NomError::from_error_kind(
498                i,
499                ErrorKind::TakeUntil,
500            )))
501        }
502    }
503}
504
505#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
506pub enum Range<'a> {
507    #[serde(borrow)]
508    Exclusive((Expr<'a>, Expr<'a>)),
509    Inclusive((Expr<'a>, Expr<'a>)),
510}
511
512impl<'a> Range<'a> {
513    fn into_owned(self) -> Range<'static> {
514        match self {
515            Range::Exclusive((expr1, expr2)) => {
516                Range::Exclusive((expr1.into_owned(), expr2.into_owned()))
517            }
518            Range::Inclusive((expr1, expr2)) => {
519                Range::Inclusive((expr1.into_owned(), expr2.into_owned()))
520            }
521        }
522    }
523}
524
525#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
526pub enum Else<'a> {
527    #[serde(borrow)]
528    Block(Vec<Token<'a>>),
529    If(Box<If<'a>>),
530}
531
532impl<'a> Else<'a> {
533    fn into_owned(self) -> Else<'static> {
534        match self {
535            Else::Block(tokens) => Else::Block(vec_to_owned(tokens)),
536            Else::If(if_tok) => Else::If(Box::new((*if_tok).into_owned())),
537        }
538    }
539}
540
541#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
542pub struct If<'a> {
543    #[serde(borrow)]
544    condition: Expr<'a>,
545    tokens: Vec<Token<'a>>,
546    else_tokens: Option<Else<'a>>,
547}
548
549impl<'a> If<'a> {
550    fn into_owned(self) -> If<'static> {
551        If {
552            condition: self.condition.into_owned(),
553            tokens: vec_to_owned(self.tokens),
554            else_tokens: self.else_tokens.map(|e| e.into_owned()),
555        }
556    }
557}
558
559#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
560pub struct For<'a> {
561    #[serde(borrow)]
562    ident: Cow<'a, str>,
563    range: Range<'a>,
564    tokens: Vec<Token<'a>>,
565}
566
567impl<'a> For<'a> {
568    fn into_owned(self) -> For<'static> {
569        For {
570            ident: Cow::Owned(self.ident.into_owned()),
571            range: self.range.into_owned(),
572            tokens: vec_to_owned(self.tokens),
573        }
574    }
575}
576
577#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
578pub struct NestedFor<'a> {
579    #[serde(borrow)]
580    running_ident: Cow<'a, str>,
581    total_ident: Cow<'a, str>,
582    to_nest: Vec<Token<'a>>,
583    pre: Vec<Token<'a>>,
584    inner: Vec<Token<'a>>,
585}
586
587impl<'a> NestedFor<'a> {
588    fn into_owned(self) -> NestedFor<'static> {
589        NestedFor {
590            running_ident: Cow::Owned(self.running_ident.into_owned()),
591            total_ident: Cow::Owned(self.total_ident.into_owned()),
592            to_nest: vec_to_owned(self.to_nest),
593            pre: vec_to_owned(self.pre),
594            inner: vec_to_owned(self.inner),
595        }
596    }
597}
598
599#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
600pub struct Concat<'a> {
601    #[serde(borrow)]
602    ident: Cow<'a, str>,
603    range: Range<'a>,
604    tokens: Vec<Token<'a>>,
605    separator: Vec<Token<'a>>,
606}
607
608impl<'a> Concat<'a> {
609    fn into_owned(self) -> Concat<'static> {
610        Concat {
611            ident: Cow::Owned(self.ident.into_owned()),
612            range: self.range.into_owned(),
613            tokens: vec_to_owned(self.tokens),
614            separator: vec_to_owned(self.separator),
615        }
616    }
617}
618
619#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
620pub struct Export<'a> {
621    #[serde(borrow)]
622    name: Cow<'a, str>,
623    tokens: Vec<Token<'a>>,
624}
625
626impl<'a> Export<'a> {
627    fn into_owned(self) -> Export<'static> {
628        Export {
629            name: Cow::Owned(self.name.into_owned()),
630            tokens: vec_to_owned(self.tokens),
631        }
632    }
633}
634
635#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
636pub enum Token<'a> {
637    #[serde(borrow)]
638    Code(Cow<'a, str>),
639    Ident(Cow<'a, str>),
640    Expr(Expr<'a>),
641    If(If<'a>),
642    For(For<'a>),
643    NestedFor(NestedFor<'a>),
644    Concat(Concat<'a>),
645    Import(Cow<'a, str>),
646    Export(Export<'a>),
647}
648
649pub fn vec_to_owned<'a>(tokens: Vec<Token<'a>>) -> Vec<Token<'static>> {
650    tokens.into_iter().map(|token| token.into_owned()).collect()
651}
652
653impl<'a> Token<'a> {
654    fn into_owned(self) -> Token<'static> {
655        use Token::*;
656        match self {
657            Code(cow) => Code(Cow::Owned(cow.to_string())),
658            Ident(cow) => Ident(Cow::Owned(cow.to_string())),
659            Expr(expr) => Expr(expr.into_owned()),
660            If(if_tok) => If(if_tok.into_owned()),
661            For(for_tok) => For(for_tok.into_owned()),
662            NestedFor(nested) => NestedFor(nested.into_owned()),
663            Concat(concat) => Concat(concat.into_owned()),
664            Import(cow) => Import(Cow::Owned(cow.to_string())),
665            Export(export) => Export(export.into_owned()),
666        }
667    }
668}
669
670pub fn parse_comment(input: &str) -> IResult<&str, &str, NomError> {
671    recognize(tuple((
672        tag("//"),
673        take_till(|c| c == '\n'),
674        opt(char('\n')),
675    )))(input)
676}
677
678fn parse_shader_code(input: &str) -> IResult<&str, Option<Token>, NomError> {
679    let (input, code) = recognize(many0(alt((
680        take_till1(|c| c == '#' || c == '/'),
681        alt((parse_comment, tag("/"))),
682    ))))(input)?;
683
684    if code.is_empty() {
685        Ok((input, None))
686    } else {
687        let code = trim_trailing_spaces(code);
688        Ok((input, Some(Token::Code(code.into()))))
689    }
690}
691
692fn parse_ident_token(input: &str) -> IResult<&str, Token, NomError> {
693    preceded(
694        multispace0,
695        map(
696            recognize(pair(
697                alt((alpha1, tag("_"))),
698                many0_count(alt((alphanumeric1, tag("_")))),
699            )),
700            |s: &str| Token::Ident(s.into()),
701        ),
702    )(input)
703}
704
705fn eat_newline(input: &str) -> IResult<&str, (), NomError> {
706    let (input, _) = opt(preceded(space0, line_ending))(input)?;
707    Ok((input, ()))
708}
709
710pub fn trim_trailing_spaces(input: &str) -> &str {
711    let mut chars = input.chars().rev();
712    let mut trailing_spaces = 0;
713
714    while let Some(ch) = chars.next() {
715        match ch {
716            ' ' | '\t' => trailing_spaces += 1,
717            '\n' => break,
718            _ => {
719                trailing_spaces = 0;
720                break;
721            }
722        }
723    }
724    &input[..input.len() - trailing_spaces]
725}
726
727fn get_inner(input: &str) -> IResult<&str, &str, NomError> {
728    let (input, inner) = delimited(
729        preceded(multispace0, tag("{")),
730        preceded(eat_newline, take_until_unbalanced('{', '}')),
731        tag("}"),
732    )(input)?;
733
734    let (inner, _) = eat_newline(trim_trailing_spaces(inner))?;
735
736    Ok((input, inner))
737}
738
739fn parse_inner(input: &str) -> IResult<&str, Vec<Token>, NomError> {
740    let (input, inner) = cut(get_inner)(input)?;
741
742    let (_, inner_tokens) = cut(parse_tokens)(inner)?;
743
744    Ok((input, inner_tokens))
745}
746
747fn parse_if(input: &str) -> IResult<&str, Token, NomError> {
748    let (input, _) = tag("if")(input)?;
749
750    let (input, condition) = cut(parse_expr)(input)?;
751
752    // FIXME unwrap on simplify error. This should be propagated.
753    let condition = condition.simplify_without_ident().unwrap();
754
755    let (input, inner_tokens) = cut(parse_inner)(input)?;
756
757    let (input, else_tag) = opt(preceded(multispace0, tag("#else")))(input)?;
758
759    let (input, else_tokens) = match else_tag {
760        Some(_) => cut(alt((
761            map(preceded(multispace0, parse_if), |res| {
762                let Token::If(res) = res else { unreachable!() };
763                Some(Else::If(Box::new(res)))
764            }),
765            map(parse_inner, |res| Some(Else::Block(res))),
766        )))(input)?,
767        None => (input, None),
768    };
769
770    Ok((
771        input,
772        Token::If(If {
773            condition,
774            tokens: inner_tokens,
775            else_tokens,
776        }),
777    ))
778}
779
780fn parse_range(input: &str) -> IResult<&str, Range, NomError> {
781    let (input, first_expr_str) = cut(take_until(".."))(input)?;
782
783    let (_, exp1) = parse_expr(first_expr_str)?;
784    let (input, ty) = cut(alt((tag("..="), tag(".."))))(input)?;
785    let (input, exp2) = parse_expr(input)?;
786    Ok((
787        input,
788        match ty {
789            "..=" => Range::Inclusive((exp1, exp2)),
790            ".." => Range::Exclusive((exp1, exp2)),
791            _ => unreachable!(),
792        },
793    ))
794}
795
796fn parse_for(input: &str) -> IResult<&str, Token, NomError> {
797    let (input, _) = tag("for")(input)?;
798
799    let (input, Token::Ident(ident)) = cut(parse_ident_token)(input)? else { unreachable!() };
800
801    let (input, _) = cut(preceded(multispace0, tag("in")))(input)?;
802
803    let (input, range) = cut(parse_range)(input)?;
804
805    let (input, inner) = cut(parse_inner)(input)?;
806
807    let result = Token::For(For {
808        ident,
809        range,
810        tokens: inner,
811    });
812    Ok((input, result))
813}
814
815fn parse_concat(input: &str) -> IResult<&str, Token, NomError> {
816    let (input, _) = tag("concat")(input)?;
817
818    let (input, Token::Ident(ident)) = cut(parse_ident_token)(input)? else { unreachable!() };
819
820    let (input, _) = cut(preceded(multispace0, tag("in")))(input)?;
821
822    let (input, range) = cut(parse_range)(input)?;
823
824    let (input, inner) = cut(parse_inner)(input)?;
825
826    let (input, separator) = cut(parse_inner)(input)?;
827
828    let result = Token::Concat(Concat {
829        ident,
830        range,
831        tokens: inner,
832        separator,
833    });
834    Ok((input, result))
835}
836
837fn parse_nest(input: &str) -> IResult<&str, Token, NomError> {
838    let (input, _) = tag("nest")(input)?;
839
840    let (input, Token::Ident(running_ident)) = cut(parse_ident_token)(input)? else { unreachable!() };
841
842    let (input, _) = cut(preceded(multispace0, char('=')))(input)?;
843
844    let (input, Token::Ident(total_ident)) = cut(parse_ident_token)(input)? else { unreachable!() };
845
846    let (input, to_nest) = cut(parse_inner)(input)?;
847
848    let (input, _) = cut(preceded(multispace0, tag("#pre")))(input)?;
849
850    let (input, pre) = cut(parse_inner)(input)?;
851
852    let (input, _) = cut(preceded(multispace0, tag("#inner")))(input)?;
853
854    let (input, inner) = cut(parse_inner)(input)?;
855
856    let result = Token::NestedFor(NestedFor {
857        running_ident,
858        total_ident,
859        to_nest,
860        pre,
861        inner,
862    });
863    Ok((input, result))
864}
865
866fn parse_import(input: &str) -> IResult<&str, Token, NomError> {
867    let (input, _) = tag("import")(input)?;
868
869    let (input, Token::Ident(name)) = cut(parse_ident_token)(input)? else { unreachable!() };
870
871    Ok((input, Token::Import(name)))
872}
873
874fn parse_export(input: &str) -> IResult<&str, Token, NomError> {
875    let (input, _) = tag("export")(input)?;
876
877    let (input, Token::Ident(name)) = cut(parse_ident_token)(input)? else { unreachable!() };
878
879    let (input, tokens) = cut(parse_inner)(input)?;
880
881    let export = Export { name, tokens };
882
883    Ok((input, Token::Export(export)))
884}
885
886pub fn parse_tokens(mut input: &str) -> IResult<&str, Vec<Token>, NomError> {
887    // |mut input: &str| {
888    let mut out = Vec::new();
889
890    // Consume initial shader code, up to the first "#"
891    let (new_input, code) = parse_shader_code(input)?;
892    if let Some(code) = code {
893        out.push(code);
894    }
895    input = new_input;
896
897    while !input.is_empty() {
898        let (new_input, _) = char('#')(input)?;
899        input = new_input;
900        // Parse directive
901        let (new_input, token) = alt((
902            parse_token_expr,
903            parse_if,
904            parse_for,
905            parse_concat,
906            parse_nest,
907            parse_import,
908            parse_export,
909            parse_ident_token,
910        ))(input)?;
911        out.push(token);
912        let (new_input, code) = parse_shader_code(new_input)?;
913        if let Some(code) = code {
914            out.push(code);
915        }
916        input = new_input;
917    }
918    Ok((input, out))
919    // }
920}
921
922#[derive(Debug, thiserror::Error)]
923pub enum ExpansionError {
924    #[error("The identifier was not found")]
925    IdentNotFound(String),
926    #[error("Attempted to import a piece of code that was not found exported anywhere")]
927    ImportNotFound(String),
928    #[error("There was a problem with evaluating an expression")]
929    SimplifyError(#[from] EvalError),
930    #[error("A condition simplified to something that wasn't a boolean")]
931    NonBoolCondition(Expr<'static>),
932    #[error("A range contained something that wasn't a number")]
933    NonNumRange(Range<'static>),
934    #[error("An explicit expression contained something that wasn't a boolean or a number")]
935    NonBoolOrNumExpr(Expr<'static>),
936    #[error("A number was expected for this definition")]
937    ExpectedNumber(Definition<'static>),
938}
939
940#[derive(Clone, PartialEq, Debug)]
941pub enum Definition<'def> {
942    Bool(bool),
943    Int(i32),
944    UInt(u32),
945    Any(Cow<'def, str>),
946    Float(f32),
947}
948
949impl<'def> Definition<'def> {
950    fn new_owned(&self) -> Definition<'static> {
951        use Definition::*;
952        match self {
953            Bool(v) => Bool(*v),
954            Int(v) => Int(*v),
955            UInt(v) => UInt(*v),
956            Any(cow) => Any(Cow::Owned(cow.to_string())),
957            Float(v) => Float(*v),
958        }
959    }
960}
961
962impl<'a> From<bool> for Definition<'a> {
963    fn from(value: bool) -> Self {
964        Definition::Bool(value)
965    }
966}
967
968impl<'a> From<i32> for Definition<'a> {
969    fn from(value: i32) -> Self {
970        Definition::Int(value)
971    }
972}
973
974impl<'a> From<u32> for Definition<'a> {
975    fn from(value: u32) -> Self {
976        Definition::UInt(value)
977    }
978}
979
980impl<'a> From<&'a str> for Definition<'a> {
981    fn from(value: &'a str) -> Self {
982        Definition::Any(value.into())
983    }
984}
985
986impl<'a> From<String> for Definition<'a> {
987    fn from(value: String) -> Self {
988        Definition::Any(value.into())
989    }
990}
991
992impl<'a> From<f32> for Definition<'a> {
993    fn from(value: f32) -> Self {
994        Definition::Float(value)
995    }
996}
997
998impl<'def> Default for Definition<'def> {
999    fn default() -> Self {
1000        Self::Any("".into())
1001    }
1002}
1003
1004impl<'def> From<Definition<'def>> for String {
1005    fn from(value: Definition) -> Self {
1006        let mut out = String::new();
1007        value.insert_in_string(&mut out).unwrap();
1008        out
1009    }
1010}
1011
1012impl<'def> Definition<'def> {
1013    fn insert_in_string(&self, target: &mut String) -> Result<(), std::fmt::Error> {
1014        match self {
1015            Definition::Bool(def) => write!(target, "{def}"),
1016            Definition::Int(def) => write!(target, "{def}"),
1017            Definition::UInt(def) => write!(target, "{def}u"),
1018            Definition::Float(def) => write!(target, "{def:.1}"),
1019            Definition::Any(def) => write!(target, "{def}"),
1020        }
1021    }
1022}
1023
1024// FIXME This could be a regular function but I can't be arsed figuring out the traits and lifetimes
1025macro_rules! make_expr_lookup {
1026    ($func:ident) => {
1027        |s: Cow<'a, str>| -> Option<Expr<'a>> {
1028            let def = $func(s)?;
1029            use Definition::*;
1030            match def {
1031                Bool(val) => Some(Expr::Bool(val)),
1032                Int(val) => Some(Expr::Num(val as f64)),
1033                UInt(val) => Some(Expr::Num(val as f64)),
1034                Float(val) => Some(Expr::Num(val as f64)),
1035                // FIXME
1036                Any(_val) => panic!("Maybe need to deal with this at some point"),
1037            }
1038        }
1039    };
1040}
1041
1042fn process_if<'a, 'def>(
1043    input: If<'a>,
1044    result: &mut String,
1045    lookup: &impl Fn(Cow<str>) -> Option<Definition<'def>>,
1046    exports: &impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1047) -> Result<(), ExpansionError> {
1048    let expr_lookup = make_expr_lookup!(lookup);
1049
1050    let If {
1051        condition,
1052        tokens,
1053        else_tokens,
1054    } = input;
1055
1056    let condition = condition.simplify_internal(&expr_lookup)?;
1057    match condition {
1058        Expr::Bool(true) => process_internal(tokens, result, lookup, exports),
1059        Expr::Bool(false) => match else_tokens {
1060            Some(Else::Block(tokens)) => process_internal(tokens, result, lookup, exports),
1061            Some(Else::If(new_if)) => process_if(*new_if, result, lookup, exports),
1062            None => Ok(()),
1063        },
1064        _ => return Err(ExpansionError::NonBoolCondition(condition.into_owned())),
1065    }
1066}
1067
1068fn process_for<'a, 'def>(
1069    input: For<'a>,
1070    result: &mut String,
1071    lookup: &impl Fn(Cow<str>) -> Option<Definition<'def>>,
1072    exports: &impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1073) -> Result<(), ExpansionError> {
1074    let For {
1075        ident,
1076        range,
1077        tokens,
1078    } = input;
1079
1080    let expr_lookup = make_expr_lookup!(lookup);
1081
1082    let range = match range {
1083        Range::Exclusive((expr1, expr2)) => {
1084            let expr1 = expr1.simplify(expr_lookup)?;
1085            let expr2 = expr2.simplify(expr_lookup)?;
1086            Range::Exclusive((expr1, expr2))
1087        }
1088        Range::Inclusive((expr1, expr2)) => {
1089            let expr1 = expr1.simplify(expr_lookup)?;
1090            let expr2 = expr2.simplify(expr_lookup)?;
1091            Range::Inclusive((expr1, expr2))
1092        }
1093    };
1094
1095    let iter = match range {
1096        Range::Exclusive((Expr::Num(start), Expr::Num(end))) => {
1097            Box::new(start as isize..end as isize) as Box<dyn Iterator<Item = isize>>
1098        }
1099        Range::Inclusive((Expr::Num(start), Expr::Num(end))) => {
1100            Box::new(start as isize..=end as isize) as Box<dyn Iterator<Item = isize>>
1101        }
1102        _ => return Err(ExpansionError::NonNumRange(range.into_owned())),
1103    };
1104
1105    for val in iter {
1106        let new_lookup = Box::new(|s: Cow<str>| -> Option<Definition<'def>> {
1107            if s == ident {
1108                Some(Definition::Int(val as i32))
1109            } else {
1110                lookup(s)
1111            }
1112        }) as Box<dyn Fn(Cow<str>) -> Option<Definition<'def>>>;
1113        process_internal(tokens.clone(), result, &new_lookup, exports)?;
1114    }
1115    Ok(())
1116}
1117
1118fn process_nested_for<'a, 'def>(
1119    input: NestedFor<'a>,
1120    result: &mut String,
1121    lookup: &impl Fn(Cow<str>) -> Option<Definition<'def>>,
1122    exports: &impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1123) -> Result<(), ExpansionError> {
1124    let NestedFor {
1125        running_ident,
1126        total_ident,
1127        to_nest,
1128        pre,
1129        inner,
1130    } = input;
1131
1132    let total_depth = lookup(total_ident.clone())
1133        .ok_or_else(|| ExpansionError::IdentNotFound(total_ident.to_string()))?;
1134
1135    let total_depth = match total_depth {
1136        Definition::Bool(_) | Definition::Any(_) => {
1137            return Err(ExpansionError::ExpectedNumber(total_depth.new_owned()))
1138        }
1139        Definition::Int(num) => num as usize,
1140        Definition::UInt(num) => num as usize,
1141        Definition::Float(num) => num as usize,
1142    };
1143
1144    for val in 0..total_depth {
1145        let new_lookup = Box::new(|s: Cow<str>| -> Option<Definition<'def>> {
1146            if s == running_ident {
1147                Some(Definition::Int(val as i32))
1148            } else {
1149                lookup(s)
1150            }
1151        }) as Box<dyn Fn(Cow<str>) -> Option<Definition<'def>>>;
1152
1153        process_internal(to_nest.clone(), result, &new_lookup, exports)?;
1154        result.push_str("{\n");
1155        process_internal(pre.clone(), result, &new_lookup, exports)?;
1156        result.push('\n');
1157    }
1158
1159    process_internal(inner, result, lookup, exports)?;
1160
1161    for _ in 0..total_depth {
1162        result.push_str("\n}")
1163    }
1164
1165    Ok(())
1166}
1167
1168fn process_concat<'a, 'def>(
1169    input: Concat<'a>,
1170    result: &mut String,
1171    lookup: &impl Fn(Cow<str>) -> Option<Definition<'def>>,
1172    exports: &impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1173) -> Result<(), ExpansionError> {
1174    let Concat {
1175        ident,
1176        range,
1177        tokens,
1178        separator,
1179    } = input;
1180
1181    let expr_lookup = make_expr_lookup!(lookup);
1182
1183    let range = match range {
1184        Range::Exclusive((expr1, expr2)) => {
1185            let expr1 = expr1.simplify(expr_lookup)?;
1186            let expr2 = expr2.simplify(expr_lookup)?;
1187            Range::Exclusive((expr1, expr2))
1188        }
1189        Range::Inclusive((expr1, expr2)) => {
1190            let expr1 = expr1.simplify(expr_lookup)?;
1191            let expr2 = expr2.simplify(expr_lookup)?;
1192            Range::Inclusive((expr1, expr2))
1193        }
1194    };
1195
1196    let iter = match range {
1197        Range::Exclusive((Expr::Num(start), Expr::Num(end))) => {
1198            Box::new(start as isize..end as isize) as Box<dyn Iterator<Item = isize>>
1199        }
1200        Range::Inclusive((Expr::Num(start), Expr::Num(end))) => {
1201            Box::new(start as isize..=end as isize) as Box<dyn Iterator<Item = isize>>
1202        }
1203        _ => return Err(ExpansionError::NonNumRange(range.into_owned())),
1204    };
1205    let mut iter = iter.peekable();
1206
1207    while let Some(val) = iter.next() {
1208        let new_lookup = Box::new(|s: Cow<str>| -> Option<Definition<'def>> {
1209            if s == ident {
1210                Some(Definition::Int(val as i32))
1211            } else {
1212                lookup(s)
1213            }
1214        }) as Box<dyn Fn(Cow<str>) -> Option<Definition<'def>>>;
1215
1216        process_internal(tokens.clone(), result, &new_lookup, exports)?;
1217
1218        if iter.peek().is_some() {
1219            process_internal(separator.clone(), result, &new_lookup, exports)?;
1220        }
1221    }
1222    Ok(())
1223}
1224
1225fn process_internal<'a, 'def>(
1226    tokens: Vec<Token<'a>>,
1227    result: &mut String,
1228    lookup: &impl Fn(Cow<str>) -> Option<Definition<'def>>,
1229    exports: &impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1230) -> Result<(), ExpansionError> {
1231    for token in tokens {
1232        match token {
1233            Token::Code(code) => result.push_str(&code),
1234            Token::Ident(name) => {
1235                let Some(shader_def) = lookup(name.clone()) else { return Err(ExpansionError::IdentNotFound(name.to_string()))};
1236                let string = String::from(shader_def);
1237                if let Ok((_, tokens)) = parse_tokens(&string) {
1238                    let tokens = vec_to_owned(tokens);
1239                    process_internal(tokens, result, lookup, exports)?;
1240                } else {
1241                    write!(result, "{}", string).unwrap();
1242                }
1243            }
1244            Token::Expr(expr) => {
1245                let expr_lookup = make_expr_lookup!(lookup);
1246                let simplified_expr = expr.simplify_internal(&expr_lookup)?;
1247                match simplified_expr {
1248                    Expr::Bool(b) => write!(result, "{}", b).unwrap(),
1249                    Expr::Num(n) => write!(result, "{}", n).unwrap(),
1250                    _ => {
1251                        return Err(ExpansionError::NonBoolOrNumExpr(
1252                            simplified_expr.into_owned(),
1253                        ))
1254                    }
1255                }
1256            }
1257            Token::If(if_tokens) => process_if(if_tokens, result, lookup, exports)?,
1258            Token::For(for_tokens) => process_for(for_tokens, result, lookup, exports)?,
1259            Token::NestedFor(nested_for) => {
1260                process_nested_for(nested_for, result, lookup, exports)?
1261            }
1262            Token::Concat(concat) => process_concat(concat, result, lookup, exports)?,
1263            Token::Import(name) => {
1264                let Some(tokens) = exports(name.clone()) else { return Err(ExpansionError::ImportNotFound(name.to_string())) };
1265                process_internal(tokens, result, lookup, exports)?;
1266            },
1267            Token::Export(Export { name: _, tokens }) => {
1268                process_internal(tokens, result, lookup, exports)?
1269            }
1270        }
1271    }
1272    Ok(())
1273}
1274
1275#[derive(thiserror::Error, Debug)]
1276#[error("The name .0 was already exported from another location.")]
1277pub struct ExportedMoreThanOnce(String);
1278
1279pub fn get_exports<'a>(
1280    tokens: &Vec<Token<'a>>,
1281    exports: &mut HashMap<Cow<'a, str>, Vec<Token<'a>>>,
1282) -> Result<(), ExportedMoreThanOnce> {
1283    for token in tokens {
1284        match token {
1285            Token::Export(Export { name, tokens }) => {
1286                match exports.entry(name.clone()){
1287                    Entry::Occupied(_) => return Err(ExportedMoreThanOnce(name.to_string())),
1288                    Entry::Vacant(vacant) => vacant.insert(tokens.clone()),
1289                };
1290            }
1291            _ => {}
1292        }
1293    }
1294    Ok(())
1295}
1296
1297pub fn process<'a, 'def>(
1298    tokens: Vec<Token<'a>>,
1299    lookup: impl Fn(Cow<str>) -> Option<Definition<'def>>,
1300    exports: impl Fn(Cow<'a, str>) -> Option<Vec<Token<'a>>>,
1301) -> Result<String, ExpansionError> {
1302    let mut result = String::new();
1303
1304    process_internal(tokens, &mut result, &lookup, &exports)?;
1305
1306    Ok(result)
1307}