1use std::{error::Error, marker::PhantomData};
4
5use hugr_core::core::HugrNode;
6use hugr_core::hugr::{ValidationError, hugrmut::HugrMut};
7use itertools::Either;
8
9pub trait ComposablePass<H: HugrMut>: Sized {
12 type Error: Error;
14 type Result; fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error>;
19
20 fn map_err<E2: Error>(
23 self,
24 f: impl Fn(Self::Error) -> E2,
25 ) -> impl ComposablePass<H, Result = Self::Result, Error = E2> {
26 ErrMapper::new(self, f)
27 }
28
29 fn then<P: ComposablePass<H>, E: ErrorCombiner<Self::Error, P::Error>>(
32 self,
33 other: P,
34 ) -> impl ComposablePass<H, Result = (Self::Result, P::Result), Error = E> {
35 struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
36 impl<H, E, P1, P2> ComposablePass<H> for Sequence<E, P1, P2>
37 where
38 H: HugrMut,
39 P1: ComposablePass<H>,
40 P2: ComposablePass<H>,
41 E: ErrorCombiner<P1::Error, P2::Error>,
42 {
43 type Error = E;
44 type Result = (P1::Result, P2::Result);
45
46 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
47 let res1 = self.0.run(hugr).map_err(E::from_first)?;
48 let res2 = self.1.run(hugr).map_err(E::from_second)?;
49 Ok((res1, res2))
50 }
51 }
52
53 Sequence(self, other, PhantomData)
54 }
55}
56
57pub trait ErrorCombiner<A, B>: Error {
60 fn from_first(a: A) -> Self;
62 fn from_second(b: B) -> Self;
64}
65
66impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
67 fn from_first(a: A) -> Self {
68 a
69 }
70
71 fn from_second(b: B) -> Self {
72 b.into()
73 }
74}
75
76impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
77 fn from_first(a: A) -> Self {
78 Either::Left(a)
79 }
80
81 fn from_second(b: B) -> Self {
82 Either::Right(b)
83 }
84}
85
86struct ErrMapper<P, H, E, F>(P, F, PhantomData<(E, H)>);
97
98impl<H: HugrMut, P: ComposablePass<H>, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, H, E, F> {
99 fn new(pass: P, err_fn: F) -> Self {
100 Self(pass, err_fn, PhantomData)
101 }
102}
103
104impl<P: ComposablePass<H>, H: HugrMut, E: Error, F: Fn(P::Error) -> E> ComposablePass<H>
105 for ErrMapper<P, H, E, F>
106{
107 type Error = E;
108 type Result = P::Result;
109
110 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
111 self.0.run(hugr).map_err(&self.1)
112 }
113}
114
115#[derive(thiserror::Error, Debug)]
119pub enum ValidatePassError<N, E>
120where
121 N: HugrNode + 'static,
122{
123 #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
125 Input {
126 #[source]
128 err: Box<ValidationError<N>>,
129 pretty_hugr: String,
131 },
132 #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
134 Output {
135 #[source]
137 err: Box<ValidationError<N>>,
138 pretty_hugr: String,
140 },
141 #[error(transparent)]
143 Underlying(Box<E>),
144}
145
146impl<N: HugrNode, E> From<E> for ValidatePassError<N, E> {
147 fn from(err: E) -> Self {
148 Self::Underlying(Box::new(err))
149 }
150}
151
152pub struct ValidatingPass<P, H>(P, PhantomData<H>);
155
156impl<P: ComposablePass<H>, H: HugrMut> ValidatingPass<P, H> {
157 pub fn new(underlying: P) -> Self {
159 Self(underlying, PhantomData)
160 }
161
162 fn validation_impl<E>(
163 &self,
164 hugr: &H,
165 mk_err: impl FnOnce(ValidationError<H::Node>, String) -> ValidatePassError<H::Node, E>,
166 ) -> Result<(), ValidatePassError<H::Node, E>> {
167 hugr.validate()
168 .map_err(|err| mk_err(err, hugr.mermaid_string()))
169 }
170}
171
172impl<P: ComposablePass<H>, H: HugrMut> ComposablePass<H> for ValidatingPass<P, H>
173where
174 H::Node: 'static,
175{
176 type Error = ValidatePassError<H::Node, P::Error>;
177 type Result = P::Result;
178
179 fn run(&self, hugr: &mut H) -> Result<P::Result, Self::Error> {
180 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
181 err: Box::new(err),
182 pretty_hugr,
183 })?;
184 let res = self.0.run(hugr)?;
185 self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
186 err: Box::new(err),
187 pretty_hugr,
188 })?;
189 Ok(res)
190 }
191}
192
193pub struct IfThen<E, H, A, B>(A, B, PhantomData<(E, H)>);
198
199impl<
200 A: ComposablePass<H, Result = bool>,
201 B: ComposablePass<H>,
202 H: HugrMut,
203 E: ErrorCombiner<A::Error, B::Error>,
204> IfThen<E, H, A, B>
205{
206 pub fn new(fst: A, opt_snd: B) -> Self {
209 Self(fst, opt_snd, PhantomData)
210 }
211}
212
213impl<
214 A: ComposablePass<H, Result = bool>,
215 B: ComposablePass<H>,
216 H: HugrMut,
217 E: ErrorCombiner<A::Error, B::Error>,
218> ComposablePass<H> for IfThen<E, H, A, B>
219{
220 type Error = E;
221 type Result = Option<B::Result>;
222
223 fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
224 let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
225 res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
226 .transpose()
227 }
228}
229
230pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
231 pass: P,
232 hugr: &mut H,
233) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
234 if cfg!(test) {
235 ValidatingPass::new(pass).run(hugr)
236 } else {
237 Ok(pass.run(hugr)?)
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use hugr_core::ops::Value;
244 use itertools::{Either, Itertools};
245
246 use hugr_core::builder::{
247 Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
248 };
249 use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t};
250 use hugr_core::hugr::hugrmut::HugrMut;
251 use hugr_core::ops::{DFG, Input, OpType, Output, handle::NodeHandle};
252 use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
253 use hugr_core::types::{Signature, TypeRow};
254 use hugr_core::{Hugr, HugrView, IncomingPort, Node, NodeIndex};
255
256 use crate::const_fold::{ConstFoldError, ConstantFoldPass};
257 use crate::dead_code::DeadCodeElimError;
258 use crate::untuple::{UntupleRecursive, UntupleResult};
259 use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
260
261 use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test};
262
263 #[test]
264 fn test_then() {
265 let mut mb = ModuleBuilder::new();
266 let id1 = mb
267 .define_function("id1", Signature::new_endo(usize_t()))
268 .unwrap();
269 let inps = id1.input_wires();
270 let id1 = id1.finish_with_outputs(inps).unwrap();
271 let id2 = mb
272 .define_function("id2", Signature::new_endo(usize_t()))
273 .unwrap();
274 let inps = id2.input_wires();
275 let id2 = id2.finish_with_outputs(inps).unwrap();
276 let hugr = mb.finish_hugr().unwrap();
277
278 let c_usz = Value::from(ConstUsize::new(2));
279 let not_a_node = Node::from(portgraph::NodeIndex::new(
280 hugr.nodes().map(Node::index).max().unwrap() + 1,
281 ));
282 assert!(!hugr.contains_node(not_a_node));
283 let dce = DeadCodeElimPass::default().with_entry_points([not_a_node]);
284 let cfold = ConstantFoldPass::default().with_inputs(id2.node(), [(0, c_usz.clone())]);
285
286 cfold.run(&mut hugr.clone()).unwrap();
287
288 let dce_err = DeadCodeElimError::NodeNotFound(not_a_node);
289 let r: Result<_, Either<DeadCodeElimError, ConstFoldError>> =
290 dce.clone().then(cfold.clone()).run(&mut hugr.clone());
291 assert_eq!(r, Err(Either::Left(dce_err.clone())));
292
293 let r: Result<_, Either<_, _>> = cfold
294 .clone()
295 .with_inputs(id1.node(), [(0, c_usz)])
296 .then(dce.clone())
297 .run(&mut hugr.clone());
298 assert_eq!(r, Err(Either::Right(dce_err)));
299
300 let r = dce
302 .map_err(|e| match e {
303 DeadCodeElimError::NodeNotFound(node) => ConstFoldError::MissingEntryPoint { node },
304 })
305 .then(cfold.clone())
306 .run(&mut hugr.clone());
307 assert_eq!(
308 r,
309 Err(ConstFoldError::MissingEntryPoint { node: not_a_node })
310 );
311
312 let v = ValidatingPass::new(cfold.clone());
314 let r: Result<_, ValidatePassError<Node, ConstFoldError>> =
315 v.then(cfold).run(&mut hugr.clone());
316 r.unwrap();
317 }
318
319 #[test]
320 fn test_validation() {
321 let mut h = Hugr::new_with_entrypoint(DFG {
322 signature: Signature::new(usize_t(), bool_t()),
323 })
324 .unwrap();
325 let inp = h.add_node_with_parent(
326 h.entrypoint(),
327 Input {
328 types: usize_t().into(),
329 },
330 );
331 let outp = h.add_node_with_parent(
332 h.entrypoint(),
333 Output {
334 types: bool_t().into(),
335 },
336 );
337 h.connect(inp, 0, outp, 0);
338 let backup = h.clone();
339 let err = backup.validate().unwrap_err();
340
341 let no_inputs: [(IncomingPort, _); 0] = [];
342 let cfold = ConstantFoldPass::default().with_inputs(backup.entrypoint(), no_inputs);
343 cfold.run(&mut h).unwrap();
344 assert_eq!(h, backup); let r = ValidatingPass::new(cfold).run(&mut h);
347 assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err));
348 }
349
350 #[test]
351 fn test_if_then() {
352 let tr = TypeRow::from(vec![usize_t(); 2]);
353
354 let h = {
355 let sig = Signature::new_endo(tr.clone());
356 let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
357 let [a, b] = fb.input_wires_arr();
358 let tup = fb
359 .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
360 .unwrap();
361 let untup = fb
362 .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
363 .unwrap();
364 fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
365 };
366
367 let untup = UntuplePass::new(UntupleRecursive::Recursive);
368 {
369 let mut repl = ReplaceTypes::default();
371 let usize_custom_t = usize_t().as_extension().unwrap().clone();
372 repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
373 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup.clone());
374
375 let mut h = h.clone();
376 let r = validate_if_test(ifthen, &mut h).unwrap();
377 assert_eq!(
378 r,
379 Some(UntupleResult {
380 rewrites_applied: 1
381 })
382 );
383 let [tuple_in, tuple_out] = h.children(h.entrypoint()).collect_array().unwrap();
384 assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
385 }
386
387 let mut repl = ReplaceTypes::default();
389 let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
390 repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
391 let ifthen = IfThen::<Either<_, _>, _, _, _>::new(repl, untup);
392 let mut h = h;
393 let r = validate_if_test(ifthen, &mut h).unwrap();
394 assert_eq!(r, None);
395 assert_eq!(h.children(h.entrypoint()).count(), 4);
396 let mktup = h
397 .output_neighbours(h.first_child(h.entrypoint()).unwrap())
398 .next()
399 .unwrap();
400 assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
401 }
402}