1use std::str::FromStr;
17use std::sync::Arc;
18
19use base64::Engine as _;
20use base64::prelude::BASE64_STANDARD;
21use ordered_float::OrderedFloat;
22use pest::Parser as _;
23use pest::iterators::{Pair, Pairs};
24use pest_parser::{HugrParser, Rule};
25use smol_str::SmolStr;
26use thiserror::Error;
27
28use crate::v0::ast::{LinkName, Module, Operation, SeqPart};
29use crate::v0::{Literal, RegionKind};
30
31use super::{Node, Package, Param, Region, Symbol, VarName, Visibility};
32use super::{SymbolName, Term};
33
34mod pest_parser {
35 use pest_derive::Parser;
36
37 #[derive(Parser)]
41 #[grammar = "v0/ast/hugr.pest"]
42 pub struct HugrParser;
43}
44
45fn parse_symbol_name(pair: Pair<Rule>) -> ParseResult<SymbolName> {
46 debug_assert_eq!(Rule::symbol_name, pair.as_rule());
47 Ok(SymbolName(pair.as_str().into()))
48}
49
50fn parse_var_name(pair: Pair<Rule>) -> ParseResult<VarName> {
51 debug_assert_eq!(Rule::term_var, pair.as_rule());
52 Ok(VarName(pair.as_str()[1..].into()))
53}
54
55fn parse_link_name(pair: Pair<Rule>) -> ParseResult<LinkName> {
56 debug_assert_eq!(Rule::link_name, pair.as_rule());
57 Ok(LinkName(pair.as_str()[1..].into()))
58}
59
60fn parse_term(pair: Pair<Rule>) -> ParseResult<Term> {
61 debug_assert_eq!(Rule::term, pair.as_rule());
62 let pair = pair.into_inner().next().unwrap();
63
64 Ok(match pair.as_rule() {
65 Rule::term_wildcard => Term::Wildcard,
66 Rule::term_var => Term::Var(parse_var_name(pair)?),
67 Rule::term_apply => {
68 let mut pairs = pair.into_inner();
69 let symbol = parse_symbol_name(pairs.next().unwrap())?;
70 let terms = pairs.map(parse_term).collect::<ParseResult<_>>()?;
71 Term::Apply(symbol, terms)
72 }
73 Rule::term_list => {
74 let pairs = pair.into_inner();
75 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
76 Term::List(parts)
77 }
78 Rule::term_tuple => {
79 let pairs = pair.into_inner();
80 let parts = pairs.map(parse_seq_part).collect::<ParseResult<_>>()?;
81 Term::Tuple(parts)
82 }
83 Rule::literal => {
84 let literal = parse_literal(pair)?;
85 Term::Literal(literal)
86 }
87 Rule::term_const_func => {
88 let mut pairs = pair.into_inner();
89 let region = parse_region(pairs.next().unwrap())?;
90 Term::Func(Arc::new(region))
91 }
92 _ => unreachable!(),
93 })
94}
95
96fn parse_literal(pair: Pair<Rule>) -> ParseResult<Literal> {
97 debug_assert_eq!(pair.as_rule(), Rule::literal);
98 let pair = pair.into_inner().next().unwrap();
99
100 Ok(match pair.as_rule() {
101 Rule::literal_string => Literal::Str(parse_string(pair)?),
102 Rule::literal_nat => Literal::Nat(parse_nat(pair)?),
103 Rule::literal_bytes => Literal::Bytes(parse_bytes(pair)?),
104 Rule::literal_float => Literal::Float(parse_float(pair)?),
105 _ => unreachable!("expected literal"),
106 })
107}
108
109fn parse_seq_part(pair: Pair<Rule>) -> ParseResult<SeqPart> {
110 debug_assert_eq!(pair.as_rule(), Rule::part);
111 let pair = pair.into_inner().next().unwrap();
112
113 Ok(match pair.as_rule() {
114 Rule::term => SeqPart::Item(parse_term(pair)?),
115 Rule::spliced_term => {
116 let mut pairs = pair.into_inner();
117 let term = parse_term(pairs.next().unwrap())?;
118 SeqPart::Splice(term)
119 }
120 _ => unreachable!("expected term or spliced term"),
121 })
122}
123
124fn parse_package(pair: Pair<Rule>) -> ParseResult<Package> {
125 debug_assert_eq!(pair.as_rule(), Rule::package);
126 let mut pairs = pair.into_inner();
127
128 let modules = take_rule(&mut pairs, Rule::module)
129 .map(parse_module)
130 .collect::<ParseResult<_>>()?;
131
132 Ok(Package { modules })
133}
134
135fn parse_module(pair: Pair<Rule>) -> ParseResult<Module> {
136 debug_assert_eq!(pair.as_rule(), Rule::module);
137 let mut pairs = pair.into_inner();
138 let meta = parse_meta_items(&mut pairs)?;
139 let children = parse_nodes(&mut pairs)?;
140
141 Ok(Module {
142 root: Region {
143 kind: RegionKind::Module,
144 children,
145 meta,
146 ..Default::default()
147 },
148 })
149}
150
151fn parse_region(pair: Pair<Rule>) -> ParseResult<Region> {
152 debug_assert_eq!(pair.as_rule(), Rule::region);
153 let mut pairs = pair.into_inner();
154
155 let kind = parse_region_kind(pairs.next().unwrap())?;
156 let sources = parse_port_list(&mut pairs)?;
157 let targets = parse_port_list(&mut pairs)?;
158 let signature = parse_optional_signature(&mut pairs)?;
159 let meta = parse_meta_items(&mut pairs)?;
160 let children = parse_nodes(&mut pairs)?;
161
162 Ok(Region {
163 kind,
164 sources,
165 targets,
166 children,
167 meta,
168 signature,
169 })
170}
171
172fn parse_region_kind(pair: Pair<Rule>) -> ParseResult<RegionKind> {
173 debug_assert_eq!(pair.as_rule(), Rule::region_kind);
174
175 Ok(match pair.as_str() {
176 "dfg" => RegionKind::DataFlow,
177 "cfg" => RegionKind::ControlFlow,
178 "mod" => RegionKind::Module,
179 _ => unreachable!(),
180 })
181}
182
183fn parse_nodes(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Node]>> {
184 take_rule(pairs, Rule::node).map(parse_node).collect()
185}
186
187fn parse_node(pair: Pair<Rule>) -> ParseResult<Node> {
188 debug_assert_eq!(pair.as_rule(), Rule::node);
189 let mut pairs = pair.into_inner();
190 let pair = pairs.next().unwrap();
191 let rule = pair.as_rule();
192 let mut pairs = pair.into_inner();
193
194 let operation = match rule {
195 Rule::node_dfg => Operation::Dfg,
196 Rule::node_cfg => Operation::Cfg,
197 Rule::node_block => Operation::Block,
198 Rule::node_tail_loop => Operation::TailLoop,
199 Rule::node_cond => Operation::Conditional,
200
201 Rule::node_import => {
202 let name = parse_symbol_name(pairs.next().unwrap())?;
203 Operation::Import(name)
204 }
205
206 Rule::node_custom => {
207 let term = parse_term(pairs.next().unwrap())?;
208 Operation::Custom(term)
209 }
210
211 Rule::node_define_func => {
212 let symbol = parse_symbol(pairs.next().unwrap())?;
213 Operation::DefineFunc(Box::new(symbol))
214 }
215 Rule::node_declare_func => {
216 let symbol = parse_symbol(pairs.next().unwrap())?;
217 Operation::DeclareFunc(Box::new(symbol))
218 }
219 Rule::node_define_alias => {
220 let symbol = parse_symbol(pairs.next().unwrap())?;
221 let value = parse_term(pairs.next().unwrap())?;
222 Operation::DefineAlias(Box::new(symbol), value)
223 }
224 Rule::node_declare_alias => {
225 let symbol = parse_symbol(pairs.next().unwrap())?;
226 Operation::DeclareAlias(Box::new(symbol))
227 }
228 Rule::node_declare_ctr => {
229 let symbol = parse_symbol(pairs.next().unwrap())?;
230 Operation::DeclareConstructor(Box::new(symbol))
231 }
232 Rule::node_declare_operation => {
233 let symbol = parse_symbol(pairs.next().unwrap())?;
234 Operation::DeclareOperation(Box::new(symbol))
235 }
236
237 _ => unreachable!(),
238 };
239
240 let inputs = parse_port_list(&mut pairs)?;
241 let outputs = parse_port_list(&mut pairs)?;
242 let signature = parse_optional_signature(&mut pairs)?;
243 let meta = parse_meta_items(&mut pairs)?;
244 let regions = pairs
245 .map(|pair| parse_region(pair))
246 .collect::<ParseResult<_>>()?;
247
248 Ok(Node {
249 operation,
250 inputs,
251 outputs,
252 regions,
253 meta,
254 signature,
255 })
256}
257
258fn parse_meta_items(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
259 take_rule(pairs, Rule::meta).map(parse_meta_item).collect()
260}
261
262fn parse_meta_item(pair: Pair<Rule>) -> ParseResult<Term> {
263 debug_assert_eq!(pair.as_rule(), Rule::meta);
264 let mut pairs = pair.into_inner();
265 parse_term(pairs.next().unwrap())
266}
267
268fn parse_optional_signature(pairs: &mut Pairs<Rule>) -> ParseResult<Option<Term>> {
269 match take_rule(pairs, Rule::signature).next() {
270 Some(pair) => Ok(Some(parse_signature(pair)?)),
271 _ => Ok(None),
272 }
273}
274
275fn parse_signature(pair: Pair<Rule>) -> ParseResult<Term> {
276 debug_assert_eq!(Rule::signature, pair.as_rule());
277 let mut pairs = pair.into_inner();
278 parse_term(pairs.next().unwrap())
279}
280
281fn parse_params(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Param]>> {
282 take_rule(pairs, Rule::param).map(parse_param).collect()
283}
284
285fn parse_param(pair: Pair<Rule>) -> ParseResult<Param> {
286 debug_assert_eq!(Rule::param, pair.as_rule());
287 let mut pairs = pair.into_inner();
288 let name = parse_var_name(pairs.next().unwrap())?;
289 let r#type = parse_term(pairs.next().unwrap())?;
290 Ok(Param { name, r#type })
291}
292
293fn parse_symbol(pair: Pair<Rule>) -> ParseResult<Symbol> {
294 debug_assert_eq!(Rule::symbol, pair.as_rule());
295
296 let mut pairs = pair.into_inner();
297 let visibility = take_rule(&mut pairs, Rule::visibility)
298 .next()
299 .map(|pair| match pair.as_str() {
300 "public" => Ok(Visibility::Public),
301 "private" => Ok(Visibility::Private),
302 _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()),
303 })
304 .transpose()?;
305 let name = parse_symbol_name(pairs.next().unwrap())?;
306 let params = parse_params(&mut pairs)?;
307 let constraints = parse_constraints(&mut pairs)?;
308 let signature = parse_term(pairs.next().unwrap())?;
309
310 Ok(Symbol {
311 visibility,
312 name,
313 params,
314 constraints,
315 signature,
316 })
317}
318
319fn parse_constraints(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[Term]>> {
320 take_rule(pairs, Rule::where_clause)
321 .map(parse_constraint)
322 .collect()
323}
324
325fn parse_constraint(pair: Pair<Rule>) -> ParseResult<Term> {
326 debug_assert_eq!(Rule::where_clause, pair.as_rule());
327 let mut pairs = pair.into_inner();
328 parse_term(pairs.next().unwrap())
329}
330
331fn parse_port_list(pairs: &mut Pairs<Rule>) -> ParseResult<Box<[LinkName]>> {
332 let Some(pair) = take_rule(pairs, Rule::port_list).next() else {
333 return Ok(Default::default());
334 };
335
336 let pairs = pair.into_inner();
337 pairs.map(parse_link_name).collect()
338}
339
340fn parse_string(pair: Pair<Rule>) -> ParseResult<SmolStr> {
341 debug_assert_eq!(pair.as_rule(), Rule::literal_string);
342
343 let capacity = pair.as_str().len() - 2;
347 let mut string = String::with_capacity(capacity);
348 let pairs = pair.into_inner();
349
350 for pair in pairs {
351 match pair.as_rule() {
352 Rule::literal_string_raw => string.push_str(pair.as_str()),
353 Rule::literal_string_escape => match pair.as_str().chars().nth(1).unwrap() {
354 '"' => string.push('"'),
355 '\\' => string.push('\\'),
356 'n' => string.push('\n'),
357 'r' => string.push('\r'),
358 't' => string.push('\t'),
359 _ => unreachable!(),
360 },
361 Rule::literal_string_unicode => {
362 let token_str = pair.as_str();
363 debug_assert_eq!(&token_str[0..3], r"\u{");
364 debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
365 let code_str = &token_str[3..token_str.len() - 1];
366 let code = u32::from_str_radix(code_str, 16).map_err(|_| {
367 ParseError::custom("invalid unicode escape sequence", pair.as_span())
368 })?;
369 let char = std::char::from_u32(code).ok_or_else(|| {
370 ParseError::custom("invalid unicode code point", pair.as_span())
371 })?;
372 string.push(char);
373 }
374 _ => unreachable!(),
375 }
376 }
377
378 Ok(string.into())
379}
380
381fn parse_bytes(pair: Pair<Rule>) -> ParseResult<Arc<[u8]>> {
382 debug_assert_eq!(pair.as_rule(), Rule::literal_bytes);
383 let pair = pair.into_inner().next().unwrap();
384 debug_assert_eq!(pair.as_rule(), Rule::base64_string);
385
386 let slice = pair.as_str().as_bytes();
387
388 let slice = &slice[1..slice.len() - 1];
390
391 let data = BASE64_STANDARD
392 .decode(slice)
393 .map_err(|_| ParseError::custom("invalid base64 encoding", pair.as_span()))?;
394
395 Ok(data.into())
396}
397
398fn parse_nat(pair: Pair<Rule>) -> ParseResult<u64> {
399 debug_assert_eq!(pair.as_rule(), Rule::literal_nat);
400 let value = pair.as_str().trim().parse().unwrap();
401 Ok(value)
402}
403
404fn parse_float(pair: Pair<Rule>) -> ParseResult<OrderedFloat<f64>> {
405 debug_assert_eq!(pair.as_rule(), Rule::literal_float);
406 let value = pair.as_str().trim().parse().unwrap();
407 Ok(OrderedFloat(value))
408}
409
410fn take_rule<'a, 'i>(
411 pairs: &'i mut Pairs<'a, Rule>,
412 rule: Rule,
413) -> impl Iterator<Item = Pair<'a, Rule>> + 'i {
414 std::iter::from_fn(move || {
415 if pairs.peek()?.as_rule() == rule {
416 pairs.next()
417 } else {
418 None
419 }
420 })
421}
422
423type ParseResult<T> = Result<T, ParseError>;
424
425#[derive(Debug, Clone, Error)]
427#[error("{0}")]
428pub struct ParseError(Box<pest::error::Error<Rule>>);
429
430impl ParseError {
431 fn custom(message: &str, span: pest::Span) -> Self {
432 let error = pest::error::Error::new_from_span(
433 pest::error::ErrorVariant::CustomError {
434 message: message.to_string(),
435 },
436 span,
437 );
438 ParseError(Box::new(error))
439 }
440}
441
442macro_rules! impl_from_str {
443 ($ident:ident, $rule:ident, $parse:expr) => {
444 impl FromStr for $ident {
445 type Err = ParseError;
446
447 fn from_str(s: &str) -> Result<Self, Self::Err> {
448 let mut pairs =
449 HugrParser::parse(Rule::$rule, s).map_err(|err| ParseError(Box::new(err)))?;
450 $parse(pairs.next().unwrap())
451 }
452 }
453 };
454}
455
456impl_from_str!(SymbolName, symbol_name, parse_symbol_name);
457impl_from_str!(VarName, term_var, parse_var_name);
458impl_from_str!(LinkName, link_name, parse_link_name);
459impl_from_str!(Term, term, parse_term);
460impl_from_str!(Node, node, parse_node);
461impl_from_str!(Region, region, parse_region);
462impl_from_str!(Param, param, parse_param);
463impl_from_str!(Package, package, parse_package);
464impl_from_str!(Module, module, parse_module);
465impl_from_str!(SeqPart, part, parse_seq_part);
466impl_from_str!(Literal, literal, parse_literal);
467impl_from_str!(Symbol, symbol, parse_symbol);