Skip to main content

hugr_model/v0/ast/
parse.rs

1// NOTE: We use the `pest` library for parsing. This library is convenient, but
2// performance is mediocre. In the case that we find that parsing is too slow,
3// we can replace the parser.
4
5// NOTE: The `pest` library returns a parsed AST which we then transform into
6// our AST data structures. The `pest` AST is guaranteed to conform to the
7// grammar, but this is not automatically visible from the types. Therefore this
8// module contains many `unwrap`s and `unreachable!`s, which will not fail on
9// any input unless there is a bug in the parser. This is perhaps aesthetically
10// unsatisfying but it is aligned with the intended usage pattern of `pest`.
11
12// NOTE: The `parse_` functions are implementation details since they refer to
13// `pest` data structures. We expose parsing via implementations of the
14// `FromStr` trait.
15
16use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::Engine as _;
20use base64::prelude::BASE64_STANDARD;
21use ordered_float::OrderedFloat;
22use pest::Parser as _;
23use pest::iterators::{Pair, Pairs};
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart, SymbolIdent};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Package, Param, Region, Symbol, VarName, Visibility};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35    use pest_derive::Parser;
36
37    // NOTE: The pest derive macro generates a `Rule` enum. We do not want this to be
38    // part of the public API, and so we hide it within this private module.
39
40    #[derive(Parser)]
41    #[grammar = "v0/ast/hugr.pest"]
42    pub struct HugrParser;
43}
44
45/// Return whether a symbol name matches the bare-symbol syntax.
46///
47/// The grammar rule can match a prefix of the input, so this helper checks that
48/// the parsed token consumes the full name.
49pub(super) fn is_bare_symbol_name(name: &str) -> bool {
50    HugrParser::parse(Rule::bare_symbol_name, name)
51        .ok()
52        .and_then(|mut pairs| pairs.next())
53        .is_some_and(|pair| pair.as_span().end() == name.len())
54}
55
56fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
57    Ok(match pair.as_rule() {
58        Rule::symbol_name => parse_symbol_name(pair.into_inner().next().unwrap())?,
59        Rule::bare_symbol_name => SymbolName(pair.as_str().into()),
60        Rule::raw_symbol_name => SymbolName(parse_raw_symbol_name(pair)),
61        _ => unreachable!("expected symbol name"),
62    })
63}
64
65fn parse_version(pair: Pair<Rule>) -> ParseResult<semver::Version> {
66    debug_assert_eq!(Rule::version, pair.as_rule());
67    pair.as_str()
68        .parse()
69        .map_err(|_| ParseError::custom("invalid semver version", pair.as_span()))
70}
71
72fn parse_symbol_ident(pair: Pair<Rule>) -> ParseResult<SymbolIdent> {
73    debug_assert_eq!(Rule::symbol_ident, pair.as_rule());
74    let mut pairs = pair.into_inner();
75    let name = parse_symbol_name(pairs.next().unwrap())?;
76    let version = pairs.next().map(parse_version).transpose()?;
77    Ok(SymbolIdent { name, version })
78}
79
80fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
81    debug_assert_eq!(Rule::term_var, pair.as_rule());
82    Ok(VarName(pair.as_str()[1..].into()))
83}
84
85fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
86    debug_assert_eq!(Rule::link_name, pair.as_rule());
87    Ok(LinkName(pair.as_str()[1..].into()))
88}
89
90fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
91    debug_assert_eq!(Rule::term, pair.as_rule());
92    let pair = pair.into_inner().next().unwrap();
93
94    Ok(match pair.as_rule() {
95        Rule::term_wildcard => Term::Wildcard,
96        Rule::term_var => Term::Var(parse_var_name(pair)?),
97        Rule::term_apply => {
98            let mut pairs = pair.into_inner();
99            let symbol = parse_symbol_ident(pairs.next().unwrap())?;
100            let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
101            Term::Apply(symbol, terms)
102        }
103        Rule::term_list => {
104            let pairs = pair.into_inner();
105            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
106            Term::List(parts)
107        }
108        Rule::term_tuple => {
109            let pairs = pair.into_inner();
110            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
111            Term::Tuple(parts)
112        }
113        Rule::literal => {
114            let literal = parse_literal(pair)?;
115            Term::Literal(literal)
116        }
117        Rule::term_const_func => {
118            let mut pairs = pair.into_inner();
119            let region = parse_region(pairs.next().unwrap())?;
120            Term::Func(Arc::new(region))
121        }
122        _ => unreachable!(),
123    })
124}
125
126fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
127    debug_assert_eq!(pair.as_rule(), Rule::literal);
128    let pair = pair.into_inner().next().unwrap();
129
130    Ok(match pair.as_rule() {
131        Rule::literal_string => Literal::Str(parse_string(pair)?),
132        Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
133        Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
134        Rule::literal_float => Literal::Float(parse_float(pair)?),
135        _ => unreachable!("expected literal"),
136    })
137}
138
139fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
140    debug_assert_eq!(pair.as_rule(), Rule::part);
141    let pair = pair.into_inner().next().unwrap();
142
143    Ok(match pair.as_rule() {
144        Rule::term => SeqPart::Item(parse_term(pair)?),
145        Rule::spliced_term => {
146            let mut pairs = pair.into_inner();
147            let term = parse_term(pairs.next().unwrap())?;
148            SeqPart::Splice(term)
149        }
150        _ => unreachable!("expected term or spliced term"),
151    })
152}
153
154fn parse_package(pair: Pair<Rule>) -> ParseResult<Package> {
155    debug_assert_eq!(pair.as_rule(), Rule::package);
156    let mut pairs = pair.into_inner();
157
158    let modules = take_rule(&mut pairs, Rule::module)
159        .map(parse_module)
160        .collect::<ParseResult<_>>()?;
161
162    Ok(Package { modules })
163}
164
165fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
166    debug_assert_eq!(pair.as_rule(), Rule::module);
167    let mut pairs = pair.into_inner();
168    let meta = parse_meta_items(&mut pairs)?;
169    let children = parse_nodes(&mut pairs)?;
170
171    Ok(Module {
172        root: Region {
173            kind: RegionKind::Module,
174            children,
175            meta,
176            ..Default::default()
177        },
178    })
179}
180
181fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
182    debug_assert_eq!(pair.as_rule(), Rule::region);
183    let mut pairs = pair.into_inner();
184
185    let kind = parse_region_kind(pairs.next().unwrap())?;
186    let sources = parse_port_list(&mut pairs)?;
187    let targets = parse_port_list(&mut pairs)?;
188    let signature = parse_optional_signature(&mut pairs)?;
189    let meta = parse_meta_items(&mut pairs)?;
190    let children = parse_nodes(&mut pairs)?;
191
192    Ok(Region {
193        kind,
194        sources,
195        targets,
196        children,
197        meta,
198        signature,
199    })
200}
201
202fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
203    debug_assert_eq!(pair.as_rule(), Rule::region_kind);
204
205    Ok(match pair.as_str() {
206        "dfg" => RegionKind::DataFlow,
207        "cfg" => RegionKind::ControlFlow,
208        "mod" => RegionKind::Module,
209        _ => unreachable!(),
210    })
211}
212
213fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
214    take_rule(pairs, Rule::node).map(parse_node).collect()
215}
216
217fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
218    debug_assert_eq!(pair.as_rule(), Rule::node);
219    let mut pairs = pair.into_inner();
220    let pair = pairs.next().unwrap();
221    let rule = pair.as_rule();
222    let mut pairs = pair.into_inner();
223
224    let operation = match rule {
225        Rule::node_dfg => Operation::Dfg,
226        Rule::node_cfg => Operation::Cfg,
227        Rule::node_block => Operation::Block,
228        Rule::node_tail_loop => Operation::TailLoop,
229        Rule::node_cond => Operation::Conditional,
230
231        Rule::node_import => {
232            let symbol_ident = parse_symbol_ident(pairs.next().unwrap())?;
233            Operation::Import(symbol_ident)
234        }
235
236        Rule::node_custom => {
237            let term = parse_term(pairs.next().unwrap())?;
238            Operation::Custom(term)
239        }
240
241        Rule::node_define_func => {
242            let symbol = parse_symbol(pairs.next().unwrap())?;
243            Operation::DefineFunc(Box::new(symbol))
244        }
245        Rule::node_declare_func => {
246            let symbol = parse_symbol(pairs.next().unwrap())?;
247            Operation::DeclareFunc(Box::new(symbol))
248        }
249        Rule::node_define_alias => {
250            let symbol = parse_symbol(pairs.next().unwrap())?;
251            let value = parse_term(pairs.next().unwrap())?;
252            Operation::DefineAlias(Box::new(symbol), value)
253        }
254        Rule::node_declare_alias => {
255            let symbol = parse_symbol(pairs.next().unwrap())?;
256            Operation::DeclareAlias(Box::new(symbol))
257        }
258        Rule::node_declare_ctr => {
259            let symbol = parse_symbol(pairs.next().unwrap())?;
260            Operation::DeclareConstructor(Box::new(symbol))
261        }
262        Rule::node_declare_operation => {
263            let symbol = parse_symbol(pairs.next().unwrap())?;
264            Operation::DeclareOperation(Box::new(symbol))
265        }
266
267        _ => unreachable!(),
268    };
269
270    let inputs = parse_port_list(&mut pairs)?;
271    let outputs = parse_port_list(&mut pairs)?;
272    let signature = parse_optional_signature(&mut pairs)?;
273    let meta = parse_meta_items(&mut pairs)?;
274    let regions = pairs
275        .map(|pair| parse_region(pair))
276        .collect::<ParseResult<_>>()?;
277
278    Ok(Node {
279        operation,
280        inputs,
281        outputs,
282        regions,
283        meta,
284        signature,
285    })
286}
287
288fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
289    take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
290}
291
292fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
293    debug_assert_eq!(pair.as_rule(), Rule::meta);
294    let mut pairs = pair.into_inner();
295    parse_term(pairs.next().unwrap())
296}
297
298fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
299    match take_rule(pairs, Rule::signature).next() {
300        Some(pair) => Ok(Some(parse_signature(pair)?)),
301        _ => Ok(None),
302    }
303}
304
305fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
306    debug_assert_eq!(Rule::signature, pair.as_rule());
307    let mut pairs = pair.into_inner();
308    parse_term(pairs.next().unwrap())
309}
310
311fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
312    take_rule(pairs, Rule::param).map(parse_param).collect()
313}
314
315fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
316    debug_assert_eq!(Rule::param, pair.as_rule());
317    let mut pairs = pair.into_inner();
318    let name = parse_var_name(pairs.next().unwrap())?;
319    let r#type = parse_term(pairs.next().unwrap())?;
320    Ok(Param { name, r#type })
321}
322
323fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
324    debug_assert_eq!(Rule::symbol, pair.as_rule());
325
326    let mut pairs = pair.into_inner();
327    let visibility = take_rule(&mut pairs, Rule::visibility)
328        .next()
329        .map(|pair| match pair.as_str() {
330            "public" => Ok(Visibility::Public),
331            "private" => Ok(Visibility::Private),
332            _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()),
333        })
334        .transpose()?;
335    let SymbolIdent { name, version } = parse_symbol_ident(pairs.next().unwrap())?;
336    let params = parse_params(&mut pairs)?;
337    let constraints = parse_constraints(&mut pairs)?;
338    let signature = parse_term(pairs.next().unwrap())?;
339
340    Ok(Symbol {
341        visibility,
342        name,
343        version,
344        params,
345        constraints,
346        signature,
347    })
348}
349
350fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
351    take_rule(pairs, Rule::where_clause)
352        .map(parse_constraint)
353        .collect()
354}
355
356fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
357    debug_assert_eq!(Rule::where_clause, pair.as_rule());
358    let mut pairs = pair.into_inner();
359    parse_term(pairs.next().unwrap())
360}
361
362fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
363    let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
364        return Ok(Default::default());
365    };
366
367    let pairs = pair.into_inner();
368    pairs.map(parse_link_name).collect()
369}
370
371fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
372    debug_assert_eq!(pair.as_rule(), Rule::literal_string);
373
374    // Any escape sequence is longer than the character it represents.
375    // Therefore the length of this token (minus 2 for the quotes on either
376    // side) is an upper bound for the length of the string.
377    let capacity = pair.as_str().len() - 2;
378    let mut string = String::with_capacity(capacity);
379    let pairs = pair.into_inner();
380
381    for pair in pairs {
382        match pair.as_rule() {
383            Rule::literal_string_raw => string.push_str(pair.as_str()),
384            Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
385                '"' => string.push('"'),
386                '\\' => string.push('\\'),
387                'n' => string.push('\n'),
388                'r' => string.push('\r'),
389                't' => string.push('\t'),
390                _ => unreachable!(),
391            },
392            Rule::literal_string_unicode => {
393                let token_str = pair.as_str();
394                debug_assert_eq!(&token_str[0..3], r"\u{");
395                debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
396                let code_str = &token_str[3..token_str.len() - 1];
397                let code = u32::from_str_radix(code_str, 16).map_err(|_| {
398                    ParseError::custom("invalid unicode escape sequence", pair.as_span())
399                })?;
400                let char = std::char::from_u32(code).ok_or_else(|| {
401                    ParseError::custom("invalid unicode code point", pair.as_span())
402                })?;
403                string.push(char);
404            }
405            _ => unreachable!(),
406        }
407    }
408
409    Ok(string.into())
410}
411
412fn parse_raw_symbol_name(pair: Pair<Rule>) -> SmolStr {
413    debug_assert_eq!(pair.as_rule(), Rule::raw_symbol_name);
414    let raw = pair.as_str();
415    let Some(quote_index) = raw.find('"') else {
416        unreachable!("raw symbol names always contain an opening quote")
417    };
418    let hashes = &raw[1..quote_index];
419    let content_start = quote_index + 1;
420    let content_end = raw.len() - hashes.len() - 1;
421    raw[content_start..content_end].into()
422}
423
424fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
425    debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
426    let pair = pair.into_inner().next().unwrap();
427    debug_assert_eq!(pair.as_rule(), Rule::base64_string);
428
429    let slice = pair.as_str().as_bytes();
430
431    // Remove the quotes
432    let slice = &slice[1..slice.len() - 1];
433
434    let data = BASE64_STANDARD
435        .decode(slice)
436        .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
437
438    Ok(data.into())
439}
440
441fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
442    debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
443    let value = pair.as_str().trim().parse().unwrap();
444    Ok(value)
445}
446
447fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
448    debug_assert_eq!(pair.as_rule(), Rule::literal_float);
449    let value = pair.as_str().trim().parse().unwrap();
450    Ok(OrderedFloat(value))
451}
452
453fn take_rule<'a, 'i>(
454    pairs: &'i mut Pairs<'a, Rule>,
455    rule: Rule,
456) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
457    std::iter::from_fn(move || {
458        if pairs.peek()?.as_rule() == rule {
459            pairs.next()
460        } else {
461            None
462        }
463    })
464}
465
466type ParseResult<T> = Result<T, ParseError>;
467
468/// An error that occurred during parsing.
469#[derive(Debug, Clone, Error)]
470#[error("{0}")]
471pub struct ParseError(Box<pest::error::Error<Rule>>);
472
473impl ParseError {
474    fn custom(message: &str, span: pest::Span) -> Self {
475        let error = pest::error::Error::new_from_span(
476            pest::error::ErrorVariant::CustomError {
477                message: message.to_string(),
478            },
479            span,
480        );
481        ParseError(Box::new(error))
482    }
483}
484
485macro_rules! impl_from_str {
486    ($ident:ident, $rule:ident, $parse:expr) => {
487        impl FromStr for $ident {
488            type Err = ParseError;
489
490            fn from_str(s: &str) -> Result<Self, Self::Err> {
491                let mut pairs =
492                    HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
493                let pair = pairs.next().unwrap();
494                let span = pair.as_span();
495                let end = span.end();
496
497                if !s[end..].trim().is_empty() {
498                    let span = pest::Span::new(s, end, s.len()).unwrap_or(span);
499                    return Err(ParseError::custom("unexpected trailing input", span));
500                }
501
502                $parse(pair)
503            }
504        }
505    };
506}
507
508impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
509impl_from_str!(SymbolIdent, symbol_ident, parse_symbol_ident);
510impl_from_str!(VarName, term_var, parse_var_name);
511impl_from_str!(LinkName, link_name, parse_link_name);
512impl_from_str!(Term, term, parse_term);
513impl_from_str!(Node, node, parse_node);
514impl_from_str!(Region, region, parse_region);
515impl_from_str!(Param, param, parse_param);
516impl_from_str!(Package, package, parse_package);
517impl_from_str!(Module, module, parse_module);
518impl_from_str!(SeqPart, part, parse_seq_part);
519impl_from_str!(Literal, literal, parse_literal);
520impl_from_str!(Symbol, symbol, parse_symbol);