hugr_model/v0/ast/
parse.rs

1// NOTE: We use the `pest` library for parsing. This library is convenient, but
2// performance is mediocre. In the case that we find that parsing is too slow,
3// we can replace the parser.
4
5// NOTE: The `pest` library returns a parsed AST which we then transform into
6// our AST data structures. The `pest` AST is guaranteed to conform to the
7// grammar, but this is not automatically visible from the types. Therefore this
8// module contains many `unwrap`s and `unreachable!`s, which will not fail on
9// any input unless there is a bug in the parser. This is perhaps aesthetically
10// unsatisfying but it is aligned with the intended usage pattern of `pest`.
11
12// NOTE: The `parse_` functions are implementation details since they refer to
13// `pest` data structures. We expose parsing via implementations of the
14// `FromStr` trait.
15
16use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::prelude::BASE64_STANDARD;
20use base64::Engine as _;
21use ordered_float::OrderedFloat;
22use pest::iterators::{Pair, Pairs};
23use pest::Parser as _;
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Param, Region, Symbol, VarName};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35    use pest_derive::Parser;
36
37    // NOTE: The pest derive macro generates a `Rule` enum. We do not want this to be
38    // part of the public API, and so we hide it within this private module.
39
40    #[derive(Parser)]
41    #[grammar = "v0/ast/hugr.pest"]
42    pub struct HugrParser;
43}
44
45fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
46    debug_assert_eq!(Rule::symbol_name, pair.as_rule());
47    Ok(SymbolName(pair.as_str().into()))
48}
49
50fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
51    debug_assert_eq!(Rule::term_var, pair.as_rule());
52    Ok(VarName(pair.as_str()[1..].into()))
53}
54
55fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
56    debug_assert_eq!(Rule::link_name, pair.as_rule());
57    Ok(LinkName(pair.as_str()[1..].into()))
58}
59
60fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
61    debug_assert_eq!(Rule::term, pair.as_rule());
62    let pair = pair.into_inner().next().unwrap();
63
64    Ok(match pair.as_rule() {
65        Rule::term_wildcard => Term::Wildcard,
66        Rule::term_var => Term::Var(parse_var_name(pair)?),
67        Rule::term_apply => {
68            let mut pairs = pair.into_inner();
69            let symbol = parse_symbol_name(pairs.next().unwrap())?;
70            let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
71            Term::Apply(symbol, terms)
72        }
73        Rule::term_list => {
74            let pairs = pair.into_inner();
75            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
76            Term::List(parts)
77        }
78        Rule::term_tuple => {
79            let pairs = pair.into_inner();
80            let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
81            Term::Tuple(parts)
82        }
83        Rule::term_ext_set => Term::ExtSet,
84        Rule::literal => {
85            let literal = parse_literal(pair)?;
86            Term::Literal(literal)
87        }
88        Rule::term_const_func => {
89            let mut pairs = pair.into_inner();
90            let region = parse_region(pairs.next().unwrap())?;
91            Term::Func(Arc::new(region))
92        }
93        _ => unreachable!(),
94    })
95}
96
97fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
98    debug_assert_eq!(pair.as_rule(), Rule::literal);
99    let pair = pair.into_inner().next().unwrap();
100
101    Ok(match pair.as_rule() {
102        Rule::literal_string => Literal::Str(parse_string(pair)?),
103        Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
104        Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
105        Rule::literal_float => Literal::Float(parse_float(pair)?),
106        _ => unreachable!("expected literal"),
107    })
108}
109
110fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
111    debug_assert_eq!(pair.as_rule(), Rule::part);
112    let pair = pair.into_inner().next().unwrap();
113
114    Ok(match pair.as_rule() {
115        Rule::term => SeqPart::Item(parse_term(pair)?),
116        Rule::spliced_term => {
117            let mut pairs = pair.into_inner();
118            let term = parse_term(pairs.next().unwrap())?;
119            SeqPart::Splice(term)
120        }
121        _ => unreachable!("expected term or spliced term"),
122    })
123}
124
125fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
126    debug_assert_eq!(pair.as_rule(), Rule::module);
127    let mut pairs = pair.into_inner();
128    let meta = parse_meta_items(&mut pairs)?;
129    let children = parse_nodes(&mut pairs)?;
130
131    Ok(Module {
132        root: Region {
133            kind: RegionKind::Module,
134            children,
135            meta,
136            ..Default::default()
137        },
138    })
139}
140
141fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
142    debug_assert_eq!(pair.as_rule(), Rule::region);
143    let mut pairs = pair.into_inner();
144
145    let kind = parse_region_kind(pairs.next().unwrap())?;
146    let sources = parse_port_list(&mut pairs)?;
147    let targets = parse_port_list(&mut pairs)?;
148    let signature = parse_optional_signature(&mut pairs)?;
149    let meta = parse_meta_items(&mut pairs)?;
150    let children = parse_nodes(&mut pairs)?;
151
152    Ok(Region {
153        kind,
154        sources,
155        targets,
156        signature,
157        meta,
158        children,
159    })
160}
161
162fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
163    debug_assert_eq!(pair.as_rule(), Rule::region_kind);
164
165    Ok(match pair.as_str() {
166        "dfg" => RegionKind::DataFlow,
167        "cfg" => RegionKind::ControlFlow,
168        "mod" => RegionKind::Module,
169        _ => unreachable!(),
170    })
171}
172
173fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
174    take_rule(pairs, Rule::node).map(parse_node).collect()
175}
176
177fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
178    debug_assert_eq!(pair.as_rule(), Rule::node);
179    let mut pairs = pair.into_inner();
180    let pair = pairs.next().unwrap();
181    let rule = pair.as_rule();
182    let mut pairs = pair.into_inner();
183
184    let operation = match rule {
185        Rule::node_dfg => Operation::Dfg,
186        Rule::node_cfg => Operation::Cfg,
187        Rule::node_block => Operation::Block,
188        Rule::node_tail_loop => Operation::TailLoop,
189        Rule::node_cond => Operation::Conditional,
190
191        Rule::node_import => {
192            let name = parse_symbol_name(pairs.next().unwrap())?;
193            Operation::Import(name)
194        }
195
196        Rule::node_custom => {
197            let term = parse_term(pairs.next().unwrap())?;
198            Operation::Custom(term)
199        }
200
201        Rule::node_define_func => {
202            let symbol = parse_symbol(pairs.next().unwrap())?;
203            Operation::DefineFunc(Box::new(symbol))
204        }
205        Rule::node_declare_func => {
206            let symbol = parse_symbol(pairs.next().unwrap())?;
207            Operation::DeclareFunc(Box::new(symbol))
208        }
209        Rule::node_define_alias => {
210            let symbol = parse_symbol(pairs.next().unwrap())?;
211            let value = parse_term(pairs.next().unwrap())?;
212            Operation::DefineAlias(Box::new(symbol), value)
213        }
214        Rule::node_declare_alias => {
215            let symbol = parse_symbol(pairs.next().unwrap())?;
216            Operation::DeclareAlias(Box::new(symbol))
217        }
218        Rule::node_declare_ctr => {
219            let symbol = parse_symbol(pairs.next().unwrap())?;
220            Operation::DeclareConstructor(Box::new(symbol))
221        }
222        Rule::node_declare_operation => {
223            let symbol = parse_symbol(pairs.next().unwrap())?;
224            Operation::DeclareOperation(Box::new(symbol))
225        }
226
227        _ => unreachable!(),
228    };
229
230    let inputs = parse_port_list(&mut pairs)?;
231    let outputs = parse_port_list(&mut pairs)?;
232    let signature = parse_optional_signature(&mut pairs)?;
233    let meta = parse_meta_items(&mut pairs)?;
234    let regions = pairs
235        .map(|pair| parse_region(pair))
236        .collect::<ParseResult<_>>()?;
237
238    Ok(Node {
239        operation,
240        inputs,
241        outputs,
242        regions,
243        meta,
244        signature,
245    })
246}
247
248fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
249    take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
250}
251
252fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
253    debug_assert_eq!(pair.as_rule(), Rule::meta);
254    let mut pairs = pair.into_inner();
255    parse_term(pairs.next().unwrap())
256}
257
258fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
259    if let Some(pair) = take_rule(pairs, Rule::signature).next() {
260        Ok(Some(parse_signature(pair)?))
261    } else {
262        Ok(None)
263    }
264}
265
266fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
267    debug_assert_eq!(Rule::signature, pair.as_rule());
268    let mut pairs = pair.into_inner();
269    parse_term(pairs.next().unwrap())
270}
271
272fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
273    take_rule(pairs, Rule::param).map(parse_param).collect()
274}
275
276fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
277    debug_assert_eq!(Rule::param, pair.as_rule());
278    let mut pairs = pair.into_inner();
279    let name = parse_var_name(pairs.next().unwrap())?;
280    let r#type = parse_term(pairs.next().unwrap())?;
281    Ok(Param { name, r#type })
282}
283
284fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
285    debug_assert_eq!(Rule::symbol, pair.as_rule());
286    let mut pairs = pair.into_inner();
287    let name = parse_symbol_name(pairs.next().unwrap())?;
288    let params = parse_params(&mut pairs)?;
289    let constraints = parse_constraints(&mut pairs)?;
290    let signature = parse_term(pairs.next().unwrap())?;
291
292    Ok(Symbol {
293        name,
294        params,
295        constraints,
296        signature,
297    })
298}
299
300fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
301    take_rule(pairs, Rule::where_clause)
302        .map(parse_constraint)
303        .collect()
304}
305
306fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
307    debug_assert_eq!(Rule::where_clause, pair.as_rule());
308    let mut pairs = pair.into_inner();
309    parse_term(pairs.next().unwrap())
310}
311
312fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
313    let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
314        return Ok(Default::default());
315    };
316
317    let pairs = pair.into_inner();
318    pairs.map(parse_link_name).collect()
319}
320
321fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
322    debug_assert_eq!(pair.as_rule(), Rule::literal_string);
323
324    // Any escape sequence is longer than the character it represents.
325    // Therefore the length of this token (minus 2 for the quotes on either
326    // side) is an upper bound for the length of the string.
327    let capacity = pair.as_str().len() - 2;
328    let mut string = String::with_capacity(capacity);
329    let pairs = pair.into_inner();
330
331    for pair in pairs {
332        match pair.as_rule() {
333            Rule::literal_string_raw => string.push_str(pair.as_str()),
334            Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
335                '"' => string.push('"'),
336                '\\' => string.push('\\'),
337                'n' => string.push('\n'),
338                'r' => string.push('\r'),
339                't' => string.push('\t'),
340                _ => unreachable!(),
341            },
342            Rule::literal_string_unicode => {
343                let token_str = pair.as_str();
344                debug_assert_eq!(&token_str[0..3], r"\u{");
345                debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
346                let code_str = &token_str[3..token_str.len() - 1];
347                let code = u32::from_str_radix(code_str, 16).map_err(|_| {
348                    ParseError::custom("invalid unicode escape sequence", pair.as_span())
349                })?;
350                let char = std::char::from_u32(code).ok_or_else(|| {
351                    ParseError::custom("invalid unicode code point", pair.as_span())
352                })?;
353                string.push(char);
354            }
355            _ => unreachable!(),
356        }
357    }
358
359    Ok(string.into())
360}
361
362fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
363    debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
364    let pair = pair.into_inner().next().unwrap();
365    debug_assert_eq!(pair.as_rule(), Rule::base64_string);
366
367    let slice = pair.as_str().as_bytes();
368
369    // Remove the quotes
370    let slice = &slice[1..slice.len() - 1];
371
372    let data = BASE64_STANDARD
373        .decode(slice)
374        .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
375
376    Ok(data.into())
377}
378
379fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
380    debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
381    let value = pair.as_str().trim().parse().unwrap();
382    Ok(value)
383}
384
385fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
386    debug_assert_eq!(pair.as_rule(), Rule::literal_float);
387    let value = pair.as_str().trim().parse().unwrap();
388    Ok(OrderedFloat(value))
389}
390
391fn take_rule<'a, 'i>(
392    pairs: &'i mut Pairs<'a, Rule>,
393    rule: Rule,
394) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
395    std::iter::from_fn(move || {
396        if pairs.peek()?.as_rule() == rule {
397            pairs.next()
398        } else {
399            None
400        }
401    })
402}
403
404type ParseResult<T> = Result<T, ParseError>;
405
406/// An error that occurred during parsing.
407#[derive(Debug, Clone, Error)]
408#[error("{0}")]
409pub struct ParseError(Box<pest::error::Error<Rule>>);
410
411impl ParseError {
412    fn custom(message: &str, span: pest::Span) -> Self {
413        let error = pest::error::Error::new_from_span(
414            pest::error::ErrorVariant::CustomError {
415                message: message.to_string(),
416            },
417            span,
418        );
419        ParseError(Box::new(error))
420    }
421}
422
423macro_rules! impl_from_str {
424    ($ident:ident, $rule:ident, $parse:expr) => {
425        impl FromStr for $ident {
426            type Err = ParseError;
427
428            fn from_str(s: &str) -> Result<Self, Self::Err> {
429                let mut pairs =
430                    HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
431                $parse(pairs.next().unwrap())
432            }
433        }
434    };
435}
436
437impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
438impl_from_str!(VarName, term_var, parse_var_name);
439impl_from_str!(LinkName, link_name, parse_link_name);
440impl_from_str!(Term, term, parse_term);
441impl_from_str!(Node, node, parse_node);
442impl_from_str!(Region, region, parse_region);
443impl_from_str!(Param, param, parse_param);
444impl_from_str!(Module, module, parse_module);
445impl_from_str!(SeqPart, part, parse_seq_part);
446impl_from_str!(Literal, literal, parse_literal);