1use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::prelude::BASE64_STANDARD;
20use base64::Engine as _;
21use ordered_float::OrderedFloat;
22use pest::iterators::{Pair, Pairs};
23use pest::Parser as _;
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Param, Region, Symbol, VarName};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35 use pest_derive::Parser;
36
37 #[derive(Parser)]
41 #[grammar = "v0/ast/hugr.pest"]
42 pub struct HugrParser;
43}
44
45fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
46 debug_assert_eq!(Rule::symbol_name, pair.as_rule());
47 Ok(SymbolName(pair.as_str().into()))
48}
49
50fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
51 debug_assert_eq!(Rule::term_var, pair.as_rule());
52 Ok(VarName(pair.as_str()[1..].into()))
53}
54
55fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
56 debug_assert_eq!(Rule::link_name, pair.as_rule());
57 Ok(LinkName(pair.as_str()[1..].into()))
58}
59
60fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
61 debug_assert_eq!(Rule::term, pair.as_rule());
62 let pair = pair.into_inner().next().unwrap();
63
64 Ok(match pair.as_rule() {
65 Rule::term_wildcard => Term::Wildcard,
66 Rule::term_var => Term::Var(parse_var_name(pair)?),
67 Rule::term_apply => {
68 let mut pairs = pair.into_inner();
69 let symbol = parse_symbol_name(pairs.next().unwrap())?;
70 let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
71 Term::Apply(symbol, terms)
72 }
73 Rule::term_list => {
74 let pairs = pair.into_inner();
75 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
76 Term::List(parts)
77 }
78 Rule::term_tuple => {
79 let pairs = pair.into_inner();
80 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
81 Term::Tuple(parts)
82 }
83 Rule::term_ext_set => Term::ExtSet,
84 Rule::literal => {
85 let literal = parse_literal(pair)?;
86 Term::Literal(literal)
87 }
88 Rule::term_const_func => {
89 let mut pairs = pair.into_inner();
90 let region = parse_region(pairs.next().unwrap())?;
91 Term::Func(Arc::new(region))
92 }
93 _ => unreachable!(),
94 })
95}
96
97fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
98 debug_assert_eq!(pair.as_rule(), Rule::literal);
99 let pair = pair.into_inner().next().unwrap();
100
101 Ok(match pair.as_rule() {
102 Rule::literal_string => Literal::Str(parse_string(pair)?),
103 Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
104 Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
105 Rule::literal_float => Literal::Float(parse_float(pair)?),
106 _ => unreachable!("expected literal"),
107 })
108}
109
110fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
111 debug_assert_eq!(pair.as_rule(), Rule::part);
112 let pair = pair.into_inner().next().unwrap();
113
114 Ok(match pair.as_rule() {
115 Rule::term => SeqPart::Item(parse_term(pair)?),
116 Rule::spliced_term => {
117 let mut pairs = pair.into_inner();
118 let term = parse_term(pairs.next().unwrap())?;
119 SeqPart::Splice(term)
120 }
121 _ => unreachable!("expected term or spliced term"),
122 })
123}
124
125fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
126 debug_assert_eq!(pair.as_rule(), Rule::module);
127 let mut pairs = pair.into_inner();
128 let meta = parse_meta_items(&mut pairs)?;
129 let children = parse_nodes(&mut pairs)?;
130
131 Ok(Module {
132 root: Region {
133 kind: RegionKind::Module,
134 children,
135 meta,
136 ..Default::default()
137 },
138 })
139}
140
141fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
142 debug_assert_eq!(pair.as_rule(), Rule::region);
143 let mut pairs = pair.into_inner();
144
145 let kind = parse_region_kind(pairs.next().unwrap())?;
146 let sources = parse_port_list(&mut pairs)?;
147 let targets = parse_port_list(&mut pairs)?;
148 let signature = parse_optional_signature(&mut pairs)?;
149 let meta = parse_meta_items(&mut pairs)?;
150 let children = parse_nodes(&mut pairs)?;
151
152 Ok(Region {
153 kind,
154 sources,
155 targets,
156 signature,
157 meta,
158 children,
159 })
160}
161
162fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
163 debug_assert_eq!(pair.as_rule(), Rule::region_kind);
164
165 Ok(match pair.as_str() {
166 "dfg" => RegionKind::DataFlow,
167 "cfg" => RegionKind::ControlFlow,
168 "mod" => RegionKind::Module,
169 _ => unreachable!(),
170 })
171}
172
173fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
174 take_rule(pairs, Rule::node).map(parse_node).collect()
175}
176
177fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
178 debug_assert_eq!(pair.as_rule(), Rule::node);
179 let mut pairs = pair.into_inner();
180 let pair = pairs.next().unwrap();
181 let rule = pair.as_rule();
182 let mut pairs = pair.into_inner();
183
184 let operation = match rule {
185 Rule::node_dfg => Operation::Dfg,
186 Rule::node_cfg => Operation::Cfg,
187 Rule::node_block => Operation::Block,
188 Rule::node_tail_loop => Operation::TailLoop,
189 Rule::node_cond => Operation::Conditional,
190
191 Rule::node_import => {
192 let name = parse_symbol_name(pairs.next().unwrap())?;
193 Operation::Import(name)
194 }
195
196 Rule::node_custom => {
197 let term = parse_term(pairs.next().unwrap())?;
198 Operation::Custom(term)
199 }
200
201 Rule::node_define_func => {
202 let symbol = parse_symbol(pairs.next().unwrap())?;
203 Operation::DefineFunc(Box::new(symbol))
204 }
205 Rule::node_declare_func => {
206 let symbol = parse_symbol(pairs.next().unwrap())?;
207 Operation::DeclareFunc(Box::new(symbol))
208 }
209 Rule::node_define_alias => {
210 let symbol = parse_symbol(pairs.next().unwrap())?;
211 let value = parse_term(pairs.next().unwrap())?;
212 Operation::DefineAlias(Box::new(symbol), value)
213 }
214 Rule::node_declare_alias => {
215 let symbol = parse_symbol(pairs.next().unwrap())?;
216 Operation::DeclareAlias(Box::new(symbol))
217 }
218 Rule::node_declare_ctr => {
219 let symbol = parse_symbol(pairs.next().unwrap())?;
220 Operation::DeclareConstructor(Box::new(symbol))
221 }
222 Rule::node_declare_operation => {
223 let symbol = parse_symbol(pairs.next().unwrap())?;
224 Operation::DeclareOperation(Box::new(symbol))
225 }
226
227 _ => unreachable!(),
228 };
229
230 let inputs = parse_port_list(&mut pairs)?;
231 let outputs = parse_port_list(&mut pairs)?;
232 let signature = parse_optional_signature(&mut pairs)?;
233 let meta = parse_meta_items(&mut pairs)?;
234 let regions = pairs
235 .map(|pair| parse_region(pair))
236 .collect::<ParseResult<_>>()?;
237
238 Ok(Node {
239 operation,
240 inputs,
241 outputs,
242 regions,
243 meta,
244 signature,
245 })
246}
247
248fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
249 take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
250}
251
252fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
253 debug_assert_eq!(pair.as_rule(), Rule::meta);
254 let mut pairs = pair.into_inner();
255 parse_term(pairs.next().unwrap())
256}
257
258fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
259 if let Some(pair) = take_rule(pairs, Rule::signature).next() {
260 Ok(Some(parse_signature(pair)?))
261 } else {
262 Ok(None)
263 }
264}
265
266fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
267 debug_assert_eq!(Rule::signature, pair.as_rule());
268 let mut pairs = pair.into_inner();
269 parse_term(pairs.next().unwrap())
270}
271
272fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
273 take_rule(pairs, Rule::param).map(parse_param).collect()
274}
275
276fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
277 debug_assert_eq!(Rule::param, pair.as_rule());
278 let mut pairs = pair.into_inner();
279 let name = parse_var_name(pairs.next().unwrap())?;
280 let r#type = parse_term(pairs.next().unwrap())?;
281 Ok(Param { name, r#type })
282}
283
284fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
285 debug_assert_eq!(Rule::symbol, pair.as_rule());
286 let mut pairs = pair.into_inner();
287 let name = parse_symbol_name(pairs.next().unwrap())?;
288 let params = parse_params(&mut pairs)?;
289 let constraints = parse_constraints(&mut pairs)?;
290 let signature = parse_term(pairs.next().unwrap())?;
291
292 Ok(Symbol {
293 name,
294 params,
295 constraints,
296 signature,
297 })
298}
299
300fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
301 take_rule(pairs, Rule::where_clause)
302 .map(parse_constraint)
303 .collect()
304}
305
306fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
307 debug_assert_eq!(Rule::where_clause, pair.as_rule());
308 let mut pairs = pair.into_inner();
309 parse_term(pairs.next().unwrap())
310}
311
312fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
313 let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
314 return Ok(Default::default());
315 };
316
317 let pairs = pair.into_inner();
318 pairs.map(parse_link_name).collect()
319}
320
321fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
322 debug_assert_eq!(pair.as_rule(), Rule::literal_string);
323
324 let capacity = pair.as_str().len() - 2;
328 let mut string = String::with_capacity(capacity);
329 let pairs = pair.into_inner();
330
331 for pair in pairs {
332 match pair.as_rule() {
333 Rule::literal_string_raw => string.push_str(pair.as_str()),
334 Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
335 '"' => string.push('"'),
336 '\\' => string.push('\\'),
337 'n' => string.push('\n'),
338 'r' => string.push('\r'),
339 't' => string.push('\t'),
340 _ => unreachable!(),
341 },
342 Rule::literal_string_unicode => {
343 let token_str = pair.as_str();
344 debug_assert_eq!(&token_str[0..3], r"\u{");
345 debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
346 let code_str = &token_str[3..token_str.len() - 1];
347 let code = u32::from_str_radix(code_str, 16).map_err(|_| {
348 ParseError::custom("invalid unicode escape sequence", pair.as_span())
349 })?;
350 let char = std::char::from_u32(code).ok_or_else(|| {
351 ParseError::custom("invalid unicode code point", pair.as_span())
352 })?;
353 string.push(char);
354 }
355 _ => unreachable!(),
356 }
357 }
358
359 Ok(string.into())
360}
361
362fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
363 debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
364 let pair = pair.into_inner().next().unwrap();
365 debug_assert_eq!(pair.as_rule(), Rule::base64_string);
366
367 let slice = pair.as_str().as_bytes();
368
369 let slice = &slice[1..slice.len() - 1];
371
372 let data = BASE64_STANDARD
373 .decode(slice)
374 .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
375
376 Ok(data.into())
377}
378
379fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
380 debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
381 let value = pair.as_str().trim().parse().unwrap();
382 Ok(value)
383}
384
385fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
386 debug_assert_eq!(pair.as_rule(), Rule::literal_float);
387 let value = pair.as_str().trim().parse().unwrap();
388 Ok(OrderedFloat(value))
389}
390
391fn take_rule<'a, 'i>(
392 pairs: &'i mut Pairs<'a, Rule>,
393 rule: Rule,
394) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
395 std::iter::from_fn(move || {
396 if pairs.peek()?.as_rule() == rule {
397 pairs.next()
398 } else {
399 None
400 }
401 })
402}
403
404type ParseResult<T> = Result<T, ParseError>;
405
406#[derive(Debug, Clone, Error)]
408#[error("{0}")]
409pub struct ParseError(Box<pest::error::Error<Rule>>);
410
411impl ParseError {
412 fn custom(message: &str, span: pest::Span) -> Self {
413 let error = pest::error::Error::new_from_span(
414 pest::error::ErrorVariant::CustomError {
415 message: message.to_string(),
416 },
417 span,
418 );
419 ParseError(Box::new(error))
420 }
421}
422
423macro_rules! impl_from_str {
424 ($ident:ident, $rule:ident, $parse:expr) => {
425 impl FromStr for $ident {
426 type Err = ParseError;
427
428 fn from_str(s: &str) -> Result<Self, Self::Err> {
429 let mut pairs =
430 HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
431 $parse(pairs.next().unwrap())
432 }
433 }
434 };
435}
436
437impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
438impl_from_str!(VarName, term_var, parse_var_name);
439impl_from_str!(LinkName, link_name, parse_link_name);
440impl_from_str!(Term, term, parse_term);
441impl_from_str!(Node, node, parse_node);
442impl_from_str!(Region, region, parse_region);
443impl_from_str!(Param, param, parse_param);
444impl_from_str!(Module, module, parse_module);
445impl_from_str!(SeqPart, part, parse_seq_part);
446impl_from_str!(Literal, literal, parse_literal);