hugr_model/v0/ast/
parse.rs

1// NOTE: We use the `pest` library for parsing. This library is convenient, but
2// performance is mediocre. In the case that we find that parsing is too slow,
3// we can replace the parser.
4
5// NOTE: The `pest` library returns a parsed AST which we then transform into
6// our AST data structures. The `pest` AST is guaranteed to conform to the
7// grammar, but this is not automatically visible from the types. Therefore this
8// module contains many `unwrap`s and `unreachable!`s, which will not fail on
9// any input unless there is a bug in the parser. This is perhaps aesthetically
10// unsatisfying but it is aligned with the intended usage pattern of `pest`.
11
12// NOTE: The `parse_` functions are implementation details since they refer to
13// `pest` data structures. We expose parsing via implementations of the
14// `FromStr` trait.
15
16use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::Engine as _;
20use base64::prelude::BASE64_STANDARD;
21use ordered_float::OrderedFloat;
22use pest::Parser as _;
23use pest::iterators::{Pair, Pairs};
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Package, Param, Region, Symbol, VarName, Visibility};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35    use pest_derive::Parser;
36
37    // NOTE: The pest derive macro generates a `Rule` enum. We do not want this to be
38    // part of the public API, and so we hide it within this private module.
39
40    #[derive(Parser)]
41    #[grammar = "v0/ast/hugr.pest"]
42    pub struct HugrParser;
43}
44
45fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
46    debug_assert_eq!(Rule::symbol_name, pair.as_rule());
47    Ok(SymbolName(pair.as_str().into()))
48}
49
50fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
51    debug_assert_eq!(Rule::term_var, pair.as_rule());
52    Ok(VarName(pair.as_str()[1..].into()))
53}
54
55fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
56    debug_assert_eq!(Rule::link_name, pair.as_rule());
57    Ok(LinkName(pair.as_str()[1..].into()))
58}
59
60fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
61    debug_assert_eq!(Rule::term, pair.as_rule());
62    let pair = pair.into_inner().next().unwrap();
63
64    Ok(match pair.as_rule() {
65        Rule::term_wildcard => Term::Wildcard,
66        Rule::term_var => Term::Var(parse_var_name(pair)?),
67        Rule::term_apply => {
68            let mut pairs = pair.into_inner();
69            let symbol = parse_symbol_name(pairs.next().unwrap())?;
70            let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
71            Term::Apply(symbol, terms)
72        }
73        Rule::term_list => {
74            let pairs = pair.into_inner();
75            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
76            Term::List(parts)
77        }
78        Rule::term_tuple => {
79            let pairs = pair.into_inner();
80            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
81            Term::Tuple(parts)
82        }
83        Rule::literal => {
84            let literal = parse_literal(pair)?;
85            Term::Literal(literal)
86        }
87        Rule::term_const_func => {
88            let mut pairs = pair.into_inner();
89            let region = parse_region(pairs.next().unwrap())?;
90            Term::Func(Arc::new(region))
91        }
92        _ => unreachable!(),
93    })
94}
95
96fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
97    debug_assert_eq!(pair.as_rule(), Rule::literal);
98    let pair = pair.into_inner().next().unwrap();
99
100    Ok(match pair.as_rule() {
101        Rule::literal_string => Literal::Str(parse_string(pair)?),
102        Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
103        Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
104        Rule::literal_float => Literal::Float(parse_float(pair)?),
105        _ => unreachable!("expected literal"),
106    })
107}
108
109fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
110    debug_assert_eq!(pair.as_rule(), Rule::part);
111    let pair = pair.into_inner().next().unwrap();
112
113    Ok(match pair.as_rule() {
114        Rule::term => SeqPart::Item(parse_term(pair)?),
115        Rule::spliced_term => {
116            let mut pairs = pair.into_inner();
117            let term = parse_term(pairs.next().unwrap())?;
118            SeqPart::Splice(term)
119        }
120        _ => unreachable!("expected term or spliced term"),
121    })
122}
123
124fn parse_package(pair: Pair<Rule>) -> ParseResult<Package> {
125    debug_assert_eq!(pair.as_rule(), Rule::package);
126    let mut pairs = pair.into_inner();
127
128    let modules = take_rule(&mut pairs, Rule::module)
129        .map(parse_module)
130        .collect::<ParseResult<_>>()?;
131
132    Ok(Package { modules })
133}
134
135fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
136    debug_assert_eq!(pair.as_rule(), Rule::module);
137    let mut pairs = pair.into_inner();
138    let meta = parse_meta_items(&mut pairs)?;
139    let children = parse_nodes(&mut pairs)?;
140
141    Ok(Module {
142        root: Region {
143            kind: RegionKind::Module,
144            children,
145            meta,
146            ..Default::default()
147        },
148    })
149}
150
151fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
152    debug_assert_eq!(pair.as_rule(), Rule::region);
153    let mut pairs = pair.into_inner();
154
155    let kind = parse_region_kind(pairs.next().unwrap())?;
156    let sources = parse_port_list(&mut pairs)?;
157    let targets = parse_port_list(&mut pairs)?;
158    let signature = parse_optional_signature(&mut pairs)?;
159    let meta = parse_meta_items(&mut pairs)?;
160    let children = parse_nodes(&mut pairs)?;
161
162    Ok(Region {
163        kind,
164        sources,
165        targets,
166        children,
167        meta,
168        signature,
169    })
170}
171
172fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
173    debug_assert_eq!(pair.as_rule(), Rule::region_kind);
174
175    Ok(match pair.as_str() {
176        "dfg" => RegionKind::DataFlow,
177        "cfg" => RegionKind::ControlFlow,
178        "mod" => RegionKind::Module,
179        _ => unreachable!(),
180    })
181}
182
183fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
184    take_rule(pairs, Rule::node).map(parse_node).collect()
185}
186
187fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
188    debug_assert_eq!(pair.as_rule(), Rule::node);
189    let mut pairs = pair.into_inner();
190    let pair = pairs.next().unwrap();
191    let rule = pair.as_rule();
192    let mut pairs = pair.into_inner();
193
194    let operation = match rule {
195        Rule::node_dfg => Operation::Dfg,
196        Rule::node_cfg => Operation::Cfg,
197        Rule::node_block => Operation::Block,
198        Rule::node_tail_loop => Operation::TailLoop,
199        Rule::node_cond => Operation::Conditional,
200
201        Rule::node_import => {
202            let name = parse_symbol_name(pairs.next().unwrap())?;
203            Operation::Import(name)
204        }
205
206        Rule::node_custom => {
207            let term = parse_term(pairs.next().unwrap())?;
208            Operation::Custom(term)
209        }
210
211        Rule::node_define_func => {
212            let symbol = parse_symbol(pairs.next().unwrap())?;
213            Operation::DefineFunc(Box::new(symbol))
214        }
215        Rule::node_declare_func => {
216            let symbol = parse_symbol(pairs.next().unwrap())?;
217            Operation::DeclareFunc(Box::new(symbol))
218        }
219        Rule::node_define_alias => {
220            let symbol = parse_symbol(pairs.next().unwrap())?;
221            let value = parse_term(pairs.next().unwrap())?;
222            Operation::DefineAlias(Box::new(symbol), value)
223        }
224        Rule::node_declare_alias => {
225            let symbol = parse_symbol(pairs.next().unwrap())?;
226            Operation::DeclareAlias(Box::new(symbol))
227        }
228        Rule::node_declare_ctr => {
229            let symbol = parse_symbol(pairs.next().unwrap())?;
230            Operation::DeclareConstructor(Box::new(symbol))
231        }
232        Rule::node_declare_operation => {
233            let symbol = parse_symbol(pairs.next().unwrap())?;
234            Operation::DeclareOperation(Box::new(symbol))
235        }
236
237        _ => unreachable!(),
238    };
239
240    let inputs = parse_port_list(&mut pairs)?;
241    let outputs = parse_port_list(&mut pairs)?;
242    let signature = parse_optional_signature(&mut pairs)?;
243    let meta = parse_meta_items(&mut pairs)?;
244    let regions = pairs
245        .map(|pair| parse_region(pair))
246        .collect::<ParseResult<_>>()?;
247
248    Ok(Node {
249        operation,
250        inputs,
251        outputs,
252        regions,
253        meta,
254        signature,
255    })
256}
257
258fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
259    take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
260}
261
262fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
263    debug_assert_eq!(pair.as_rule(), Rule::meta);
264    let mut pairs = pair.into_inner();
265    parse_term(pairs.next().unwrap())
266}
267
268fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
269    match take_rule(pairs, Rule::signature).next() {
270        Some(pair) => Ok(Some(parse_signature(pair)?)),
271        _ => Ok(None),
272    }
273}
274
275fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
276    debug_assert_eq!(Rule::signature, pair.as_rule());
277    let mut pairs = pair.into_inner();
278    parse_term(pairs.next().unwrap())
279}
280
281fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
282    take_rule(pairs, Rule::param).map(parse_param).collect()
283}
284
285fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
286    debug_assert_eq!(Rule::param, pair.as_rule());
287    let mut pairs = pair.into_inner();
288    let name = parse_var_name(pairs.next().unwrap())?;
289    let r#type = parse_term(pairs.next().unwrap())?;
290    Ok(Param { name, r#type })
291}
292
293fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
294    debug_assert_eq!(Rule::symbol, pair.as_rule());
295
296    let mut pairs = pair.into_inner();
297    let visibility = take_rule(&mut pairs, Rule::visibility)
298        .next()
299        .map(|pair| match pair.as_str() {
300            "public" => Ok(Visibility::Public),
301            "private" => Ok(Visibility::Private),
302            _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()),
303        })
304        .transpose()?;
305    let name = parse_symbol_name(pairs.next().unwrap())?;
306    let params = parse_params(&mut pairs)?;
307    let constraints = parse_constraints(&mut pairs)?;
308    let signature = parse_term(pairs.next().unwrap())?;
309
310    Ok(Symbol {
311        visibility,
312        name,
313        params,
314        constraints,
315        signature,
316    })
317}
318
319fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
320    take_rule(pairs, Rule::where_clause)
321        .map(parse_constraint)
322        .collect()
323}
324
325fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
326    debug_assert_eq!(Rule::where_clause, pair.as_rule());
327    let mut pairs = pair.into_inner();
328    parse_term(pairs.next().unwrap())
329}
330
331fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
332    let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
333        return Ok(Default::default());
334    };
335
336    let pairs = pair.into_inner();
337    pairs.map(parse_link_name).collect()
338}
339
340fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
341    debug_assert_eq!(pair.as_rule(), Rule::literal_string);
342
343    // Any escape sequence is longer than the character it represents.
344    // Therefore the length of this token (minus 2 for the quotes on either
345    // side) is an upper bound for the length of the string.
346    let capacity = pair.as_str().len() - 2;
347    let mut string = String::with_capacity(capacity);
348    let pairs = pair.into_inner();
349
350    for pair in pairs {
351        match pair.as_rule() {
352            Rule::literal_string_raw => string.push_str(pair.as_str()),
353            Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
354                '"' => string.push('"'),
355                '\\' => string.push('\\'),
356                'n' => string.push('\n'),
357                'r' => string.push('\r'),
358                't' => string.push('\t'),
359                _ => unreachable!(),
360            },
361            Rule::literal_string_unicode => {
362                let token_str = pair.as_str();
363                debug_assert_eq!(&token_str[0..3], r"\u{");
364                debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
365                let code_str = &token_str[3..token_str.len() - 1];
366                let code = u32::from_str_radix(code_str, 16).map_err(|_| {
367                    ParseError::custom("invalid unicode escape sequence", pair.as_span())
368                })?;
369                let char = std::char::from_u32(code).ok_or_else(|| {
370                    ParseError::custom("invalid unicode code point", pair.as_span())
371                })?;
372                string.push(char);
373            }
374            _ => unreachable!(),
375        }
376    }
377
378    Ok(string.into())
379}
380
381fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
382    debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
383    let pair = pair.into_inner().next().unwrap();
384    debug_assert_eq!(pair.as_rule(), Rule::base64_string);
385
386    let slice = pair.as_str().as_bytes();
387
388    // Remove the quotes
389    let slice = &slice[1..slice.len() - 1];
390
391    let data = BASE64_STANDARD
392        .decode(slice)
393        .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
394
395    Ok(data.into())
396}
397
398fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
399    debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
400    let value = pair.as_str().trim().parse().unwrap();
401    Ok(value)
402}
403
404fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
405    debug_assert_eq!(pair.as_rule(), Rule::literal_float);
406    let value = pair.as_str().trim().parse().unwrap();
407    Ok(OrderedFloat(value))
408}
409
410fn take_rule<'a, 'i>(
411    pairs: &'i mut Pairs<'a, Rule>,
412    rule: Rule,
413) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
414    std::iter::from_fn(move || {
415        if pairs.peek()?.as_rule() == rule {
416            pairs.next()
417        } else {
418            None
419        }
420    })
421}
422
423type ParseResult<T> = Result<T, ParseError>;
424
425/// An error that occurred during parsing.
426#[derive(Debug, Clone, Error)]
427#[error("{0}")]
428pub struct ParseError(Box<pest::error::Error<Rule>>);
429
430impl ParseError {
431    fn custom(message: &str, span: pest::Span) -> Self {
432        let error = pest::error::Error::new_from_span(
433            pest::error::ErrorVariant::CustomError {
434                message: message.to_string(),
435            },
436            span,
437        );
438        ParseError(Box::new(error))
439    }
440}
441
442macro_rules! impl_from_str {
443    ($ident:ident, $rule:ident, $parse:expr) => {
444        impl FromStr for $ident {
445            type Err = ParseError;
446
447            fn from_str(s: &str) -> Result<Self, Self::Err> {
448                let mut pairs =
449                    HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
450                $parse(pairs.next().unwrap())
451            }
452        }
453    };
454}
455
456impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
457impl_from_str!(VarName, term_var, parse_var_name);
458impl_from_str!(LinkName, link_name, parse_link_name);
459impl_from_str!(Term, term, parse_term);
460impl_from_str!(Node, node, parse_node);
461impl_from_str!(Region, region, parse_region);
462impl_from_str!(Param, param, parse_param);
463impl_from_str!(Package, package, parse_package);
464impl_from_str!(Module, module, parse_module);
465impl_from_str!(SeqPart, part, parse_seq_part);
466impl_from_str!(Literal, literal, parse_literal);
467impl_from_str!(Symbol, symbol, parse_symbol);