hugr_passes/
composable.rs

1//! Compiler passes and utilities for composing them
2
3use std::{error::Error, marker::PhantomData};
4
5use hugr_core::core::HugrNode;
6use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
7use itertools::Either;
8
9/// An optimization pass that can be sequenced with another and/or wrapped
10/// e.g. by [`ValidatingPass`]
11pub trait ComposablePass<H: HugrMut>: Sized {
12    /// Error thrown by this pass.
13    type Error: Error;
14    /// Result returned by this pass.
15    type Result; // Would like to default to () but currently unstable
16
17    /// Run the pass on the given HUGR.
18    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
19
20    /// Apply a function to the error type of this pass, returning a new
21    /// [`ComposablePass`] that has the same result type.
22    fn map_err<E2: Error>(
23        self,
24        f: impl Fn(Self::Error) -> E2,
25    ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
26        ErrMapper::new(self, f)
27    }
28
29    /// Returns a [`ComposablePass`] that does "`self` then `other`", so long as
30    /// `other::Err` can be combined with ours.
31    fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
32        self,
33        other: P,
34    ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
35        struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
36        impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
37        where
38            H: HugrMut,
39            P1: ComposablePass<H>,
40            P2: ComposablePass<H>,
41            E: ErrorCombiner<P1::Error, P2::Error>,
42        {
43            type Error = E;
44            type Result = (P1::Result, P2::Result);
45
46            fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
47                let res1 = self.0.run(hugr).map_err(E::from_first)?;
48                let res2 = self.1.run(hugr).map_err(E::from_second)?;
49                Ok((res1, res2))
50            }
51        }
52
53        Sequence(self, other, PhantomData)
54    }
55}
56
57/// Trait for combining the error types from two different passes
58/// into a single error.
59pub trait ErrorCombiner<A, B>: Error {
60    /// Create a combined error from the first pass's error.
61    fn from_first(a: A) -> Self;
62    /// Create a combined error from the second pass's error.
63    fn from_second(b: B) -> Self;
64}
65
66impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
67    fn from_first(a: A) -> Self {
68        a
69    }
70
71    fn from_second(b: B) -> Self {
72        b.into()
73    }
74}
75
76impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
77    fn from_first(a: A) -> Self {
78        Either::Left(a)
79    }
80
81    fn from_second(b: B) -> Self {
82        Either::Right(b)
83    }
84}
85
86// Note: in the short term we could wish for two more impls:
87//   impl<E:Error> ErrorCombiner<Infallible, E> for E
88//   impl<E:Error> ErrorCombiner<E, Infallible> for E
89// however, these aren't possible as they conflict with
90//   impl<A, B:Into<A>> ErrorCombiner<A,B> for A
91// when A=E=Infallible, boo :-(.
92// However this will become possible, indeed automatic, when Infallible is replaced
93// by ! (never_type) as (unlike Infallible) ! converts Into anything
94
95// ErrMapper ------------------------------
96struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
97
98impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
99    fn new(pass: P, err_fn: F) -> Self {
100        Self(pass, err_fn, PhantomData)
101    }
102}
103
104impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
105    for ErrMapper<P, H, E, F>
106{
107    type Error = E;
108    type Result = P::Result;
109
110    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
111        self.0.run(hugr).map_err(&self.1)
112    }
113}
114
115// ValidatingPass ------------------------------
116
117/// Error from a [`ValidatingPass`]
118#[derive(thiserror::Error, Debug)]
119pub enum ValidatePassError<N, E>
120where
121    N: HugrNode + 'static,
122{
123    /// Validation failed on the initial HUGR.
124    #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
125    Input {
126        /// The validation error that occurred.
127        #[source]
128        err: Box<ValidationError<N>>,
129        /// A pretty-printed representation of the HUGR that failed validation.
130        pretty_hugr: String,
131    },
132    /// Validation failed on the final HUGR.
133    #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
134    Output {
135        /// The validation error that occurred.
136        #[source]
137        err: Box<ValidationError<N>>,
138        /// A pretty-printed representation of the HUGR that failed validation.
139        pretty_hugr: String,
140    },
141    /// An error from the underlying pass.
142    #[error(transparent)]
143    Underlying(Box<E>),
144}
145
146impl<N: HugrNode, E> From<E> for ValidatePassError<N, E> {
147    fn from(err: E) -> Self {
148        Self::Underlying(Box::new(err))
149    }
150}
151
152/// Runs an underlying pass, but with validation of the Hugr
153/// both before and afterwards.
154pub struct ValidatingPass<P, H>(P, PhantomData<H>);
155
156impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
157    /// Return a new [`ValidatingPass`] that wraps the given underlying pass.
158    pub fn new(underlying: P) -> Self {
159        Self(underlying, PhantomData)
160    }
161
162    fn validation_impl<E>(
163        &self,
164        hugr: &H,
165        mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
166    ) -> Result<(), ValidatePassError<H::Node, E>> {
167        hugr.validate()
168            .map_err(|err| mk_err(err, hugr.mermaid_string()))
169    }
170}
171
172impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
173where
174    H::Node: 'static,
175{
176    type Error = ValidatePassError<H::Node, P::Error>;
177    type Result = P::Result;
178
179    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
180        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
181            err: Box::new(err),
182            pretty_hugr,
183        })?;
184        let res = self.0.run(hugr)?;
185        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
186            err: Box::new(err),
187            pretty_hugr,
188        })?;
189        Ok(res)
190    }
191}
192
193// IfThen ------------------------------
194/// [`ComposablePass`] that executes a first pass that returns a `bool`
195/// result; and then, if-and-only-if that first result was true,
196/// executes a second pass
197pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
198
199impl<
200    A: ComposablePass<H, Result = bool>,
201    B: ComposablePass<H>,
202    H: HugrMut,
203    E: ErrorCombiner<A::Error, B::Error>,
204> IfThen<E, H, A, B>
205{
206    /// Make a new instance given the [`ComposablePass`] to run first
207    /// and (maybe) second
208    pub fn new(fst: A, opt_snd: B) -> Self {
209        Self(fst, opt_snd, PhantomData)
210    }
211}
212
213impl<
214    A: ComposablePass<H, Result = bool>,
215    B: ComposablePass<H>,
216    H: HugrMut,
217    E: ErrorCombiner<A::Error, B::Error>,
218> ComposablePass<H> for IfThen<E, H, A, B>
219{
220    type Error = E;
221    type Result = Option<B::Result>;
222
223    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
224        let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
225        res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
226            .transpose()
227    }
228}
229
230pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
231    pass: P,
232    hugr: &mut H,
233) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
234    if cfg!(test) {
235        ValidatingPass::new(pass).run(hugr)
236    } else {
237        Ok(pass.run(hugr)?)
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use hugr_core::ops::Value;
244    use itertools::{Either, Itertools};
245
246    use hugr_core::builder::{
247        Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
248    };
249    use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
250    use hugr_core::hugr::hugrmut::HugrMut;
251    use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
252    use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
253    use hugr_core::types::{Signature, TypeRow};
254    use hugr_core::{Hugr, HugrView, IncomingPort, Node, NodeIndex};
255
256    use crate::const_fold::{ConstFoldError, ConstantFoldPass};
257    use crate::dead_code::DeadCodeElimError;
258    use crate::untuple::{UntupleRecursive, UntupleResult};
259    use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
260
261    use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test};
262
263    #[test]
264    fn test_then() {
265        let mut mb = ModuleBuilder::new();
266        let id1 = mb
267            .define_function("id1", Signature::new_endo(usize_t()))
268            .unwrap();
269        let inps = id1.input_wires();
270        let id1 = id1.finish_with_outputs(inps).unwrap();
271        let id2 = mb
272            .define_function("id2", Signature::new_endo(usize_t()))
273            .unwrap();
274        let inps = id2.input_wires();
275        let id2 = id2.finish_with_outputs(inps).unwrap();
276        let hugr = mb.finish_hugr().unwrap();
277
278        let c_usz = Value::from(ConstUsize::new(2));
279        let not_a_node = Node::from(portgraph::NodeIndex::new(
280            hugr.nodes().map(Node::index).max().unwrap() + 1,
281        ));
282        assert!(!hugr.contains_node(not_a_node));
283        let dce = DeadCodeElimPass::default().with_entry_points([not_a_node]);
284        let cfold = ConstantFoldPass::default().with_inputs(id2.node(), [(0, c_usz.clone())]);
285
286        cfold.run(&mut hugr.clone()).unwrap();
287
288        let dce_err = DeadCodeElimError::NodeNotFound(not_a_node);
289        let r: Result<_, Either<DeadCodeElimError, ConstFoldError>> =
290            dce.clone().then(cfold.clone()).run(&mut hugr.clone());
291        assert_eq!(r, Err(Either::Left(dce_err.clone())));
292
293        let r: Result<_, Either<_, _>> = cfold
294            .clone()
295            .with_inputs(id1.node(), [(0, c_usz)])
296            .then(dce.clone())
297            .run(&mut hugr.clone());
298        assert_eq!(r, Err(Either::Right(dce_err)));
299
300        // Avoid wrapping in Either by mapping both to same Error
301        let r = dce
302            .map_err(|e| match e {
303                DeadCodeElimError::NodeNotFound(node) => ConstFoldError::MissingEntryPoint { node },
304            })
305            .then(cfold.clone())
306            .run(&mut hugr.clone());
307        assert_eq!(
308            r,
309            Err(ConstFoldError::MissingEntryPoint { node: not_a_node })
310        );
311
312        // Or where second supports Into first
313        let v = ValidatingPass::new(cfold.clone());
314        let r: Result<_, ValidatePassError<Node, ConstFoldError>> =
315            v.then(cfold).run(&mut hugr.clone());
316        r.unwrap();
317    }
318
319    #[test]
320    fn test_validation() {
321        let mut h = Hugr::new_with_entrypoint(DFG {
322            signature: Signature::new(usize_t(), bool_t()),
323        })
324        .unwrap();
325        let inp = h.add_node_with_parent(
326            h.entrypoint(),
327            Input {
328                types: usize_t().into(),
329            },
330        );
331        let outp = h.add_node_with_parent(
332            h.entrypoint(),
333            Output {
334                types: bool_t().into(),
335            },
336        );
337        h.connect(inp, 0, outp, 0);
338        let backup = h.clone();
339        let err = backup.validate().unwrap_err();
340
341        let no_inputs: [(IncomingPort, _); 0] = [];
342        let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
343        cfold.run(&mut h).unwrap();
344        assert_eq!(h, backup); // Did nothing
345
346        let r = ValidatingPass::new(cfold).run(&mut h);
347        assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err));
348    }
349
350    #[test]
351    fn test_if_then() {
352        let tr = TypeRow::from(vec![usize_t(); 2]);
353
354        let h = {
355            let sig = Signature::new_endo(tr.clone());
356            let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
357            let [a, b] = fb.input_wires_arr();
358            let tup = fb
359                .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
360                .unwrap();
361            let untup = fb
362                .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
363                .unwrap();
364            fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
365        };
366
367        let untup = UntuplePass::new(UntupleRecursive::Recursive);
368        {
369            // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
370            let mut repl = ReplaceTypes::default();
371            let usize_custom_t = usize_t().as_extension().unwrap().clone();
372            repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
373            let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
374
375            let mut h = h.clone();
376            let r = validate_if_test(ifthen, &mut h).unwrap();
377            assert_eq!(
378                r,
379                Some(UntupleResult {
380                    rewrites_applied: 1
381                })
382            );
383            let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
384            assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
385        }
386
387        // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple
388        let mut repl = ReplaceTypes::default();
389        let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
390        repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
391        let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
392        let mut h = h;
393        let r = validate_if_test(ifthen, &mut h).unwrap();
394        assert_eq!(r, None);
395        assert_eq!(h.children(h.entrypoint()).count(), 4);
396        let mktup = h
397            .output_neighbours(h.first_child(h.entrypoint()).unwrap())
398            .next()
399            .unwrap();
400        assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
401    }
402}