hugr_passes/
composable.rs

1//! Compiler passes and utilities for composing them
2
3use std::{error::Error, marker::PhantomData};
4
5use hugr_core::core::HugrNode;
6use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
7use itertools::Either;
8
9/// An optimization pass that can be sequenced with another and/or wrapped
10/// e.g. by [`ValidatingPass`]
11pub trait ComposablePass<H: HugrMut>: Sized {
12    type Error: Error;
13    type Result; // Would like to default to () but currently unstable
14
15    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
16
17    fn map_err<E2: Error>(
18        self,
19        f: impl Fn(Self::Error) -> E2,
20    ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
21        ErrMapper::new(self, f)
22    }
23
24    /// Returns a [`ComposablePass`] that does "`self` then `other`", so long as
25    /// `other::Err` can be combined with ours.
26    fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
27        self,
28        other: P,
29    ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
30        struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
31        impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
32        where
33            H: HugrMut,
34            P1: ComposablePass<H>,
35            P2: ComposablePass<H>,
36            E: ErrorCombiner<P1::Error, P2::Error>,
37        {
38            type Error = E;
39            type Result = (P1::Result, P2::Result);
40
41            fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
42                let res1 = self.0.run(hugr).map_err(E::from_first)?;
43                let res2 = self.1.run(hugr).map_err(E::from_second)?;
44                Ok((res1, res2))
45            }
46        }
47
48        Sequence(self, other, PhantomData)
49    }
50}
51
52/// Trait for combining the error types from two different passes
53/// into a single error.
54pub trait ErrorCombiner<A, B>: Error {
55    fn from_first(a: A) -> Self;
56    fn from_second(b: B) -> Self;
57}
58
59impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
60    fn from_first(a: A) -> Self {
61        a
62    }
63
64    fn from_second(b: B) -> Self {
65        b.into()
66    }
67}
68
69impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
70    fn from_first(a: A) -> Self {
71        Either::Left(a)
72    }
73
74    fn from_second(b: B) -> Self {
75        Either::Right(b)
76    }
77}
78
79// Note: in the short term we could wish for two more impls:
80//   impl<E:Error> ErrorCombiner<Infallible, E> for E
81//   impl<E:Error> ErrorCombiner<E, Infallible> for E
82// however, these aren't possible as they conflict with
83//   impl<A, B:Into<A>> ErrorCombiner<A,B> for A
84// when A=E=Infallible, boo :-(.
85// However this will become possible, indeed automatic, when Infallible is replaced
86// by ! (never_type) as (unlike Infallible) ! converts Into anything
87
88// ErrMapper ------------------------------
89struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
90
91impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
92    fn new(pass: P, err_fn: F) -> Self {
93        Self(pass, err_fn, PhantomData)
94    }
95}
96
97impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
98    for ErrMapper<P, H, E, F>
99{
100    type Error = E;
101    type Result = P::Result;
102
103    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
104        self.0.run(hugr).map_err(&self.1)
105    }
106}
107
108// ValidatingPass ------------------------------
109
110/// Error from a [`ValidatingPass`]
111#[derive(thiserror::Error, Debug)]
112pub enum ValidatePassError<N, E>
113where
114    N: HugrNode + 'static,
115{
116    #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
117    Input {
118        #[source]
119        err: ValidationError<N>,
120        pretty_hugr: String,
121    },
122    #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
123    Output {
124        #[source]
125        err: ValidationError<N>,
126        pretty_hugr: String,
127    },
128    #[error(transparent)]
129    Underlying(#[from] E),
130}
131
132/// Runs an underlying pass, but with validation of the Hugr
133/// both before and afterwards.
134pub struct ValidatingPass<P, H>(P, PhantomData<H>);
135
136impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
137    pub fn new(underlying: P) -> Self {
138        Self(underlying, PhantomData)
139    }
140
141    fn validation_impl<E>(
142        &self,
143        hugr: &H,
144        mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
145    ) -> Result<(), ValidatePassError<H::Node, E>> {
146        hugr.validate()
147            .map_err(|err| mk_err(err, hugr.mermaid_string()))
148    }
149}
150
151impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
152where
153    H::Node: 'static,
154{
155    type Error = ValidatePassError<H::Node, P::Error>;
156    type Result = P::Result;
157
158    fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
159        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
160            err,
161            pretty_hugr,
162        })?;
163        let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
164        self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
165            err,
166            pretty_hugr,
167        })?;
168        Ok(res)
169    }
170}
171
172// IfThen ------------------------------
173/// [`ComposablePass`] that executes a first pass that returns a `bool`
174/// result; and then, if-and-only-if that first result was true,
175/// executes a second pass
176pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
177
178impl<
179    A: ComposablePass<H, Result = bool>,
180    B: ComposablePass<H>,
181    H: HugrMut,
182    E: ErrorCombiner<A::Error, B::Error>,
183> IfThen<E, H, A, B>
184{
185    /// Make a new instance given the [`ComposablePass`] to run first
186    /// and (maybe) second
187    pub fn new(fst: A, opt_snd: B) -> Self {
188        Self(fst, opt_snd, PhantomData)
189    }
190}
191
192impl<
193    A: ComposablePass<H, Result = bool>,
194    B: ComposablePass<H>,
195    H: HugrMut,
196    E: ErrorCombiner<A::Error, B::Error>,
197> ComposablePass<H> for IfThen<E, H, A, B>
198{
199    type Error = E;
200    type Result = Option<B::Result>;
201
202    fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
203        let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
204        res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
205            .transpose()
206    }
207}
208
209pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
210    pass: P,
211    hugr: &mut H,
212) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
213    if cfg!(test) {
214        ValidatingPass::new(pass).run(hugr)
215    } else {
216        pass.run(hugr).map_err(ValidatePassError::Underlying)
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use itertools::{Either, Itertools};
223    use std::convert::Infallible;
224
225    use hugr_core::builder::{
226        Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
227        ModuleBuilder,
228    };
229    use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
230    use hugr_core::hugr::hugrmut::HugrMut;
231    use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
232    use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
233    use hugr_core::types::{Signature, TypeRow};
234    use hugr_core::{Hugr, HugrView, IncomingPort};
235
236    use crate::const_fold::{ConstFoldError, ConstantFoldPass};
237    use crate::untuple::{UntupleRecursive, UntupleResult};
238    use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
239
240    use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test};
241
242    #[test]
243    fn test_then() {
244        let mut mb = ModuleBuilder::new();
245        let id1 = mb
246            .define_function("id1", Signature::new_endo(usize_t()))
247            .unwrap();
248        let inps = id1.input_wires();
249        let id1 = id1.finish_with_outputs(inps).unwrap();
250        let id2 = mb
251            .define_function("id2", Signature::new_endo(usize_t()))
252            .unwrap();
253        let inps = id2.input_wires();
254        let id2 = id2.finish_with_outputs(inps).unwrap();
255        let hugr = mb.finish_hugr().unwrap();
256
257        let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
258        let cfold =
259            ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);
260
261        cfold.run(&mut hugr.clone()).unwrap();
262
263        let exp_err = ConstFoldError::MissingEntryPoint { node: id2.node() };
264        let r: Result<_, Either<Infallible, ConstFoldError>> =
265            dce.clone().then(cfold.clone()).run(&mut hugr.clone());
266        assert_eq!(r, Err(Either::Right(exp_err.clone())));
267
268        let r = dce
269            .clone()
270            .map_err(|inf| match inf {})
271            .then(cfold.clone())
272            .run(&mut hugr.clone());
273        assert_eq!(r, Err(exp_err));
274
275        let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone());
276        r2.unwrap();
277    }
278
279    #[test]
280    fn test_validation() {
281        let mut h = Hugr::new_with_entrypoint(DFG {
282            signature: Signature::new(usize_t(), bool_t()),
283        })
284        .unwrap();
285        let inp = h.add_node_with_parent(
286            h.entrypoint(),
287            Input {
288                types: usize_t().into(),
289            },
290        );
291        let outp = h.add_node_with_parent(
292            h.entrypoint(),
293            Output {
294                types: bool_t().into(),
295            },
296        );
297        h.connect(inp, 0, outp, 0);
298        let backup = h.clone();
299        let err = backup.validate().unwrap_err();
300
301        let no_inputs: [(IncomingPort, _); 0] = [];
302        let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
303        cfold.run(&mut h).unwrap();
304        assert_eq!(h, backup); // Did nothing
305
306        let r = ValidatingPass::new(cfold).run(&mut h);
307        assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
308    }
309
310    #[test]
311    fn test_if_then() {
312        let tr = TypeRow::from(vec![usize_t(); 2]);
313
314        let h = {
315            let sig = Signature::new_endo(tr.clone());
316            let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
317            let [a, b] = fb.input_wires_arr();
318            let tup = fb
319                .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
320                .unwrap();
321            let untup = fb
322                .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
323                .unwrap();
324            fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
325        };
326
327        let untup = UntuplePass::new(UntupleRecursive::Recursive);
328        {
329            // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
330            let mut repl = ReplaceTypes::default();
331            let usize_custom_t = usize_t().as_extension().unwrap().clone();
332            repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
333            let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
334
335            let mut h = h.clone();
336            let r = validate_if_test(ifthen, &mut h).unwrap();
337            assert_eq!(
338                r,
339                Some(UntupleResult {
340                    rewrites_applied: 1
341                })
342            );
343            let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
344            assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
345        }
346
347        // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple
348        let mut repl = ReplaceTypes::default();
349        let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
350        repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
351        let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
352        let mut h = h;
353        let r = validate_if_test(ifthen, &mut h).unwrap();
354        assert_eq!(r, None);
355        assert_eq!(h.children(h.entrypoint()).count(), 4);
356        let mktup = h
357            .output_neighbours(h.first_child(h.entrypoint()).unwrap())
358            .next()
359            .unwrap();
360        assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
361    }
362}