1use std::{error::Error, marker::PhantomData};
4
5use hugr_core::core::HugrNode;
6use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
7use itertools::Either;
8
9pub trait ComposablePass<H: HugrMut>: Sized {
12 type Error: Error;
13 type Result; fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
16
17 fn map_err<E2: Error>(
18 self,
19 f: impl Fn(Self::Error) -> E2,
20 ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
21 ErrMapper::new(self, f)
22 }
23
24 fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
27 self,
28 other: P,
29 ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
30 struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
31 impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
32 where
33 H: HugrMut,
34 P1: ComposablePass<H>,
35 P2: ComposablePass<H>,
36 E: ErrorCombiner<P1::Error, P2::Error>,
37 {
38 type Error = E;
39 type Result = (P1::Result, P2::Result);
40
41 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
42 let res1 = self.0.run(hugr).map_err(E::from_first)?;
43 let res2 = self.1.run(hugr).map_err(E::from_second)?;
44 Ok((res1, res2))
45 }
46 }
47
48 Sequence(self, other, PhantomData)
49 }
50}
51
52pub trait ErrorCombiner<A, B>: Error {
55 fn from_first(a: A) -> Self;
56 fn from_second(b: B) -> Self;
57}
58
59impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
60 fn from_first(a: A) -> Self {
61 a
62 }
63
64 fn from_second(b: B) -> Self {
65 b.into()
66 }
67}
68
69impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
70 fn from_first(a: A) -> Self {
71 Either::Left(a)
72 }
73
74 fn from_second(b: B) -> Self {
75 Either::Right(b)
76 }
77}
78
79struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
90
91impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
92 fn new(pass: P, err_fn: F) -> Self {
93 Self(pass, err_fn, PhantomData)
94 }
95}
96
97impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
98 for ErrMapper<P, H, E, F>
99{
100 type Error = E;
101 type Result = P::Result;
102
103 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
104 self.0.run(hugr).map_err(&self.1)
105 }
106}
107
108#[derive(thiserror::Error, Debug)]
112pub enum ValidatePassError<N, E>
113where
114 N: HugrNode + 'static,
115{
116 #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
117 Input {
118 #[source]
119 err: ValidationError<N>,
120 pretty_hugr: String,
121 },
122 #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
123 Output {
124 #[source]
125 err: ValidationError<N>,
126 pretty_hugr: String,
127 },
128 #[error(transparent)]
129 Underlying(#[from] E),
130}
131
132pub struct ValidatingPass<P, H>(P, PhantomData<H>);
135
136impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
137 pub fn new(underlying: P) -> Self {
138 Self(underlying, PhantomData)
139 }
140
141 fn validation_impl<E>(
142 &self,
143 hugr: &H,
144 mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
145 ) -> Result<(), ValidatePassError<H::Node, E>> {
146 hugr.validate()
147 .map_err(|err| mk_err(err, hugr.mermaid_string()))
148 }
149}
150
151impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
152where
153 H::Node: 'static,
154{
155 type Error = ValidatePassError<H::Node, P::Error>;
156 type Result = P::Result;
157
158 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
159 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
160 err,
161 pretty_hugr,
162 })?;
163 let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
164 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
165 err,
166 pretty_hugr,
167 })?;
168 Ok(res)
169 }
170}
171
172pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
177
178impl<
179 A: ComposablePass<H, Result = bool>,
180 B: ComposablePass<H>,
181 H: HugrMut,
182 E: ErrorCombiner<A::Error, B::Error>,
183> IfThen<E, H, A, B>
184{
185 pub fn new(fst: A, opt_snd: B) -> Self {
188 Self(fst, opt_snd, PhantomData)
189 }
190}
191
192impl<
193 A: ComposablePass<H, Result = bool>,
194 B: ComposablePass<H>,
195 H: HugrMut,
196 E: ErrorCombiner<A::Error, B::Error>,
197> ComposablePass<H> for IfThen<E, H, A, B>
198{
199 type Error = E;
200 type Result = Option<B::Result>;
201
202 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
203 let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
204 res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
205 .transpose()
206 }
207}
208
209pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
210 pass: P,
211 hugr: &mut H,
212) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
213 if cfg!(test) {
214 ValidatingPass::new(pass).run(hugr)
215 } else {
216 pass.run(hugr).map_err(ValidatePassError::Underlying)
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use itertools::{Either, Itertools};
223 use std::convert::Infallible;
224
225 use hugr_core::builder::{
226 Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
227 ModuleBuilder,
228 };
229 use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
230 use hugr_core::hugr::hugrmut::HugrMut;
231 use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
232 use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
233 use hugr_core::types::{Signature, TypeRow};
234 use hugr_core::{Hugr, HugrView, IncomingPort};
235
236 use crate::const_fold::{ConstFoldError, ConstantFoldPass};
237 use crate::untuple::{UntupleRecursive, UntupleResult};
238 use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
239
240 use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test};
241
242 #[test]
243 fn test_then() {
244 let mut mb = ModuleBuilder::new();
245 let id1 = mb
246 .define_function("id1", Signature::new_endo(usize_t()))
247 .unwrap();
248 let inps = id1.input_wires();
249 let id1 = id1.finish_with_outputs(inps).unwrap();
250 let id2 = mb
251 .define_function("id2", Signature::new_endo(usize_t()))
252 .unwrap();
253 let inps = id2.input_wires();
254 let id2 = id2.finish_with_outputs(inps).unwrap();
255 let hugr = mb.finish_hugr().unwrap();
256
257 let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
258 let cfold =
259 ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);
260
261 cfold.run(&mut hugr.clone()).unwrap();
262
263 let exp_err = ConstFoldError::MissingEntryPoint { node: id2.node() };
264 let r: Result<_, Either<Infallible, ConstFoldError>> =
265 dce.clone().then(cfold.clone()).run(&mut hugr.clone());
266 assert_eq!(r, Err(Either::Right(exp_err.clone())));
267
268 let r = dce
269 .clone()
270 .map_err(|inf| match inf {})
271 .then(cfold.clone())
272 .run(&mut hugr.clone());
273 assert_eq!(r, Err(exp_err));
274
275 let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone());
276 r2.unwrap();
277 }
278
279 #[test]
280 fn test_validation() {
281 let mut h = Hugr::new_with_entrypoint(DFG {
282 signature: Signature::new(usize_t(), bool_t()),
283 })
284 .unwrap();
285 let inp = h.add_node_with_parent(
286 h.entrypoint(),
287 Input {
288 types: usize_t().into(),
289 },
290 );
291 let outp = h.add_node_with_parent(
292 h.entrypoint(),
293 Output {
294 types: bool_t().into(),
295 },
296 );
297 h.connect(inp, 0, outp, 0);
298 let backup = h.clone();
299 let err = backup.validate().unwrap_err();
300
301 let no_inputs: [(IncomingPort, _); 0] = [];
302 let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
303 cfold.run(&mut h).unwrap();
304 assert_eq!(h, backup); let r = ValidatingPass::new(cfold).run(&mut h);
307 assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
308 }
309
310 #[test]
311 fn test_if_then() {
312 let tr = TypeRow::from(vec![usize_t(); 2]);
313
314 let h = {
315 let sig = Signature::new_endo(tr.clone());
316 let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
317 let [a, b] = fb.input_wires_arr();
318 let tup = fb
319 .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
320 .unwrap();
321 let untup = fb
322 .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
323 .unwrap();
324 fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
325 };
326
327 let untup = UntuplePass::new(UntupleRecursive::Recursive);
328 {
329 let mut repl = ReplaceTypes::default();
331 let usize_custom_t = usize_t().as_extension().unwrap().clone();
332 repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
333 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
334
335 let mut h = h.clone();
336 let r = validate_if_test(ifthen, &mut h).unwrap();
337 assert_eq!(
338 r,
339 Some(UntupleResult {
340 rewrites_applied: 1
341 })
342 );
343 let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
344 assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
345 }
346
347 let mut repl = ReplaceTypes::default();
349 let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
350 repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
351 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
352 let mut h = h;
353 let r = validate_if_test(ifthen, &mut h).unwrap();
354 assert_eq!(r, None);
355 assert_eq!(h.children(h.entrypoint()).count(), 4);
356 let mktup = h
357 .output_neighbours(h.first_child(h.entrypoint()).unwrap())
358 .next()
359 .unwrap();
360 assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
361 }
362}