calyx_frontend/
parser.rs

1#![allow(clippy::upper_case_acronyms)]
2
3//! Parser for Calyx programs.
4use super::ast::{
5    self, BitNum, Control, GuardComp as GC, GuardExpr, NumType, StaticGuardExpr,
6};
7use super::Attributes;
8use crate::{Attribute, Direction, PortDef, Primitive, Width};
9use calyx_utils::{self, CalyxResult, Id};
10use calyx_utils::{FileIdx, GPosIdx, GlobalPositionTable};
11use pest::pratt_parser::{Assoc, Op, PrattParser};
12use pest_consume::{match_nodes, Error, Parser};
13use std::fs;
14use std::io::Read;
15use std::path::Path;
16use std::str::FromStr;
17
18type ParseResult<T> = Result<T, Error<Rule>>;
19type ComponentDef = ast::ComponentDef;
20
21/// Data associated with parsing the file.
22#[derive(Clone)]
23struct UserData {
24    /// Index to the current file
25    pub file: FileIdx,
26}
27
28// user data is the input program so that we can create Id's
29// that have a reference to the input string
30type Node<'i> = pest_consume::Node<'i, Rule, UserData>;
31
32// include the grammar file so that Cargo knows to rebuild this file on grammar changes
33const _GRAMMAR: &str = include_str!("syntax.pest");
34
35// Define the precedence of binary operations. We use `lazy_static` so that
36// this is only ever constructed once.
37lazy_static::lazy_static! {
38    static ref PRATT: PrattParser<Rule> =
39    PrattParser::new()
40        .op(Op::infix(Rule::guard_or, Assoc::Left))
41        .op(Op::infix(Rule::guard_and, Assoc::Left));
42}
43
44#[derive(Parser)]
45#[grammar = "syntax.pest"]
46pub struct CalyxParser;
47
48impl CalyxParser {
49    /// Parse a Calyx program into an AST representation.
50    pub fn parse_file(path: &Path) -> CalyxResult<ast::NamespaceDef> {
51        let time = std::time::Instant::now();
52        let content = &fs::read(path).map_err(|err| {
53            calyx_utils::Error::invalid_file(format!(
54                "Failed to read {}: {err}",
55                path.to_string_lossy(),
56            ))
57        })?;
58        // Add a new file to the position table
59        let string_content = std::str::from_utf8(content)?.to_string();
60        let file = GlobalPositionTable::as_mut()
61            .add_file(path.to_string_lossy().to_string(), string_content);
62        let user_data = UserData { file };
63        let content = GlobalPositionTable::as_ref().get_source(file);
64        // Parse the file
65        let inputs =
66            CalyxParser::parse_with_userdata(Rule::file, content, user_data)
67                .map_err(|e| e.with_path(&path.to_string_lossy()))
68                .map_err(|e| {
69                    calyx_utils::Error::misc(format!(
70                        "Failed to parse `{}`: {err}",
71                        path.to_string_lossy(),
72                        err = e
73                    ))
74                })?;
75        let input = inputs.single().map_err(|e| {
76            calyx_utils::Error::misc(format!(
77                "Failed to parse `{}`: {err}",
78                path.to_string_lossy(),
79                err = e
80            ))
81        })?;
82        let out = CalyxParser::file(input).map_err(|e| {
83            calyx_utils::Error::misc(format!(
84                "Failed to parse `{}`: {err}",
85                path.to_string_lossy(),
86                err = e
87            ))
88        })?;
89        log::info!(
90            "Parsed `{}` in {}ms",
91            path.to_string_lossy(),
92            time.elapsed().as_millis()
93        );
94        Ok(out)
95    }
96
97    pub fn parse<R: Read>(mut r: R) -> CalyxResult<ast::NamespaceDef> {
98        let mut buf = String::new();
99        r.read_to_string(&mut buf).map_err(|err| {
100            calyx_utils::Error::invalid_file(format!(
101                "Failed to parse buffer: {err}",
102            ))
103        })?;
104        // Save the input string to the position table
105        let file =
106            GlobalPositionTable::as_mut().add_file("<stdin>".to_string(), buf);
107        let user_data = UserData { file };
108        let contents = GlobalPositionTable::as_ref().get_source(file);
109        // Parse the input
110        let inputs =
111            CalyxParser::parse_with_userdata(Rule::file, contents, user_data)
112                .map_err(|e| {
113                calyx_utils::Error::misc(
114                    format!("Failed to parse buffer: {e}",),
115                )
116            })?;
117        let input = inputs.single().map_err(|e| {
118            calyx_utils::Error::misc(format!("Failed to parse buffer: {e}",))
119        })?;
120        let out = CalyxParser::file(input).map_err(|e| {
121            calyx_utils::Error::misc(format!("Failed to parse buffer: {e}",))
122        })?;
123        Ok(out)
124    }
125
126    fn get_span(node: &Node) -> GPosIdx {
127        let ud = node.user_data();
128        let sp = node.as_span();
129        let pos = GlobalPositionTable::as_mut().add_pos(
130            ud.file,
131            sp.start(),
132            sp.end(),
133        );
134        GPosIdx(pos)
135    }
136
137    #[allow(clippy::result_large_err)]
138    fn guard_expr_helper(
139        ud: UserData,
140        pairs: pest::iterators::Pairs<Rule>,
141    ) -> ParseResult<Box<GuardExpr>> {
142        PRATT
143            .map_primary(|primary| match primary.as_rule() {
144                Rule::term => {
145                    Self::term(Node::new_with_user_data(primary, ud.clone()))
146                        .map(Box::new)
147                }
148                x => unreachable!("Unexpected rule {:?} for guard_expr", x),
149            })
150            .map_infix(|lhs, op, rhs| {
151                Ok(match op.as_rule() {
152                    Rule::guard_or => Box::new(GuardExpr::Or(lhs?, rhs?)),
153                    Rule::guard_and => Box::new(GuardExpr::And(lhs?, rhs?)),
154                    _ => unreachable!(),
155                })
156            })
157            .parse(pairs)
158    }
159
160    #[allow(clippy::result_large_err)]
161    fn static_guard_expr_helper(
162        ud: UserData,
163        pairs: pest::iterators::Pairs<Rule>,
164    ) -> ParseResult<Box<StaticGuardExpr>> {
165        PRATT
166            .map_primary(|primary| match primary.as_rule() {
167                Rule::static_term => Self::static_term(
168                    Node::new_with_user_data(primary, ud.clone()),
169                )
170                .map(Box::new),
171                x => unreachable!(
172                    "Unexpected rule {:?} for static_guard_expr",
173                    x
174                ),
175            })
176            .map_infix(|lhs, op, rhs| {
177                Ok(match op.as_rule() {
178                    Rule::guard_or => Box::new(StaticGuardExpr::Or(lhs?, rhs?)),
179                    Rule::guard_and => {
180                        Box::new(StaticGuardExpr::And(lhs?, rhs?))
181                    }
182                    _ => unreachable!(),
183                })
184            })
185            .parse(pairs)
186    }
187}
188
189#[allow(clippy::large_enum_variant)]
190enum ExtOrComp {
191    Ext((Option<String>, Vec<Primitive>)),
192    Comp(ComponentDef),
193    PrimInline(Primitive),
194}
195
196#[pest_consume::parser]
197impl CalyxParser {
198    fn EOI(_input: Node) -> ParseResult<()> {
199        Ok(())
200    }
201
202    fn semi(_input: Node) -> ParseResult<()> {
203        Ok(())
204    }
205
206    fn comma_req(_input: Node) -> ParseResult<()> {
207        Ok(())
208    }
209    fn comma(input: Node) -> ParseResult<()> {
210        match_nodes!(
211            input.clone().into_children();
212            [comma_req(_)] => Ok(()),
213            [] => Err(input.error("expected comma"))
214        )
215    }
216
217    fn comb(_input: Node) -> ParseResult<()> {
218        Ok(())
219    }
220
221    fn static_word(_input: Node) -> ParseResult<()> {
222        Ok(())
223    }
224
225    fn reference(_input: Node) -> ParseResult<()> {
226        Ok(())
227    }
228
229    // ================ Literals =====================
230    fn identifier(input: Node) -> ParseResult<Id> {
231        Ok(Id::new(input.as_str()))
232    }
233
234    fn bitwidth(input: Node) -> ParseResult<u64> {
235        input
236            .as_str()
237            .parse::<u64>()
238            .map_err(|_| input.error("Expected valid bitwidth"))
239    }
240
241    fn static_annotation(input: Node) -> ParseResult<std::num::NonZeroU64> {
242        Ok(match_nodes!(
243            input.into_children();
244            [static_word(_), latency_annotation(latency)] => latency,
245        ))
246    }
247
248    fn static_optional_latency(
249        input: Node,
250    ) -> ParseResult<Option<std::num::NonZeroU64>> {
251        Ok(match_nodes!(
252            input.into_children();
253            [static_word(_), latency_annotation(latency)] => Some(latency),
254            [static_word(_)] => None,
255        ))
256    }
257
258    fn both_comb_static(
259        input: Node,
260    ) -> ParseResult<Option<std::num::NonZeroU64>> {
261        Err(input.error("Cannot have both comb and static annotations"))
262    }
263
264    fn comb_or_static(
265        input: Node,
266    ) -> ParseResult<Option<std::num::NonZeroU64>> {
267        match_nodes!(
268            input.clone().into_children();
269            [both_comb_static(_)] => unreachable!("both_comb_static did not error"),
270            [comb(_)] => Ok(None),
271            [static_annotation(latency)] => Ok(Some(latency)),
272        )
273    }
274
275    fn bad_num(input: Node) -> ParseResult<u64> {
276        Err(input.error("Expected number with bitwidth (like 32'd10)."))
277    }
278
279    fn hex(input: Node) -> ParseResult<u64> {
280        u64::from_str_radix(input.as_str(), 16)
281            .map_err(|_| input.error("Expected hexadecimal number"))
282    }
283    fn decimal(input: Node) -> ParseResult<u64> {
284        #[allow(clippy::from_str_radix_10)]
285        u64::from_str_radix(input.as_str(), 10)
286            .map_err(|_| input.error("Expected decimal number"))
287    }
288    fn octal(input: Node) -> ParseResult<u64> {
289        u64::from_str_radix(input.as_str(), 8)
290            .map_err(|_| input.error("Expected octal number"))
291    }
292    fn binary(input: Node) -> ParseResult<u64> {
293        u64::from_str_radix(input.as_str(), 2)
294            .map_err(|_| input.error("Expected binary number"))
295    }
296
297    fn num_lit(input: Node) -> ParseResult<BitNum> {
298        let span = Self::get_span(&input);
299        let num = match_nodes!(
300            input.clone().into_children();
301            [bitwidth(width), decimal(val)] => BitNum {
302                    width,
303                    num_type: NumType::Decimal,
304                    val,
305                    span
306                },
307            [bitwidth(width), hex(val)] => BitNum {
308                    width,
309                    num_type: NumType::Hex,
310                    val,
311                    span
312                },
313            [bitwidth(width), octal(val)] => BitNum {
314                    width,
315                    num_type: NumType::Octal,
316                    val,
317                    span
318                },
319            [bitwidth(width), binary(val)] => BitNum {
320                    width,
321                    num_type: NumType::Binary,
322                    val,
323                    span
324                },
325
326        );
327
328        // the below cast is safe since the width must be less than 64 for
329        // the given literal to be unrepresentable
330        if num.width == 0
331            || (num.width < 64 && u64::pow(2, num.width as u32) <= num.val)
332        {
333            let lit_str = match num.num_type {
334                NumType::Binary => format!("{:b}", num.val),
335                NumType::Decimal => format!("{}", num.val),
336                NumType::Octal => format!("{:o}", num.val),
337                NumType::Hex => format!("{:x}", num.val),
338            };
339            let bit_plural = if num.width == 1 { "bit" } else { "bits" };
340            Err(input.error(format!(
341                "Cannot represent given literal '{}' in {} {}",
342                lit_str, num.width, bit_plural
343            )))
344        } else {
345            Ok(num)
346        }
347    }
348
349    fn char(input: Node) -> ParseResult<&str> {
350        Ok(input.as_str())
351    }
352
353    fn string_lit(input: Node) -> ParseResult<String> {
354        Ok(match_nodes!(
355            input.into_children();
356            [char(c)..] => c.collect::<Vec<_>>().join("")
357        ))
358    }
359
360    // ================ Attributes =====================
361    fn attribute(input: Node) -> ParseResult<(Attribute, u64)> {
362        match_nodes!(
363            input.clone().into_children();
364            [string_lit(key), bitwidth(num)] => Attribute::from_str(&key).map(|attr| (attr, num)).map_err(|e| input.error(format!("{:?}", e)))
365        )
366    }
367    fn attributes(input: Node) -> ParseResult<Attributes> {
368        match_nodes!(
369            input.clone().into_children();
370            [attribute(kvs)..] => kvs.collect::<Vec<_>>().try_into().map_err(|e| input.error(format!("{:?}", e)))
371        )
372    }
373    fn name_with_attribute(input: Node) -> ParseResult<(Id, Attributes)> {
374        Ok(match_nodes!(
375            input.into_children();
376            [identifier(name), attributes(attrs)] => (name, attrs),
377            [identifier(name)] => (name, Attributes::default()),
378        ))
379    }
380
381    fn block_char(input: Node) -> ParseResult<&str> {
382        Ok(input.as_str())
383    }
384
385    fn block_string(input: Node) -> ParseResult<String> {
386        Ok(match_nodes!(
387            input.into_children();
388            [block_char(c)..] => c.collect::<String>().trim().to_string()
389        ))
390    }
391
392    fn attr_val(input: Node) -> ParseResult<u64> {
393        Ok(match_nodes!(
394            input.into_children();
395            [bitwidth(num)] => num
396        ))
397    }
398
399    fn latency_annotation(input: Node) -> ParseResult<std::num::NonZeroU64> {
400        let num = match_nodes!(
401            input.clone().into_children();
402            [bitwidth(value)] => value,
403        );
404        if num == 0 {
405            Err(input.error("latency annotation of 0"))
406        } else {
407            Ok(std::num::NonZeroU64::new(num).unwrap())
408        }
409    }
410
411    fn at_attribute(input: Node) -> ParseResult<(Attribute, u64)> {
412        match_nodes!(
413            input.clone().into_children();
414            [identifier(key), attr_val(num)] => Attribute::from_str(key.as_ref()).map_err(|e| input.error(format!("{:?}", e))).map(|attr| (attr, num)),
415            [identifier(key)] => Attribute::from_str(key.as_ref()).map_err(|e| input.error(format!("{:?}", e))).map(|attr| (attr, 1)),
416        )
417    }
418
419    fn at_attributes(input: Node) -> ParseResult<Attributes> {
420        match_nodes!(
421            input.clone().into_children();
422            [at_attribute(kvs)..] => kvs.collect::<Vec<_>>().try_into().map_err(|e| input.error(format!("{:?}", e)))
423        )
424    }
425
426    // ================ Signature =====================
427    fn params(input: Node) -> ParseResult<Vec<Id>> {
428        Ok(match_nodes!(
429            input.into_children();
430            [identifier(id)..] => id.collect()
431        ))
432    }
433
434    fn args(input: Node) -> ParseResult<Vec<u64>> {
435        Ok(match_nodes!(
436            input.into_children();
437            [bitwidth(bw)..] => bw.collect(),
438            [] => vec![]
439        ))
440    }
441
442    fn io_port(input: Node) -> ParseResult<(Id, Width, Attributes)> {
443        Ok(match_nodes!(
444            input.into_children();
445            [at_attributes(attrs), identifier(id), bitwidth(value)] =>
446                (id, Width::Const { value }, attrs),
447            [at_attributes(attrs), identifier(id), identifier(value)] =>
448                (id, Width::Param { value }, attrs)
449        ))
450    }
451
452    fn inputs(input: Node) -> ParseResult<Vec<PortDef<Width>>> {
453        Ok(match_nodes!(
454            input.into_children();
455            [io_port((name, width, attributes))] => {
456                let pd = PortDef::new(
457                    name, width, Direction::Input, attributes
458                );
459                vec![pd]
460            },
461            [io_port((name, width, attributes)), comma(_), inputs(rest)] => {
462                let pd = PortDef::new(
463                    name, width, Direction::Input, attributes
464                );
465                let mut v = vec![pd];
466                v.extend(rest);
467                v
468            }
469        ))
470    }
471
472    fn outputs(input: Node) -> ParseResult<Vec<PortDef<Width>>> {
473        Ok(match_nodes!(
474            input.into_children();
475            [io_port((name, width, attributes))] => {
476                let pd = PortDef::new(
477                    name, width, Direction::Output, attributes
478                );
479                vec![pd]
480            },
481            [io_port((name, width, attributes)), comma(_), outputs(rest)] => {
482                let pd = PortDef::new(
483                    name, width, Direction::Output, attributes
484                );
485                let mut v = vec![pd];
486                v.extend(rest);
487                v
488            }
489        ))
490    }
491
492    fn signature(input: Node) -> ParseResult<Vec<PortDef<Width>>> {
493        Ok(match_nodes!(
494            input.into_children();
495            // NOTE(rachit): We expect the signature to be extended to have `go`,
496            // `done`, `reset,`, and `clk`.
497            [] => Vec::with_capacity(4),
498            [inputs(ins)] =>  ins ,
499            [outputs(outs)] =>  outs ,
500            [inputs(ins), outputs(outs)] => {
501                ins.into_iter().chain(outs.into_iter()).collect()
502            },
503        ))
504    }
505
506    // ==============Primitives=====================
507    fn sig_with_params(
508        input: Node,
509    ) -> ParseResult<(Vec<Id>, Vec<PortDef<Width>>)> {
510        Ok(match_nodes!(
511            input.into_children();
512            [params(p), signature(s)] => (p, s),
513            [signature(s)] => (vec![], s),
514        ))
515    }
516    fn primitive(input: Node) -> ParseResult<Primitive> {
517        let span = Self::get_span(&input);
518        Ok(match_nodes!(
519            input.into_children();
520            [name_with_attribute((name, attrs)), sig_with_params((p, s))] => Primitive {
521                name,
522                params: p,
523                signature: s,
524                attributes: attrs.add_span(span),
525                is_comb: false,
526                latency: None,
527                body: None,
528            },
529            [comb_or_static(cs_res), name_with_attribute((name, attrs)), sig_with_params((p, s))] => Primitive {
530                name,
531                params: p,
532                signature: s,
533                attributes: attrs.add_span(span),
534                is_comb: cs_res.is_none(),
535                latency: cs_res,
536                body: None,
537            }
538        ))
539    }
540
541    // ================ Cells =====================
542    fn cell_without_semi(input: Node) -> ParseResult<ast::Cell> {
543        let span = Self::get_span(&input);
544        Ok(match_nodes!(
545            input.into_children();
546            [at_attributes(attrs), reference(_), identifier(id), identifier(prim), args(args)] =>
547            ast::Cell::from(id, prim, args, attrs.add_span(span),true),
548            [at_attributes(attrs), identifier(id), identifier(prim), args(args)] =>
549            ast::Cell::from(id, prim, args, attrs.add_span(span),false)
550        ))
551    }
552
553    fn cell(input: Node) -> ParseResult<ast::Cell> {
554        match_nodes!(
555            input.clone().into_children();
556            [cell_without_semi(_)] =>
557                Err(input.error("Declaration is missing `;`")),
558            [cell_without_semi(node), semi(_)] => Ok(node),
559        )
560    }
561
562    fn cells(input: Node) -> ParseResult<Vec<ast::Cell>> {
563        Ok(match_nodes!(
564                input.into_children();
565                [cell(cells)..] => cells.collect()
566        ))
567    }
568
569    // ================ Wires =====================
570    fn port(input: Node) -> ParseResult<ast::Port> {
571        Ok(match_nodes!(
572            input.into_children();
573            [identifier(component), identifier(port)] =>
574                ast::Port::Comp { component, port },
575            [identifier(port)] => ast::Port::This { port }
576        ))
577    }
578
579    fn hole(input: Node) -> ParseResult<ast::Port> {
580        Ok(match_nodes!(
581            input.into_children();
582            [identifier(group), identifier(name)] => ast::Port::Hole { group, name }
583        ))
584    }
585
586    #[allow(clippy::upper_case_acronyms)]
587    fn LHS(input: Node) -> ParseResult<ast::Port> {
588        Ok(match_nodes!(
589            input.into_children();
590            [port(port)] => port,
591            [hole(hole)] => hole
592        ))
593    }
594
595    fn expr(input: Node) -> ParseResult<ast::Atom> {
596        match_nodes!(
597            input.into_children();
598            [LHS(port)] => Ok(ast::Atom::Port(port)),
599            [num_lit(num)] => Ok(ast::Atom::Num(num)),
600            [bad_num(_)] => unreachable!("bad_num returned non-error result"),
601        )
602    }
603
604    fn guard_eq(_input: Node) -> ParseResult<()> {
605        Ok(())
606    }
607    fn guard_neq(_input: Node) -> ParseResult<()> {
608        Ok(())
609    }
610    fn guard_leq(_input: Node) -> ParseResult<()> {
611        Ok(())
612    }
613    fn guard_geq(_input: Node) -> ParseResult<()> {
614        Ok(())
615    }
616    fn guard_lt(_input: Node) -> ParseResult<()> {
617        Ok(())
618    }
619    fn guard_gt(_input: Node) -> ParseResult<()> {
620        Ok(())
621    }
622
623    fn cmp_expr(input: Node) -> ParseResult<ast::CompGuard> {
624        Ok(match_nodes!(
625            input.into_children();
626            [expr(l), guard_eq(_), expr(r)] => (GC::Eq, l, r),
627            [expr(l), guard_neq(_), expr(r)] => (GC::Neq, l, r),
628            [expr(l), guard_geq(_), expr(r)] => (GC::Geq, l, r),
629            [expr(l), guard_leq(_), expr(r)] => (GC::Leq, l, r),
630            [expr(l), guard_gt(_), expr(r)] =>  (GC::Gt, l, r),
631            [expr(l), guard_lt(_), expr(r)] =>  (GC::Lt, l, r),
632        ))
633    }
634
635    fn guard_not(_input: Node) -> ParseResult<()> {
636        Ok(())
637    }
638
639    fn guard_expr(input: Node) -> ParseResult<Box<GuardExpr>> {
640        let ud = input.user_data().clone();
641        Self::guard_expr_helper(ud, input.into_pair().into_inner())
642    }
643
644    fn static_guard_expr(input: Node) -> ParseResult<Box<StaticGuardExpr>> {
645        let ud = input.user_data().clone();
646        Self::static_guard_expr_helper(ud, input.into_pair().into_inner())
647    }
648
649    fn term(input: Node) -> ParseResult<ast::GuardExpr> {
650        Ok(match_nodes!(
651            input.into_children();
652            [guard_expr(guard)] => *guard,
653            [cmp_expr((gc, a1, a2))] => ast::GuardExpr::CompOp((gc, a1, a2)),
654            [expr(e)] => ast::GuardExpr::Atom(e),
655            [guard_not(_), expr(e)] => {
656                ast::GuardExpr::Not(Box::new(ast::GuardExpr::Atom(e)))
657            },
658            [guard_not(_), cmp_expr((gc, a1, a2))] => {
659                ast::GuardExpr::Not(Box::new(ast::GuardExpr::CompOp((gc, a1, a2))))
660            },
661            [guard_not(_), guard_expr(e)] => {
662                ast::GuardExpr::Not(e)
663            },
664            [guard_not(_), expr(e)] =>
665                ast::GuardExpr::Not(Box::new(ast::GuardExpr::Atom(e)))
666        ))
667    }
668
669    fn static_term(input: Node) -> ParseResult<ast::StaticGuardExpr> {
670        Ok(match_nodes!(
671            input.into_children();
672            [static_timing_expr(interval)] => ast::StaticGuardExpr::StaticInfo(interval),
673            [static_guard_expr(guard)] => *guard,
674            [cmp_expr((gc, a1, a2))] => ast::StaticGuardExpr::CompOp((gc, a1, a2)),
675            [expr(e)] => ast::StaticGuardExpr::Atom(e),
676            [guard_not(_), expr(e)] => {
677                ast::StaticGuardExpr::Not(Box::new(ast::StaticGuardExpr::Atom(e)))
678            },
679            [guard_not(_), cmp_expr((gc, a1, a2))] => {
680                ast::StaticGuardExpr::Not(Box::new(ast::StaticGuardExpr::CompOp((gc, a1, a2))))
681            },
682            [guard_not(_), static_guard_expr(e)] => {
683                ast::StaticGuardExpr::Not(e)
684            },
685            [guard_not(_), expr(e)] =>
686                ast::StaticGuardExpr::Not(Box::new(ast::StaticGuardExpr::Atom(e)))
687        ))
688    }
689
690    fn switch_stmt(input: Node) -> ParseResult<ast::Guard> {
691        Ok(match_nodes!(
692            input.into_children();
693            [guard_expr(guard), expr(expr)] => ast::Guard { guard: Some(*guard), expr },
694        ))
695    }
696
697    fn static_switch_stmt(input: Node) -> ParseResult<ast::StaticGuard> {
698        Ok(match_nodes!(
699            input.into_children();
700            [static_guard_expr(guard), expr(expr)] => ast::StaticGuard { guard: Some(*guard), expr },
701        ))
702    }
703
704    fn wire(input: Node) -> ParseResult<ast::Wire> {
705        let span = Self::get_span(&input);
706        Ok(match_nodes!(
707            input.into_children();
708            [at_attributes(attrs), LHS(dest), expr(expr)] => ast::Wire {
709                src: ast::Guard { guard: None, expr },
710                dest,
711                attributes: attrs.add_span(span),
712            },
713            [at_attributes(attrs), LHS(dest), switch_stmt(src)] => ast::Wire {
714                src,
715                dest,
716                attributes: attrs.add_span(span),
717            }
718        ))
719    }
720
721    fn static_wire(input: Node) -> ParseResult<ast::StaticWire> {
722        let span = Self::get_span(&input);
723        Ok(match_nodes!(
724            input.into_children();
725            [at_attributes(attrs), LHS(dest), expr(expr)] => ast::StaticWire {
726                src: ast::StaticGuard { guard: None, expr },
727                dest,
728                attributes: attrs.add_span(span),
729            },
730            [at_attributes(attrs), LHS(dest), static_switch_stmt(src)] => ast::StaticWire {
731                src,
732                dest,
733                attributes: attrs.add_span(span),
734            }
735        ))
736    }
737
738    fn static_timing_expr(input: Node) -> ParseResult<(u64, u64)> {
739        Ok(match_nodes!(
740            input.into_children();
741            [bitwidth(single_num)] => (single_num, single_num+1),
742            [bitwidth(start_interval), bitwidth(end_interval)] => (start_interval, end_interval)
743        ))
744    }
745
746    fn group(input: Node) -> ParseResult<ast::Group> {
747        let span = Self::get_span(&input);
748        Ok(match_nodes!(
749            input.into_children();
750            [name_with_attribute((name, attrs)), wire(wire)..] => ast::Group {
751                name,
752                attributes: attrs.add_span(span),
753                wires: wire.collect(),
754                is_comb: false,
755            },
756            [comb(_), name_with_attribute((name, attrs)), wire(wire)..] => ast::Group {
757                name,
758                attributes: attrs.add_span(span),
759                wires: wire.collect(),
760                is_comb: true,
761            }
762        ))
763    }
764
765    fn static_group(input: Node) -> ParseResult<ast::StaticGroup> {
766        let span = Self::get_span(&input);
767        Ok(match_nodes!(
768            input.into_children();
769            [static_annotation(latency), name_with_attribute((name, attrs)), static_wire(wire)..] => ast::StaticGroup {
770                name,
771                attributes: attrs.add_span(span),
772                wires: wire.collect(),
773                latency,
774            }
775        ))
776    }
777
778    fn connections(
779        input: Node,
780    ) -> ParseResult<(Vec<ast::Wire>, Vec<ast::Group>, Vec<ast::StaticGroup>)>
781    {
782        let mut wires = Vec::new();
783        let mut groups = Vec::new();
784        let mut static_groups = Vec::new();
785        for node in input.into_children() {
786            match node.as_rule() {
787                Rule::wire => wires.push(Self::wire(node)?),
788                Rule::group => groups.push(Self::group(node)?),
789                Rule::static_group => {
790                    static_groups.push(Self::static_group(node)?)
791                }
792                _ => unreachable!(),
793            }
794        }
795        Ok((wires, groups, static_groups))
796    }
797
798    // ================ Control program =====================
799    fn invoke_arg(input: Node) -> ParseResult<(Id, ast::Atom)> {
800        Ok(match_nodes!(
801            input.into_children();
802            [identifier(name), port(p)] => (name, ast::Atom::Port(p)),
803            [identifier(name), num_lit(bn)] => (name, ast::Atom::Num(bn))
804
805        ))
806    }
807
808    fn invoke_args(input: Node) -> ParseResult<Vec<(Id, ast::Atom)>> {
809        Ok(match_nodes!(
810            input.into_children();
811            [invoke_arg(args)..] => args.collect()
812        ))
813    }
814
815    fn invoke_ref_arg(input: Node) -> ParseResult<(Id, Id)> {
816        Ok(match_nodes!(
817            input.into_children();
818            [identifier(outcell), identifier(incell)] => (outcell, incell)
819        ))
820    }
821
822    fn invoke_ref_args(input: Node) -> ParseResult<Vec<(Id, Id)>> {
823        Ok(match_nodes!(
824            input.into_children();
825            [invoke_ref_arg(args)..] => args.collect(),
826            [] => Vec::new()
827        ))
828    }
829
830    fn invoke(input: Node) -> ParseResult<ast::Control> {
831        let span = Self::get_span(&input);
832        Ok(match_nodes!(
833            input.into_children();
834            [at_attributes(attrs), identifier(comp), invoke_ref_args(cells),invoke_args(inputs), invoke_args(outputs)] =>
835                ast::Control::Invoke {
836                    comp,
837                    inputs,
838                    outputs,
839                    attributes: attrs.add_span(span),
840                    comb_group: None,
841                    ref_cells: cells
842                },
843            [at_attributes(attrs), identifier(comp), invoke_ref_args(cells),invoke_args(inputs), invoke_args(outputs), identifier(group)] =>
844                ast::Control::Invoke {
845                    comp,
846                    inputs,
847                    outputs,
848                    attributes: attrs.add_span(span),
849                    comb_group: Some(group),
850                    ref_cells: cells
851                },
852        ))
853    }
854
855    fn static_invoke(input: Node) -> ParseResult<ast::Control> {
856        let span = Self::get_span(&input);
857        Ok(match_nodes!(
858            input.into_children();
859            [at_attributes(attrs), static_optional_latency(latency), identifier(comp), invoke_ref_args(cells),invoke_args(inputs), invoke_args(outputs)] =>
860                ast::Control::StaticInvoke {
861                    comp,
862                    inputs,
863                    outputs,
864                    attributes: attrs.add_span(span),
865                    ref_cells: cells,
866                    latency,
867                    comb_group: None,
868                },
869                [at_attributes(attrs), static_optional_latency(latency), identifier(comp), invoke_ref_args(cells),invoke_args(inputs), invoke_args(outputs), identifier(group)] =>
870                ast::Control::StaticInvoke {
871                    comp,
872                    inputs,
873                    outputs,
874                    attributes: attrs.add_span(span),
875                    ref_cells: cells,
876                    latency,
877                    comb_group: Some(group),
878                },
879        ))
880    }
881
882    fn empty(input: Node) -> ParseResult<ast::Control> {
883        let span = Self::get_span(&input);
884        Ok(match_nodes!(
885            input.into_children();
886            [at_attributes(attrs)] => ast::Control::Empty {
887                attributes: attrs.add_span(span)
888            }
889        ))
890    }
891
892    fn enable(input: Node) -> ParseResult<ast::Control> {
893        let span = Self::get_span(&input);
894        Ok(match_nodes!(
895            input.into_children();
896            [at_attributes(attrs), identifier(name)] => ast::Control::Enable {
897                comp: name,
898                attributes: attrs.add_span(span)
899            }
900        ))
901    }
902
903    fn seq(input: Node) -> ParseResult<ast::Control> {
904        let span = Self::get_span(&input);
905        Ok(match_nodes!(
906            input.into_children();
907            [at_attributes(attrs), stmt(stmt)..] => ast::Control::Seq {
908                stmts: stmt.collect(),
909                attributes: attrs.add_span(span),
910            }
911        ))
912    }
913
914    fn static_seq(input: Node) -> ParseResult<ast::Control> {
915        let span = Self::get_span(&input);
916        Ok(match_nodes!(
917            input.into_children();
918            [at_attributes(attrs), static_optional_latency(latency) ,stmt(stmt)..] => ast::Control::StaticSeq {
919                stmts: stmt.collect(),
920                attributes: attrs.add_span(span),
921                latency,
922            }
923        ))
924    }
925
926    fn par(input: Node) -> ParseResult<ast::Control> {
927        let span = Self::get_span(&input);
928        Ok(match_nodes!(
929            input.into_children();
930            [at_attributes(attrs), stmt(stmt)..] => ast::Control::Par {
931                stmts: stmt.collect(),
932                attributes: attrs.add_span(span),
933            }
934        ))
935    }
936
937    fn static_par(input: Node) -> ParseResult<ast::Control> {
938        let span = Self::get_span(&input);
939        Ok(match_nodes!(
940            input.into_children();
941            [at_attributes(attrs), static_optional_latency(latency) ,stmt(stmt)..] => ast::Control::StaticPar {
942                stmts: stmt.collect(),
943                attributes: attrs.add_span(span),
944                latency,
945            }
946        ))
947    }
948
949    fn port_with(input: Node) -> ParseResult<(ast::Port, Option<Id>)> {
950        Ok(match_nodes!(
951            input.into_children();
952            [port(port), identifier(cond)] => (port, Some(cond)),
953            [port(port)] => (port, None),
954        ))
955    }
956
957    fn if_stmt(input: Node) -> ParseResult<ast::Control> {
958        let span = Self::get_span(&input);
959        Ok(match_nodes!(
960            input.into_children();
961            [at_attributes(attrs), port_with((port, cond)), block(stmt)] => ast::Control::If {
962                port,
963                cond,
964                tbranch: Box::new(stmt),
965                fbranch: Box::new(ast::Control::Empty { attributes: Attributes::default() }),
966                attributes: attrs.add_span(span),
967            },
968            [at_attributes(attrs), port_with((port, cond)), block(tbranch), block(fbranch)] =>
969                ast::Control::If {
970                    port,
971                    cond,
972                    tbranch: Box::new(tbranch),
973                    fbranch: Box::new(fbranch),
974                    attributes: attrs.add_span(span),
975                },
976            [at_attributes(attrs), port_with((port, cond)), block(tbranch), if_stmt(fbranch)] =>
977                ast::Control::If {
978                    port,
979                    cond,
980                    tbranch: Box::new(tbranch),
981                    fbranch: Box::new(fbranch),
982                    attributes: attrs.add_span(span),
983                },
984
985        ))
986    }
987
988    fn static_if_stmt(input: Node) -> ParseResult<ast::Control> {
989        let span = Self::get_span(&input);
990        Ok(match_nodes!(
991            input.into_children();
992            [at_attributes(attrs), static_optional_latency(latency), port(port), block(stmt)] => ast::Control::StaticIf {
993                port,
994                tbranch: Box::new(stmt),
995                fbranch: Box::new(ast::Control::Empty { attributes: Attributes::default() }),
996                attributes: attrs.add_span(span),
997                latency,
998            },
999            [at_attributes(attrs), static_optional_latency(latency), port(port), block(tbranch), block(fbranch)] =>
1000                ast::Control::StaticIf {
1001                    port,
1002                    tbranch: Box::new(tbranch),
1003                    fbranch: Box::new(fbranch),
1004                    attributes: attrs.add_span(span),
1005                    latency,
1006                },
1007            [at_attributes(attrs), static_optional_latency(latency), port(port), block(tbranch), static_if_stmt(fbranch)] =>
1008                ast::Control::StaticIf {
1009                    port,
1010                    tbranch: Box::new(tbranch),
1011                    fbranch: Box::new(fbranch),
1012                    attributes: attrs.add_span(span),
1013                    latency,
1014                }
1015        ))
1016    }
1017
1018    fn while_stmt(input: Node) -> ParseResult<ast::Control> {
1019        let span = Self::get_span(&input);
1020        Ok(match_nodes!(
1021            input.into_children();
1022            [at_attributes(attrs), port_with((port, cond)), block(stmt)] => ast::Control::While {
1023                port,
1024                cond,
1025                body: Box::new(stmt),
1026                attributes: attrs.add_span(span),
1027            }
1028        ))
1029    }
1030
1031    fn repeat_stmt(input: Node) -> ParseResult<ast::Control> {
1032        let span = Self::get_span(&input);
1033        Ok(match_nodes!(
1034            input.into_children();
1035            [at_attributes(attrs), bitwidth(num_repeats) , block(stmt)] => ast::Control::Repeat {
1036                num_repeats,
1037                body: Box::new(stmt),
1038                attributes: attrs.add_span(span),
1039            },
1040            [at_attributes(attrs), static_word(_), bitwidth(num_repeats) , block(stmt)] => ast::Control::StaticRepeat {
1041                num_repeats,
1042                body: Box::new(stmt),
1043                attributes: attrs.add_span(span),
1044            }
1045        ))
1046    }
1047
1048    fn stmt(input: Node) -> ParseResult<ast::Control> {
1049        Ok(match_nodes!(
1050            input.into_children();
1051            [enable(data)] => data,
1052            [empty(data)] => data,
1053            [invoke(data)] => data,
1054            [static_invoke(data)] => data,
1055            [seq(data)] => data,
1056            [static_seq(data)] => data,
1057            [par(data)] => data,
1058            [static_par(data)] => data,
1059            [if_stmt(data)] => data,
1060            [static_if_stmt(data)] => data,
1061            [while_stmt(data)] => data,
1062            [repeat_stmt(data)] => data,
1063        ))
1064    }
1065
1066    fn block(input: Node) -> ParseResult<ast::Control> {
1067        Ok(match_nodes!(
1068            input.into_children();
1069            [stmt(stmt)] => stmt,
1070            [stmts_without_block(seq)] => seq,
1071        ))
1072    }
1073
1074    fn stmts_without_block(input: Node) -> ParseResult<ast::Control> {
1075        match_nodes!(
1076            input.clone().into_children();
1077            [stmt(stmt)..] => Ok(ast::Control::Seq {
1078                stmts: stmt.collect(),
1079                attributes: Attributes::default(),
1080            })
1081        )
1082    }
1083
1084    fn control(input: Node) -> ParseResult<ast::Control> {
1085        Ok(match_nodes!(
1086            input.into_children();
1087            [block(stmt)] => stmt,
1088            [] => ast::Control::empty()
1089        ))
1090    }
1091
1092    fn component(input: Node) -> ParseResult<ComponentDef> {
1093        let span = Self::get_span(&input);
1094        match_nodes!(
1095            input.clone().into_children();
1096            [
1097                comb_or_static(cs_res),
1098                name_with_attribute((name, attributes)),
1099                signature(sig),
1100                cells(cells),
1101                connections(connections)
1102            ] => {
1103                if cs_res.is_some() {
1104                    Err(input.error("Static Component must have defined control"))?;
1105                }
1106                let (continuous_assignments, groups, static_groups) = connections;
1107                let sig = sig.into_iter().map(|pd| {
1108                    if let Width::Const { value } = pd.width {
1109                        Ok(PortDef::new(
1110                            pd.name(),
1111                            value,
1112                            pd.direction,
1113                            pd.attributes
1114                        ))
1115                    } else {
1116                        Err(input.error("Components cannot use parameters"))
1117                    }
1118                }).collect::<Result<_, _>>()?;
1119                Ok(ComponentDef {
1120                    name,
1121                    signature: sig,
1122                    cells,
1123                    groups,
1124                    static_groups,
1125                    continuous_assignments,
1126                    control: Control::empty(),
1127                    attributes: attributes.add_span(span),
1128                    is_comb: true,
1129                    latency: None,
1130                })
1131            },
1132            [
1133                name_with_attribute((name, attributes)),
1134                signature(sig),
1135                cells(cells),
1136                connections(connections),
1137                control(control)
1138            ] => {
1139                let (continuous_assignments, groups, static_groups) = connections;
1140                let sig = sig.into_iter().map(|pd| {
1141                    if let Width::Const { value } = pd.width {
1142                        Ok(PortDef::new(
1143                            pd.name(),
1144                            value,
1145                            pd.direction,
1146                            pd.attributes
1147                        ))
1148                    } else {
1149                        Err(input.error("Components cannot use parameters"))
1150                    }
1151                }).collect::<Result<_, _>>()?;
1152                Ok(ComponentDef {
1153                    name,
1154                    signature: sig,
1155                    cells,
1156                    groups,
1157                    static_groups,
1158                    continuous_assignments,
1159                    control,
1160                    attributes: attributes.add_span(span),
1161                    is_comb: false,
1162                    latency: None,
1163                })
1164            },
1165            [
1166                comb_or_static(cs_res),
1167                name_with_attribute((name, attributes)),
1168                signature(sig),
1169                cells(cells),
1170                connections(connections),
1171                control(control),
1172            ] => {
1173                let (continuous_assignments, groups, static_groups) = connections;
1174                let sig = sig.into_iter().map(|pd| {
1175                    if let Width::Const { value } = pd.width {
1176                        Ok(PortDef::new(
1177                            pd.name(),
1178                            value,
1179                            pd.direction,
1180                            pd.attributes
1181                        ))
1182                    } else {
1183                        Err(input.error("Components cannot use parameters"))
1184                    }
1185                }).collect::<Result<_, _>>()?;
1186                Ok(ComponentDef {
1187                    name,
1188                    signature: sig,
1189                    cells,
1190                    groups,
1191                    static_groups,
1192                    continuous_assignments,
1193                    control,
1194                    attributes: attributes.add_span(span),
1195                    is_comb: cs_res.is_none(),
1196                    latency: cs_res,
1197                })
1198            },
1199        )
1200    }
1201
1202    fn imports(input: Node) -> ParseResult<Vec<String>> {
1203        Ok(match_nodes!(
1204            input.into_children();
1205            [string_lit(path)..] => path.collect()
1206        ))
1207    }
1208
1209    fn ext(input: Node) -> ParseResult<(Option<String>, Vec<Primitive>)> {
1210        Ok(match_nodes!(
1211            input.into_children();
1212            [string_lit(file), primitive(prims)..] => (Some(file), prims.collect())
1213        ))
1214    }
1215
1216    fn prim_inline(input: Node) -> ParseResult<Primitive> {
1217        let span = Self::get_span(&input);
1218        Ok(match_nodes!(
1219            input.into_children();
1220            [name_with_attribute((name, attrs)), sig_with_params((p, s)), block_string(b)] => {
1221            Primitive {
1222                name,
1223                params: p,
1224                signature: s,
1225                attributes: attrs.add_span(span),
1226                is_comb: false,
1227                latency: None,
1228                body: Some(b),
1229            }},
1230            [comb_or_static(cs_res), name_with_attribute((name, attrs)), sig_with_params((p, s)), block_string(b)] => Primitive {
1231                name,
1232                params: p,
1233                signature: s,
1234                attributes: attrs.add_span(span),
1235                is_comb: cs_res.is_none(),
1236                latency: cs_res,
1237                body: Some(b),
1238            }
1239        ))
1240    }
1241
1242    fn extern_or_component(input: Node) -> ParseResult<ExtOrComp> {
1243        Ok(match_nodes!(
1244            input.into_children();
1245            [component(comp)] => ExtOrComp::Comp(comp),
1246            [ext(ext)] => ExtOrComp::Ext(ext),
1247            [prim_inline(prim_inline)] => ExtOrComp::PrimInline(prim_inline),
1248        ))
1249    }
1250
1251    fn externs_and_comps(
1252        input: Node,
1253    ) -> ParseResult<impl Iterator<Item = ExtOrComp>> {
1254        Ok(match_nodes!(input.into_children();
1255            [extern_or_component(e)..] => e
1256        ))
1257    }
1258
1259    fn any_char(input: Node) -> ParseResult<String> {
1260        Ok(input.as_str().into())
1261    }
1262
1263    fn metadata_char(input: Node) -> ParseResult<String> {
1264        Ok(match_nodes!(input.into_children();
1265            [any_char(c)] => c,
1266        ))
1267    }
1268
1269    fn metadata(input: Node) -> ParseResult<String> {
1270        Ok(match_nodes!(input.into_children();
1271            [metadata_char(c)..] => c.collect::<String>().trim().into()
1272        ))
1273    }
1274
1275    fn file(input: Node) -> ParseResult<ast::NamespaceDef> {
1276        Ok(match_nodes!(
1277            input.into_children();
1278            // There really seems to be no straightforward way to resolve this
1279            // duplication
1280            [imports(imports), externs_and_comps(mixed), metadata(m), EOI(_)] => {
1281                let mut namespace =
1282                    ast::NamespaceDef {
1283                        imports,
1284                        components: Vec::new(),
1285                        externs: Vec::new(),
1286                        metadata: if m != *"" { Some(m) } else { None }
1287                    };
1288                for m in mixed {
1289                    match m {
1290                        ExtOrComp::Ext(ext) => namespace.externs.push(ext),
1291                        ExtOrComp::Comp(comp) => namespace.components.push(comp),
1292                        ExtOrComp::PrimInline(prim) => {
1293                            if let Some((_, prim_inlines)) = namespace.externs.iter_mut().find(|(filename, _)| filename.is_none()) {
1294                                prim_inlines.push(prim)
1295                            }
1296                            else{
1297                                namespace.externs.push((None, vec![prim]));
1298                            }
1299                        },
1300                    }
1301                }
1302                namespace
1303            },
1304            [imports(imports), externs_and_comps(mixed), EOI(_)] => {
1305                let mut namespace =
1306                    ast::NamespaceDef {
1307                        imports,
1308                        components: Vec::new(),
1309                        externs: Vec::new(),
1310                        metadata: None
1311                    };
1312                for m in mixed {
1313                    match m {
1314                        ExtOrComp::Ext(ext) => namespace.externs.push(ext),
1315                        ExtOrComp::Comp(comp) => namespace.components.push(comp),
1316                        ExtOrComp::PrimInline(prim) => {
1317                            if let Some((_, prim_inlines)) = namespace.externs.iter_mut().find(|(filename, _)| filename.is_none()) {
1318                                prim_inlines.push(prim)
1319                            }
1320                            else{
1321                                namespace.externs.push((None, vec![prim]));
1322                            }
1323                        },
1324                    }
1325                }
1326                namespace
1327            },
1328
1329        ))
1330    }
1331}