1mod scope;
9
10use hugr_core::Hugr;
11pub use scope::{InScope, PassScope, Preserve};
12
13use std::{error::Error, marker::PhantomData};
14
15use hugr_core::core::HugrNode;
16use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
17use itertools::Either;
18
19pub trait ComposablePass<H: HugrMut>: Sized {
27 type Error: Error;
29 type Result; fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
34
35 fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
46 let _ = scope;
52 self
53 }
54
55 fn map_err<E2: Error>(
58 self,
59 f: impl Fn(Self::Error) -> E2,
60 ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
61 ErrMapper::new(self, f)
62 }
63
64 fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
76 self,
77 other: P,
78 ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
79 struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
80 impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
81 where
82 H: HugrMut,
83 P1: ComposablePass<H>,
84 P2: ComposablePass<H>,
85 E: ErrorCombiner<P1::Error, P2::Error>,
86 {
87 type Error = E;
88 type Result = (P1::Result, P2::Result);
89
90 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
91 let res1 = self.0.run(hugr).map_err(E::from_first)?;
92 let res2 = self.1.run(hugr).map_err(E::from_second)?;
93 Ok((res1, res2))
94 }
95
96 fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
97 let scope = scope.into();
98 Self(
99 self.0.with_scope_internal(scope.clone()),
100 self.1.with_scope_internal(scope),
101 PhantomData,
102 )
103 }
104 }
105
106 Sequence(self, other, PhantomData)
107 }
108}
109
110pub trait WithScope {
113 fn with_scope(self, scope: impl Into<PassScope>) -> Self;
117}
118
119impl<P: ComposablePass<Hugr>> WithScope for P {
120 fn with_scope(self, scope: impl Into<PassScope>) -> Self {
121 self.with_scope_internal(scope)
122 }
123}
124
125pub trait ErrorCombiner<A, B>: Error {
128 fn from_first(a: A) -> Self;
130 fn from_second(b: B) -> Self;
132}
133
134impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
135 fn from_first(a: A) -> Self {
136 a
137 }
138
139 fn from_second(b: B) -> Self {
140 b.into()
141 }
142}
143
144impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
145 fn from_first(a: A) -> Self {
146 Either::Left(a)
147 }
148
149 fn from_second(b: B) -> Self {
150 Either::Right(b)
151 }
152}
153
154struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
165
166impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
167 fn new(pass: P, err_fn: F) -> Self {
168 Self(pass, err_fn, PhantomData)
169 }
170}
171
172impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
173 for ErrMapper<P, H, E, F>
174{
175 type Error = E;
176 type Result = P::Result;
177
178 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
179 self.0.run(hugr).map_err(&self.1)
180 }
181
182 fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
183 Self(self.0.with_scope_internal(scope), self.1, PhantomData)
184 }
185}
186
187#[derive(thiserror::Error, Debug)]
191pub enum ValidatePassError<N, E>
192where
193 N: HugrNode + 'static,
194{
195 #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
197 Input {
198 #[source]
200 err: Box<ValidationError<N>>,
201 pretty_hugr: String,
203 },
204 #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
206 Output {
207 #[source]
209 err: Box<ValidationError<N>>,
210 pretty_hugr: String,
212 },
213 #[error(transparent)]
215 Underlying(Box<E>),
216}
217
218impl<N: HugrNode, E> From<E> for ValidatePassError<N, E> {
219 fn from(err: E) -> Self {
220 Self::Underlying(Box::new(err))
221 }
222}
223
224pub struct ValidatingPass<P, H>(P, PhantomData<H>);
227
228impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
229 pub fn new(underlying: P) -> Self {
231 Self(underlying, PhantomData)
232 }
233
234 fn validation_impl<E>(
235 &self,
236 hugr: &H,
237 mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
238 ) -> Result<(), ValidatePassError<H::Node, E>> {
239 hugr.validate()
240 .map_err(|err| mk_err(err, hugr.mermaid_string()))
241 }
242}
243
244impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
245where
246 H::Node: 'static,
247{
248 type Error = ValidatePassError<H::Node, P::Error>;
249 type Result = P::Result;
250
251 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
252 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
253 err: Box::new(err),
254 pretty_hugr,
255 })?;
256 let res = self.0.run(hugr)?;
257 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
258 err: Box::new(err),
259 pretty_hugr,
260 })?;
261 Ok(res)
262 }
263
264 fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
265 Self(self.0.with_scope_internal(scope), self.1)
266 }
267}
268
269pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
274
275impl<
276 A: ComposablePass<H, Result = bool>,
277 B: ComposablePass<H>,
278 H: HugrMut,
279 E: ErrorCombiner<A::Error, B::Error>,
280> IfThen<E, H, A, B>
281{
282 pub fn new(fst: A, opt_snd: B) -> Self {
285 Self(fst, opt_snd, PhantomData)
286 }
287}
288
289impl<
290 A: ComposablePass<H, Result = bool>,
291 B: ComposablePass<H>,
292 H: HugrMut,
293 E: ErrorCombiner<A::Error, B::Error>,
294> ComposablePass<H> for IfThen<E, H, A, B>
295{
296 type Error = E;
297 type Result = Option<B::Result>;
298
299 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
300 let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
301 res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
302 .transpose()
303 }
304
305 fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
306 let scope = scope.into();
307 Self(
308 self.0.with_scope_internal(scope.clone()),
309 self.1.with_scope_internal(scope),
310 PhantomData,
311 )
312 }
313}
314
315pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
317 pass: P,
318 hugr: &mut H,
319) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
320 if cfg!(test) {
321 ValidatingPass::new(pass).run(hugr)
322 } else {
323 Ok(pass.run(hugr)?)
324 }
325}
326
327#[cfg(test)]
328pub(crate) mod test {
329 use hugr_core::ops::Value;
330 use itertools::{Either, Itertools};
331
332 use hugr_core::builder::{
333 Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
334 };
335 use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
336 use hugr_core::hugr::hugrmut::HugrMut;
337 use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
338 use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
339 use hugr_core::types::{Signature, TypeRow};
340 use hugr_core::{Hugr, HugrView, IncomingPort, Node, NodeIndex};
341
342 use crate::composable::WithScope;
343 use crate::const_fold::{ConstFoldError, ConstantFoldPass};
344 use crate::dead_code::DeadCodeElimError;
345 use crate::untuple::UntupleResult;
346 use crate::{DeadCodeElimPass, PassScope, ReplaceTypes, UntuplePass};
347
348 use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass};
349
350 pub(crate) fn run_validating<P: ComposablePass<H>, H: HugrMut>(
351 pass: P,
352 hugr: &mut H,
353 ) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
354 ValidatingPass::new(pass).run(hugr)
355 }
356
357 #[test]
358 fn test_then() {
359 let mut mb = ModuleBuilder::new();
360 let id1 = mb
361 .define_function("id1", Signature::new_endo(usize_t()))
362 .unwrap();
363 let inps = id1.input_wires();
364 let id1 = id1.finish_with_outputs(inps).unwrap();
365 let id2 = mb
366 .define_function("id2", Signature::new_endo(usize_t()))
367 .unwrap();
368 let inps = id2.input_wires();
369 let id2 = id2.finish_with_outputs(inps).unwrap();
370 let hugr = mb.finish_hugr().unwrap();
371
372 let c_usz = Value::from(ConstUsize::new(2));
373 let not_a_node = Node::from(portgraph::NodeIndex::new(
374 hugr.nodes().map(Node::index).max().unwrap() + 1,
375 ));
376 assert!(!hugr.contains_node(not_a_node));
377 let dce = DeadCodeElimPass::default().with_entry_points([not_a_node]);
378 let cfold = ConstantFoldPass::default().with_inputs(id2.node(), [(0, c_usz.clone())]);
379
380 cfold.run(&mut hugr.clone()).unwrap();
381
382 let dce_err = DeadCodeElimError::NodeNotFound(not_a_node);
383 let r: Result<_, Either<DeadCodeElimError, ConstFoldError>> =
384 dce.clone().then(cfold.clone()).run(&mut hugr.clone());
385 assert_eq!(r, Err(Either::Left(dce_err.clone())));
386
387 let r: Result<_, Either<_, _>> = cfold
388 .clone()
389 .with_inputs(id1.node(), [(0, c_usz)])
390 .then(dce.clone())
391 .run(&mut hugr.clone());
392 assert_eq!(r, Err(Either::Right(dce_err)));
393
394 let r = dce
396 .map_err(|e| match e {
397 DeadCodeElimError::NodeNotFound(node) => ConstFoldError::MissingEntryPoint { node },
398 })
399 .then(cfold.clone())
400 .run(&mut hugr.clone());
401 assert_eq!(
402 r,
403 Err(ConstFoldError::MissingEntryPoint { node: not_a_node })
404 );
405
406 let v = ValidatingPass::new(cfold.clone());
408 let r: Result<_, ValidatePassError<Node, ConstFoldError>> =
409 v.then(cfold).run(&mut hugr.clone());
410 r.unwrap();
411 }
412
413 #[test]
414 fn test_validation() {
415 let mut h = Hugr::new_with_entrypoint(DFG {
416 signature: Signature::new(usize_t(), bool_t()),
417 })
418 .unwrap();
419 let inp = h.add_node_with_parent(
420 h.entrypoint(),
421 Input {
422 types: usize_t().into(),
423 },
424 );
425 let outp = h.add_node_with_parent(
426 h.entrypoint(),
427 Output {
428 types: bool_t().into(),
429 },
430 );
431 h.connect(inp, 0, outp, 0);
432 let backup = h.clone();
433 let err = backup.validate().unwrap_err();
434
435 let no_inputs: [(IncomingPort, _); 0] = [];
436 let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
437 cfold.run(&mut h).unwrap();
438 assert_eq!(h, backup); let r = ValidatingPass::new(cfold).run(&mut h);
441 assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err));
442 }
443
444 #[test]
445 fn test_if_then() {
446 let tr = TypeRow::from(vec![usize_t(); 2]);
447
448 let h = {
449 let sig = Signature::new_endo(tr.clone());
450 let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
451 let [a, b] = fb.input_wires_arr();
452 let tup = fb
453 .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
454 .unwrap();
455 let untup = fb
456 .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
457 .unwrap();
458 fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
459 };
460
461 let untup = UntuplePass::default().with_scope(PassScope::EntrypointRecursive);
462 {
463 let mut repl = ReplaceTypes::default();
465 let usize_custom_t = usize_t().as_extension().unwrap().clone();
466 repl.set_replace_type(usize_custom_t, INT_TYPES[6].clone());
467 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
468
469 let mut h = h.clone();
470 let r = run_validating(ifthen, &mut h).unwrap();
471 assert_eq!(
472 r,
473 Some(UntupleResult {
474 rewrites_applied: 1
475 })
476 );
477 let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
478 assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
479 }
480
481 let mut repl = ReplaceTypes::default();
483 let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
484 repl.set_replace_type(i32_custom_t, INT_TYPES[6].clone());
485 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
486 let mut h = h;
487 let r = run_validating(ifthen, &mut h).unwrap();
488 assert_eq!(r, None);
489 assert_eq!(h.children(h.entrypoint()).count(), 4);
490 let mktup = h
491 .output_neighbours(h.first_child(h.entrypoint()).unwrap())
492 .next()
493 .unwrap();
494 assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
495 }
496}