hugr_llvm/utils/
array_op_builder.rs

1use hugr_core::std_extensions::collections::array::{new_array_op, ArrayOpDef};
2use hugr_core::{
3    builder::{BuildError, Dataflow},
4    extension::simple_op::HasConcrete as _,
5    types::Type,
6    Wire,
7};
8use itertools::Itertools as _;
9
10pub trait ArrayOpBuilder: Dataflow {
11    fn add_new_array(
12        &mut self,
13        elem_ty: Type,
14        values: impl IntoIterator<Item = Wire>,
15    ) -> Result<Wire, BuildError> {
16        let inputs = values.into_iter().collect_vec();
17        let [out] = self
18            .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)?
19            .outputs_arr();
20        Ok(out)
21    }
22
23    fn add_array_get(
24        &mut self,
25        elem_ty: Type,
26        size: u64,
27        input: Wire,
28        index: Wire,
29    ) -> Result<Wire, BuildError> {
30        // TODO Add an OpLoadError variant to BuildError.
31        let op = ArrayOpDef::get
32            .instantiate(&[size.into(), elem_ty.into()])
33            .unwrap();
34        let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr();
35        Ok(out)
36    }
37
38    fn add_array_set(
39        &mut self,
40        elem_ty: Type,
41        size: u64,
42        input: Wire,
43        index: Wire,
44        value: Wire,
45    ) -> Result<Wire, BuildError> {
46        // TODO Add an OpLoadError variant to BuildError
47        let op = ArrayOpDef::set
48            .instantiate(&[size.into(), elem_ty.into()])
49            .unwrap();
50        let [out] = self
51            .add_dataflow_op(op, vec![input, index, value])?
52            .outputs_arr();
53        Ok(out)
54    }
55
56    fn add_array_swap(
57        &mut self,
58        elem_ty: Type,
59        size: u64,
60        input: Wire,
61        index1: Wire,
62        index2: Wire,
63    ) -> Result<Wire, BuildError> {
64        // TODO Add an OpLoadError variant to BuildError
65        let op = ArrayOpDef::swap
66            .instantiate(&[size.into(), elem_ty.into()])
67            .unwrap();
68        let [out] = self
69            .add_dataflow_op(op, vec![input, index1, index2])?
70            .outputs_arr();
71        Ok(out)
72    }
73
74    fn add_array_pop_left(
75        &mut self,
76        elem_ty: Type,
77        size: u64,
78        input: Wire,
79    ) -> Result<Wire, BuildError> {
80        // TODO Add an OpLoadError variant to BuildError
81        let op = ArrayOpDef::pop_left
82            .instantiate(&[size.into(), elem_ty.into()])
83            .unwrap();
84        Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
85    }
86
87    fn add_array_pop_right(
88        &mut self,
89        elem_ty: Type,
90        size: u64,
91        input: Wire,
92    ) -> Result<Wire, BuildError> {
93        // TODO Add an OpLoadError variant to BuildError
94        let op = ArrayOpDef::pop_right
95            .instantiate(&[size.into(), elem_ty.into()])
96            .unwrap();
97        Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
98    }
99
100    fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> {
101        // TODO Add an OpLoadError variant to BuildError
102        self.add_dataflow_op(
103            ArrayOpDef::discard_empty
104                .instantiate(&[elem_ty.into()])
105                .unwrap(),
106            [input],
107        )?;
108        Ok(())
109    }
110}
111
112impl<D: Dataflow> ArrayOpBuilder for D {}
113
114#[cfg(test)]
115pub mod test {
116    use hugr_core::extension::prelude::PRELUDE_ID;
117    use hugr_core::extension::ExtensionSet;
118    use hugr_core::std_extensions::collections::array::{self, array_type};
119    use hugr_core::{
120        builder::{DFGBuilder, HugrBuilder},
121        extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
122        types::Signature,
123        Hugr,
124    };
125    use rstest::rstest;
126
127    use super::*;
128
129    #[rstest::fixture]
130    #[default(DFGBuilder<Hugr>)]
131    pub fn all_array_ops<B: Dataflow>(
132        #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)
133            .with_extension_delta(ExtensionSet::from_iter([
134                PRELUDE_ID,
135                array::EXTENSION_ID
136        ]))).unwrap())]
137        mut builder: B,
138    ) -> B {
139        let us0 = builder.add_load_value(ConstUsize::new(0));
140        let us1 = builder.add_load_value(ConstUsize::new(1));
141        let us2 = builder.add_load_value(ConstUsize::new(2));
142        let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
143        let [arr] = {
144            let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap();
145            let res_sum_ty = {
146                let array_type = array_type(2, usize_t());
147                either_type(array_type.clone(), array_type)
148            };
149            builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
150        };
151
152        let [elem_0] = {
153            let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
154            builder
155                .build_unwrap_sum(1, option_type(usize_t()), r)
156                .unwrap()
157        };
158
159        let [_elem_1, arr] = {
160            let r = builder
161                .add_array_set(usize_t(), 2, arr, us1, elem_0)
162                .unwrap();
163            let res_sum_ty = {
164                let row = vec![usize_t(), array_type(2, usize_t())];
165                either_type(row.clone(), row)
166            };
167            builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
168        };
169
170        let [_elem_left, arr] = {
171            let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap();
172            builder
173                .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r)
174                .unwrap()
175        };
176        let [_elem_right, arr] = {
177            let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap();
178            builder
179                .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r)
180                .unwrap()
181        };
182
183        builder.add_array_discard_empty(usize_t(), arr).unwrap();
184        builder
185    }
186
187    #[rstest]
188    fn build_all_ops(all_array_ops: DFGBuilder<Hugr>) {
189        all_array_ops.finish_hugr().unwrap();
190    }
191}