1use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::Engine as _;
20use base64::prelude::BASE64_STANDARD;
21use ordered_float::OrderedFloat;
22use pest::Parser as _;
23use pest::iterators::{Pair, Pairs};
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart, SymbolIdent};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Package, Param, Region, Symbol, VarName, Visibility};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35 use pest_derive::Parser;
36
37 #[derive(Parser)]
41 #[grammar = "v0/ast/hugr.pest"]
42 pub struct HugrParser;
43}
44
45pub(super) fn is_bare_symbol_name(name: &str) -> bool {
50 HugrParser::parse(Rule::bare_symbol_name, name)
51 .ok()
52 .and_then(|mut pairs| pairs.next())
53 .is_some_and(|pair| pair.as_span().end() == name.len())
54}
55
56fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
57 Ok(match pair.as_rule() {
58 Rule::symbol_name => parse_symbol_name(pair.into_inner().next().unwrap())?,
59 Rule::bare_symbol_name => SymbolName(pair.as_str().into()),
60 Rule::raw_symbol_name => SymbolName(parse_raw_symbol_name(pair)),
61 _ => unreachable!("expected symbol name"),
62 })
63}
64
65fn parse_version(pair: Pair<Rule>) -> ParseResult<semver::Version> {
66 debug_assert_eq!(Rule::version, pair.as_rule());
67 pair.as_str()
68 .parse()
69 .map_err(|_| ParseError::custom("invalid semver version", pair.as_span()))
70}
71
72fn parse_symbol_ident(pair: Pair<Rule>) -> ParseResult<SymbolIdent> {
73 debug_assert_eq!(Rule::symbol_ident, pair.as_rule());
74 let mut pairs = pair.into_inner();
75 let name = parse_symbol_name(pairs.next().unwrap())?;
76 let version = pairs.next().map(parse_version).transpose()?;
77 Ok(SymbolIdent { name, version })
78}
79
80fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
81 debug_assert_eq!(Rule::term_var, pair.as_rule());
82 Ok(VarName(pair.as_str()[1..].into()))
83}
84
85fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
86 debug_assert_eq!(Rule::link_name, pair.as_rule());
87 Ok(LinkName(pair.as_str()[1..].into()))
88}
89
90fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
91 debug_assert_eq!(Rule::term, pair.as_rule());
92 let pair = pair.into_inner().next().unwrap();
93
94 Ok(match pair.as_rule() {
95 Rule::term_wildcard => Term::Wildcard,
96 Rule::term_var => Term::Var(parse_var_name(pair)?),
97 Rule::term_apply => {
98 let mut pairs = pair.into_inner();
99 let symbol = parse_symbol_ident(pairs.next().unwrap())?;
100 let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
101 Term::Apply(symbol, terms)
102 }
103 Rule::term_list => {
104 let pairs = pair.into_inner();
105 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
106 Term::List(parts)
107 }
108 Rule::term_tuple => {
109 let pairs = pair.into_inner();
110 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
111 Term::Tuple(parts)
112 }
113 Rule::literal => {
114 let literal = parse_literal(pair)?;
115 Term::Literal(literal)
116 }
117 Rule::term_const_func => {
118 let mut pairs = pair.into_inner();
119 let region = parse_region(pairs.next().unwrap())?;
120 Term::Func(Arc::new(region))
121 }
122 _ => unreachable!(),
123 })
124}
125
126fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
127 debug_assert_eq!(pair.as_rule(), Rule::literal);
128 let pair = pair.into_inner().next().unwrap();
129
130 Ok(match pair.as_rule() {
131 Rule::literal_string => Literal::Str(parse_string(pair)?),
132 Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
133 Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
134 Rule::literal_float => Literal::Float(parse_float(pair)?),
135 _ => unreachable!("expected literal"),
136 })
137}
138
139fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
140 debug_assert_eq!(pair.as_rule(), Rule::part);
141 let pair = pair.into_inner().next().unwrap();
142
143 Ok(match pair.as_rule() {
144 Rule::term => SeqPart::Item(parse_term(pair)?),
145 Rule::spliced_term => {
146 let mut pairs = pair.into_inner();
147 let term = parse_term(pairs.next().unwrap())?;
148 SeqPart::Splice(term)
149 }
150 _ => unreachable!("expected term or spliced term"),
151 })
152}
153
154fn parse_package(pair: Pair<Rule>) -> ParseResult<Package> {
155 debug_assert_eq!(pair.as_rule(), Rule::package);
156 let mut pairs = pair.into_inner();
157
158 let modules = take_rule(&mut pairs, Rule::module)
159 .map(parse_module)
160 .collect::<ParseResult<_>>()?;
161
162 Ok(Package { modules })
163}
164
165fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
166 debug_assert_eq!(pair.as_rule(), Rule::module);
167 let mut pairs = pair.into_inner();
168 let meta = parse_meta_items(&mut pairs)?;
169 let children = parse_nodes(&mut pairs)?;
170
171 Ok(Module {
172 root: Region {
173 kind: RegionKind::Module,
174 children,
175 meta,
176 ..Default::default()
177 },
178 })
179}
180
181fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
182 debug_assert_eq!(pair.as_rule(), Rule::region);
183 let mut pairs = pair.into_inner();
184
185 let kind = parse_region_kind(pairs.next().unwrap())?;
186 let sources = parse_port_list(&mut pairs)?;
187 let targets = parse_port_list(&mut pairs)?;
188 let signature = parse_optional_signature(&mut pairs)?;
189 let meta = parse_meta_items(&mut pairs)?;
190 let children = parse_nodes(&mut pairs)?;
191
192 Ok(Region {
193 kind,
194 sources,
195 targets,
196 children,
197 meta,
198 signature,
199 })
200}
201
202fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
203 debug_assert_eq!(pair.as_rule(), Rule::region_kind);
204
205 Ok(match pair.as_str() {
206 "dfg" => RegionKind::DataFlow,
207 "cfg" => RegionKind::ControlFlow,
208 "mod" => RegionKind::Module,
209 _ => unreachable!(),
210 })
211}
212
213fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
214 take_rule(pairs, Rule::node).map(parse_node).collect()
215}
216
217fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
218 debug_assert_eq!(pair.as_rule(), Rule::node);
219 let mut pairs = pair.into_inner();
220 let pair = pairs.next().unwrap();
221 let rule = pair.as_rule();
222 let mut pairs = pair.into_inner();
223
224 let operation = match rule {
225 Rule::node_dfg => Operation::Dfg,
226 Rule::node_cfg => Operation::Cfg,
227 Rule::node_block => Operation::Block,
228 Rule::node_tail_loop => Operation::TailLoop,
229 Rule::node_cond => Operation::Conditional,
230
231 Rule::node_import => {
232 let symbol_ident = parse_symbol_ident(pairs.next().unwrap())?;
233 Operation::Import(symbol_ident)
234 }
235
236 Rule::node_custom => {
237 let term = parse_term(pairs.next().unwrap())?;
238 Operation::Custom(term)
239 }
240
241 Rule::node_define_func => {
242 let symbol = parse_symbol(pairs.next().unwrap())?;
243 Operation::DefineFunc(Box::new(symbol))
244 }
245 Rule::node_declare_func => {
246 let symbol = parse_symbol(pairs.next().unwrap())?;
247 Operation::DeclareFunc(Box::new(symbol))
248 }
249 Rule::node_define_alias => {
250 let symbol = parse_symbol(pairs.next().unwrap())?;
251 let value = parse_term(pairs.next().unwrap())?;
252 Operation::DefineAlias(Box::new(symbol), value)
253 }
254 Rule::node_declare_alias => {
255 let symbol = parse_symbol(pairs.next().unwrap())?;
256 Operation::DeclareAlias(Box::new(symbol))
257 }
258 Rule::node_declare_ctr => {
259 let symbol = parse_symbol(pairs.next().unwrap())?;
260 Operation::DeclareConstructor(Box::new(symbol))
261 }
262 Rule::node_declare_operation => {
263 let symbol = parse_symbol(pairs.next().unwrap())?;
264 Operation::DeclareOperation(Box::new(symbol))
265 }
266
267 _ => unreachable!(),
268 };
269
270 let inputs = parse_port_list(&mut pairs)?;
271 let outputs = parse_port_list(&mut pairs)?;
272 let signature = parse_optional_signature(&mut pairs)?;
273 let meta = parse_meta_items(&mut pairs)?;
274 let regions = pairs
275 .map(|pair| parse_region(pair))
276 .collect::<ParseResult<_>>()?;
277
278 Ok(Node {
279 operation,
280 inputs,
281 outputs,
282 regions,
283 meta,
284 signature,
285 })
286}
287
288fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
289 take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
290}
291
292fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
293 debug_assert_eq!(pair.as_rule(), Rule::meta);
294 let mut pairs = pair.into_inner();
295 parse_term(pairs.next().unwrap())
296}
297
298fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
299 match take_rule(pairs, Rule::signature).next() {
300 Some(pair) => Ok(Some(parse_signature(pair)?)),
301 _ => Ok(None),
302 }
303}
304
305fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
306 debug_assert_eq!(Rule::signature, pair.as_rule());
307 let mut pairs = pair.into_inner();
308 parse_term(pairs.next().unwrap())
309}
310
311fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
312 take_rule(pairs, Rule::param).map(parse_param).collect()
313}
314
315fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
316 debug_assert_eq!(Rule::param, pair.as_rule());
317 let mut pairs = pair.into_inner();
318 let name = parse_var_name(pairs.next().unwrap())?;
319 let r#type = parse_term(pairs.next().unwrap())?;
320 Ok(Param { name, r#type })
321}
322
323fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
324 debug_assert_eq!(Rule::symbol, pair.as_rule());
325
326 let mut pairs = pair.into_inner();
327 let visibility = take_rule(&mut pairs, Rule::visibility)
328 .next()
329 .map(|pair| match pair.as_str() {
330 "public" => Ok(Visibility::Public),
331 "private" => Ok(Visibility::Private),
332 _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()),
333 })
334 .transpose()?;
335 let SymbolIdent { name, version } = parse_symbol_ident(pairs.next().unwrap())?;
336 let params = parse_params(&mut pairs)?;
337 let constraints = parse_constraints(&mut pairs)?;
338 let signature = parse_term(pairs.next().unwrap())?;
339
340 Ok(Symbol {
341 visibility,
342 name,
343 version,
344 params,
345 constraints,
346 signature,
347 })
348}
349
350fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
351 take_rule(pairs, Rule::where_clause)
352 .map(parse_constraint)
353 .collect()
354}
355
356fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
357 debug_assert_eq!(Rule::where_clause, pair.as_rule());
358 let mut pairs = pair.into_inner();
359 parse_term(pairs.next().unwrap())
360}
361
362fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
363 let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
364 return Ok(Default::default());
365 };
366
367 let pairs = pair.into_inner();
368 pairs.map(parse_link_name).collect()
369}
370
371fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
372 debug_assert_eq!(pair.as_rule(), Rule::literal_string);
373
374 let capacity = pair.as_str().len() - 2;
378 let mut string = String::with_capacity(capacity);
379 let pairs = pair.into_inner();
380
381 for pair in pairs {
382 match pair.as_rule() {
383 Rule::literal_string_raw => string.push_str(pair.as_str()),
384 Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
385 '"' => string.push('"'),
386 '\\' => string.push('\\'),
387 'n' => string.push('\n'),
388 'r' => string.push('\r'),
389 't' => string.push('\t'),
390 _ => unreachable!(),
391 },
392 Rule::literal_string_unicode => {
393 let token_str = pair.as_str();
394 debug_assert_eq!(&token_str[0..3], r"\u{");
395 debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
396 let code_str = &token_str[3..token_str.len() - 1];
397 let code = u32::from_str_radix(code_str, 16).map_err(|_| {
398 ParseError::custom("invalid unicode escape sequence", pair.as_span())
399 })?;
400 let char = std::char::from_u32(code).ok_or_else(|| {
401 ParseError::custom("invalid unicode code point", pair.as_span())
402 })?;
403 string.push(char);
404 }
405 _ => unreachable!(),
406 }
407 }
408
409 Ok(string.into())
410}
411
412fn parse_raw_symbol_name(pair: Pair<Rule>) -> SmolStr {
413 debug_assert_eq!(pair.as_rule(), Rule::raw_symbol_name);
414 let raw = pair.as_str();
415 let Some(quote_index) = raw.find('"') else {
416 unreachable!("raw symbol names always contain an opening quote")
417 };
418 let hashes = &raw[1..quote_index];
419 let content_start = quote_index + 1;
420 let content_end = raw.len() - hashes.len() - 1;
421 raw[content_start..content_end].into()
422}
423
424fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
425 debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
426 let pair = pair.into_inner().next().unwrap();
427 debug_assert_eq!(pair.as_rule(), Rule::base64_string);
428
429 let slice = pair.as_str().as_bytes();
430
431 let slice = &slice[1..slice.len() - 1];
433
434 let data = BASE64_STANDARD
435 .decode(slice)
436 .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
437
438 Ok(data.into())
439}
440
441fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
442 debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
443 let value = pair.as_str().trim().parse().unwrap();
444 Ok(value)
445}
446
447fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
448 debug_assert_eq!(pair.as_rule(), Rule::literal_float);
449 let value = pair.as_str().trim().parse().unwrap();
450 Ok(OrderedFloat(value))
451}
452
453fn take_rule<'a, 'i>(
454 pairs: &'i mut Pairs<'a, Rule>,
455 rule: Rule,
456) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
457 std::iter::from_fn(move || {
458 if pairs.peek()?.as_rule() == rule {
459 pairs.next()
460 } else {
461 None
462 }
463 })
464}
465
466type ParseResult<T> = Result<T, ParseError>;
467
468#[derive(Debug, Clone, Error)]
470#[error("{0}")]
471pub struct ParseError(Box<pest::error::Error<Rule>>);
472
473impl ParseError {
474 fn custom(message: &str, span: pest::Span) -> Self {
475 let error = pest::error::Error::new_from_span(
476 pest::error::ErrorVariant::CustomError {
477 message: message.to_string(),
478 },
479 span,
480 );
481 ParseError(Box::new(error))
482 }
483}
484
485macro_rules! impl_from_str {
486 ($ident:ident, $rule:ident, $parse:expr) => {
487 impl FromStr for $ident {
488 type Err = ParseError;
489
490 fn from_str(s: &str) -> Result<Self, Self::Err> {
491 let mut pairs =
492 HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
493 let pair = pairs.next().unwrap();
494 let span = pair.as_span();
495 let end = span.end();
496
497 if !s[end..].trim().is_empty() {
498 let span = pest::Span::new(s, end, s.len()).unwrap_or(span);
499 return Err(ParseError::custom("unexpected trailing input", span));
500 }
501
502 $parse(pair)
503 }
504 }
505 };
506}
507
508impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
509impl_from_str!(SymbolIdent, symbol_ident, parse_symbol_ident);
510impl_from_str!(VarName, term_var, parse_var_name);
511impl_from_str!(LinkName, link_name, parse_link_name);
512impl_from_str!(Term, term, parse_term);
513impl_from_str!(Node, node, parse_node);
514impl_from_str!(Region, region, parse_region);
515impl_from_str!(Param, param, parse_param);
516impl_from_str!(Package, package, parse_package);
517impl_from_str!(Module, module, parse_module);
518impl_from_str!(SeqPart, part, parse_seq_part);
519impl_from_str!(Literal, literal, parse_literal);
520impl_from_str!(Symbol, symbol, parse_symbol);