Skip to main content

hugr_passes/
composable.rs

1//! Compiler passes and utilities for composing them.
2//!
3//! The core trait is [`ComposablePass`], which defines a transformation that can
4//! be applied to a HUGR.
5//! See the [`ComposablePass`] trait documentation for more details.
6//!
7
8mod scope;
9
10use hugr_core::Hugr;
11pub use scope::{InScope, PassScope, Preserve};
12
13use std::{error::Error, marker::PhantomData};
14
15use hugr_core::core::HugrNode;
16use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
17use itertools::Either;
18
19/// An optimization pass that can be sequenced with another and/or wrapped
20/// e.g. by [`ValidatingPass`].
21///
22/// Note it is expected that (simple) passes should make reasonable effort to be
23/// idempotent (i.e. such that after running a pass, rerunning it immediately has
24/// no further effect). However this is *not* a requirement, e.g. a sequence of
25/// idempotent passes created by [ComposablePass::then] may not be idempotent itself.
26pub trait ComposablePass<H: HugrMut>: Sized {
27    /// Error thrown by this pass.
28    type Error: Error;
29    /// Result returned by this pass.
30    type Result; // Would like to default to () but currently unstable
31
32    /// Run the pass on the given HUGR.
33    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
34
35    /// Set the scope configuration used to run the pass.
36    ///
37    /// See [`PassScope`] for more details.
38    ///
39    /// In `hugr 0.25.*`, this configuration is only a guidance, and may be
40    /// ignored by the pass by using the default implementation.
41    ///
42    /// From `hugr >=0.26.0`, passes must respect the scope configuration.
43    //
44    // For hugr passes, this is tracked by <https://github.com/Quantinuum/hugr/issues/2771>
45    fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
46        // Currently passes are not required to respect the scope configuration.
47        // <https://github.com/Quantinuum/hugr/issues/2771>
48        //
49        // deprecated: Remove default implementation in hugr 0.26.0,
50        // ensure all passes follow the scope configuration.
51        let _ = scope;
52        self
53    }
54
55    /// Apply a function to the error type of this pass, returning a new
56    /// [`ComposablePass`] that has the same result type.
57    fn map_err<E2: Error>(
58        self,
59        f: impl Fn(Self::Error) -> E2,
60    ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
61        ErrMapper::new(self, f)
62    }
63
64    /// Returns a [`ComposablePass`] that does "`self` then `other`", so long as
65    /// `other::Err` can be combined with ours.
66    ///
67    /// Composed passes may have different configured [`PassScope`]s. Use
68    /// [`WithScope::with_scope`] after the composition to override all the
69    /// scope configurations if needed.
70    ///
71    /// Note this is not necessarily idempotent even if both `self` and `other` are.
72    /// (Idempotency would require rerunning the sequence of both until no change;
73    /// since there is no general/efficient reporting of whether a pass has changed
74    /// the hugr, no such checking or looping is done here.)
75    fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
76        self,
77        other: P,
78    ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
79        struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
80        impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
81        where
82            H: HugrMut,
83            P1: ComposablePass<H>,
84            P2: ComposablePass<H>,
85            E: ErrorCombiner<P1::Error, P2::Error>,
86        {
87            type Error = E;
88            type Result = (P1::Result, P2::Result);
89
90            fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
91                let res1 = self.0.run(hugr).map_err(E::from_first)?;
92                let res2 = self.1.run(hugr).map_err(E::from_second)?;
93                Ok((res1, res2))
94            }
95
96            fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
97                let scope = scope.into();
98                Self(
99                    self.0.with_scope_internal(scope.clone()),
100                    self.1.with_scope_internal(scope),
101                    PhantomData,
102                )
103            }
104        }
105
106        Sequence(self, other, PhantomData)
107    }
108}
109
110/// Extension trait for adding a `with_scope` method to a `ComposablePass` that
111/// does not require instantiating the `H` generic parameter.
112pub trait WithScope {
113    /// Set the scope configuration used to run the pass.
114    ///
115    /// See [`PassScope`] for more details.
116    fn with_scope(self, scope: impl Into<PassScope>) -> Self;
117}
118
119impl<P: ComposablePass<Hugr>> WithScope for P {
120    fn with_scope(self, scope: impl Into<PassScope>) -> Self {
121        self.with_scope_internal(scope)
122    }
123}
124
125/// Trait for combining the error types from two different passes
126/// into a single error.
127pub trait ErrorCombiner<A, B>: Error {
128    /// Create a combined error from the first pass's error.
129    fn from_first(a: A) -> Self;
130    /// Create a combined error from the second pass's error.
131    fn from_second(b: B) -> Self;
132}
133
134impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
135    fn from_first(a: A) -> Self {
136        a
137    }
138
139    fn from_second(b: B) -> Self {
140        b.into()
141    }
142}
143
144impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
145    fn from_first(a: A) -> Self {
146        Either::Left(a)
147    }
148
149    fn from_second(b: B) -> Self {
150        Either::Right(b)
151    }
152}
153
154// Note: in the short term we could wish for two more impls:
155//   impl<E:Error> ErrorCombiner<Infallible, E> for E
156//   impl<E:Error> ErrorCombiner<E, Infallible> for E
157// however, these aren't possible as they conflict with
158//   impl<A, B:Into<A>> ErrorCombiner<A,B> for A
159// when A=E=Infallible, boo :-(.
160// However this will become possible, indeed automatic, when Infallible is replaced
161// by ! (never_type) as (unlike Infallible) ! converts Into anything
162
163// ErrMapper ------------------------------
164struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
165
166impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
167    fn new(pass: P, err_fn: F) -> Self {
168        Self(pass, err_fn, PhantomData)
169    }
170}
171
172impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
173    for ErrMapper<P, H, E, F>
174{
175    type Error = E;
176    type Result = P::Result;
177
178    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
179        self.0.run(hugr).map_err(&self.1)
180    }
181
182    fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
183        Self(self.0.with_scope_internal(scope), self.1, PhantomData)
184    }
185}
186
187// ValidatingPass ------------------------------
188
189/// Error from a [`ValidatingPass`]
190#[derive(thiserror::Error, Debug)]
191pub enum ValidatePassError<N, E>
192where
193    N: HugrNode + 'static,
194{
195    /// Validation failed on the initial HUGR.
196    #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
197    Input {
198        /// The validation error that occurred.
199        #[source]
200        err: Box<ValidationError<N>>,
201        /// A pretty-printed representation of the HUGR that failed validation.
202        pretty_hugr: String,
203    },
204    /// Validation failed on the final HUGR.
205    #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
206    Output {
207        /// The validation error that occurred.
208        #[source]
209        err: Box<ValidationError<N>>,
210        /// A pretty-printed representation of the HUGR that failed validation.
211        pretty_hugr: String,
212    },
213    /// An error from the underlying pass.
214    #[error(transparent)]
215    Underlying(Box<E>),
216}
217
218impl<N: HugrNode, E> From<E> for ValidatePassError<N, E> {
219    fn from(err: E) -> Self {
220        Self::Underlying(Box::new(err))
221    }
222}
223
224/// Runs an underlying pass, but with validation of the Hugr
225/// both before and afterwards.
226pub struct ValidatingPass<P, H>(P, PhantomData<H>);
227
228impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
229    /// Return a new [`ValidatingPass`] that wraps the given underlying pass.
230    pub fn new(underlying: P) -> Self {
231        Self(underlying, PhantomData)
232    }
233
234    fn validation_impl<E>(
235        &self,
236        hugr: &H,
237        mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
238    ) -> Result<(), ValidatePassError<H::Node, E>> {
239        hugr.validate()
240            .map_err(|err| mk_err(err, hugr.mermaid_string()))
241    }
242}
243
244impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
245where
246    H::Node: 'static,
247{
248    type Error = ValidatePassError<H::Node, P::Error>;
249    type Result = P::Result;
250
251    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
252        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
253            err: Box::new(err),
254            pretty_hugr,
255        })?;
256        let res = self.0.run(hugr)?;
257        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
258            err: Box::new(err),
259            pretty_hugr,
260        })?;
261        Ok(res)
262    }
263
264    fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
265        Self(self.0.with_scope_internal(scope), self.1)
266    }
267}
268
269// IfThen ------------------------------
270/// [`ComposablePass`] that executes a first pass that returns a `bool`
271/// result; and then, if-and-only-if that first result was true,
272/// executes a second pass
273pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
274
275impl<
276    A: ComposablePass<H, Result = bool>,
277    B: ComposablePass<H>,
278    H: HugrMut,
279    E: ErrorCombiner<A::Error, B::Error>,
280> IfThen<E, H, A, B>
281{
282    /// Make a new instance given the [`ComposablePass`] to run first
283    /// and (maybe) second
284    pub fn new(fst: A, opt_snd: B) -> Self {
285        Self(fst, opt_snd, PhantomData)
286    }
287}
288
289impl<
290    A: ComposablePass<H, Result = bool>,
291    B: ComposablePass<H>,
292    H: HugrMut,
293    E: ErrorCombiner<A::Error, B::Error>,
294> ComposablePass<H> for IfThen<E, H, A, B>
295{
296    type Error = E;
297    type Result = Option<B::Result>;
298
299    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
300        let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
301        res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
302            .transpose()
303    }
304
305    fn with_scope_internal(self, scope: impl Into<PassScope>) -> Self {
306        let scope = scope.into();
307        Self(
308            self.0.with_scope_internal(scope.clone()),
309            self.1.with_scope_internal(scope),
310            PhantomData,
311        )
312    }
313}
314
315// Note remove when deprecated constant_fold_pass / remove_dead_funcs are removed
316pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
317    pass: P,
318    hugr: &mut H,
319) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
320    if cfg!(test) {
321        ValidatingPass::new(pass).run(hugr)
322    } else {
323        Ok(pass.run(hugr)?)
324    }
325}
326
327#[cfg(test)]
328pub(crate) mod test {
329    use hugr_core::ops::Value;
330    use itertools::{Either, Itertools};
331
332    use hugr_core::builder::{
333        Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
334    };
335    use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
336    use hugr_core::hugr::hugrmut::HugrMut;
337    use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
338    use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
339    use hugr_core::types::{Signature, TypeRow};
340    use hugr_core::{Hugr, HugrView, IncomingPort, Node, NodeIndex};
341
342    use crate::composable::WithScope;
343    use crate::const_fold::{ConstFoldError, ConstantFoldPass};
344    use crate::dead_code::DeadCodeElimError;
345    use crate::untuple::UntupleResult;
346    use crate::{DeadCodeElimPass, PassScope, ReplaceTypes, UntuplePass};
347
348    use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass};
349
350    pub(crate) fn run_validating<P: ComposablePass<H>, H: HugrMut>(
351        pass: P,
352        hugr: &mut H,
353    ) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
354        ValidatingPass::new(pass).run(hugr)
355    }
356
357    #[test]
358    fn test_then() {
359        let mut mb = ModuleBuilder::new();
360        let id1 = mb
361            .define_function("id1", Signature::new_endo(usize_t()))
362            .unwrap();
363        let inps = id1.input_wires();
364        let id1 = id1.finish_with_outputs(inps).unwrap();
365        let id2 = mb
366            .define_function("id2", Signature::new_endo(usize_t()))
367            .unwrap();
368        let inps = id2.input_wires();
369        let id2 = id2.finish_with_outputs(inps).unwrap();
370        let hugr = mb.finish_hugr().unwrap();
371
372        let c_usz = Value::from(ConstUsize::new(2));
373        let not_a_node = Node::from(portgraph::NodeIndex::new(
374            hugr.nodes().map(Node::index).max().unwrap() + 1,
375        ));
376        assert!(!hugr.contains_node(not_a_node));
377        let dce = DeadCodeElimPass::default().with_entry_points([not_a_node]);
378        let cfold = ConstantFoldPass::default().with_inputs(id2.node(), [(0, c_usz.clone())]);
379
380        cfold.run(&mut hugr.clone()).unwrap();
381
382        let dce_err = DeadCodeElimError::NodeNotFound(not_a_node);
383        let r: Result<_, Either<DeadCodeElimError, ConstFoldError>> =
384            dce.clone().then(cfold.clone()).run(&mut hugr.clone());
385        assert_eq!(r, Err(Either::Left(dce_err.clone())));
386
387        let r: Result<_, Either<_, _>> = cfold
388            .clone()
389            .with_inputs(id1.node(), [(0, c_usz)])
390            .then(dce.clone())
391            .run(&mut hugr.clone());
392        assert_eq!(r, Err(Either::Right(dce_err)));
393
394        // Avoid wrapping in Either by mapping both to same Error
395        let r = dce
396            .map_err(|e| match e {
397                DeadCodeElimError::NodeNotFound(node) => ConstFoldError::MissingEntryPoint { node },
398            })
399            .then(cfold.clone())
400            .run(&mut hugr.clone());
401        assert_eq!(
402            r,
403            Err(ConstFoldError::MissingEntryPoint { node: not_a_node })
404        );
405
406        // Or where second supports Into first
407        let v = ValidatingPass::new(cfold.clone());
408        let r: Result<_, ValidatePassError<Node, ConstFoldError>> =
409            v.then(cfold).run(&mut hugr.clone());
410        r.unwrap();
411    }
412
413    #[test]
414    fn test_validation() {
415        let mut h = Hugr::new_with_entrypoint(DFG {
416            signature: Signature::new(usize_t(), bool_t()),
417        })
418        .unwrap();
419        let inp = h.add_node_with_parent(
420            h.entrypoint(),
421            Input {
422                types: usize_t().into(),
423            },
424        );
425        let outp = h.add_node_with_parent(
426            h.entrypoint(),
427            Output {
428                types: bool_t().into(),
429            },
430        );
431        h.connect(inp, 0, outp, 0);
432        let backup = h.clone();
433        let err = backup.validate().unwrap_err();
434
435        let no_inputs: [(IncomingPort, _); 0] = [];
436        let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
437        cfold.run(&mut h).unwrap();
438        assert_eq!(h, backup); // Did nothing
439
440        let r = ValidatingPass::new(cfold).run(&mut h);
441        assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err));
442    }
443
444    #[test]
445    fn test_if_then() {
446        let tr = TypeRow::from(vec![usize_t(); 2]);
447
448        let h = {
449            let sig = Signature::new_endo(tr.clone());
450            let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
451            let [a, b] = fb.input_wires_arr();
452            let tup = fb
453                .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
454                .unwrap();
455            let untup = fb
456                .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
457                .unwrap();
458            fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
459        };
460
461        let untup = UntuplePass::default().with_scope(PassScope::EntrypointRecursive);
462        {
463            // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
464            let mut repl = ReplaceTypes::default();
465            let usize_custom_t = usize_t().as_extension().unwrap().clone();
466            repl.set_replace_type(usize_custom_t, INT_TYPES[6].clone());
467            let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
468
469            let mut h = h.clone();
470            let r = run_validating(ifthen, &mut h).unwrap();
471            assert_eq!(
472                r,
473                Some(UntupleResult {
474                    rewrites_applied: 1
475                })
476            );
477            let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
478            assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
479        }
480
481        // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple
482        let mut repl = ReplaceTypes::default();
483        let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
484        repl.set_replace_type(i32_custom_t, INT_TYPES[6].clone());
485        let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
486        let mut h = h;
487        let r = run_validating(ifthen, &mut h).unwrap();
488        assert_eq!(r, None);
489        assert_eq!(h.children(h.entrypoint()).count(), 4);
490        let mktup = h
491            .output_neighbours(h.first_child(h.entrypoint()).unwrap())
492            .next()
493            .unwrap();
494        assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
495    }
496}