1use hugr_core::std_extensions::collections::array::{new_array_op, ArrayOpDef};
2use hugr_core::{
3 builder::{BuildError, Dataflow},
4 extension::simple_op::HasConcrete as _,
5 types::Type,
6 Wire,
7};
8use itertools::Itertools as _;
9
10pub trait ArrayOpBuilder: Dataflow {
11 fn add_new_array(
12 &mut self,
13 elem_ty: Type,
14 values: impl IntoIterator<Item = Wire>,
15 ) -> Result<Wire, BuildError> {
16 let inputs = values.into_iter().collect_vec();
17 let [out] = self
18 .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)?
19 .outputs_arr();
20 Ok(out)
21 }
22
23 fn add_array_get(
24 &mut self,
25 elem_ty: Type,
26 size: u64,
27 input: Wire,
28 index: Wire,
29 ) -> Result<Wire, BuildError> {
30 let op = ArrayOpDef::get
32 .instantiate(&[size.into(), elem_ty.into()])
33 .unwrap();
34 let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr();
35 Ok(out)
36 }
37
38 fn add_array_set(
39 &mut self,
40 elem_ty: Type,
41 size: u64,
42 input: Wire,
43 index: Wire,
44 value: Wire,
45 ) -> Result<Wire, BuildError> {
46 let op = ArrayOpDef::set
48 .instantiate(&[size.into(), elem_ty.into()])
49 .unwrap();
50 let [out] = self
51 .add_dataflow_op(op, vec![input, index, value])?
52 .outputs_arr();
53 Ok(out)
54 }
55
56 fn add_array_swap(
57 &mut self,
58 elem_ty: Type,
59 size: u64,
60 input: Wire,
61 index1: Wire,
62 index2: Wire,
63 ) -> Result<Wire, BuildError> {
64 let op = ArrayOpDef::swap
66 .instantiate(&[size.into(), elem_ty.into()])
67 .unwrap();
68 let [out] = self
69 .add_dataflow_op(op, vec![input, index1, index2])?
70 .outputs_arr();
71 Ok(out)
72 }
73
74 fn add_array_pop_left(
75 &mut self,
76 elem_ty: Type,
77 size: u64,
78 input: Wire,
79 ) -> Result<Wire, BuildError> {
80 let op = ArrayOpDef::pop_left
82 .instantiate(&[size.into(), elem_ty.into()])
83 .unwrap();
84 Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
85 }
86
87 fn add_array_pop_right(
88 &mut self,
89 elem_ty: Type,
90 size: u64,
91 input: Wire,
92 ) -> Result<Wire, BuildError> {
93 let op = ArrayOpDef::pop_right
95 .instantiate(&[size.into(), elem_ty.into()])
96 .unwrap();
97 Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
98 }
99
100 fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> {
101 self.add_dataflow_op(
103 ArrayOpDef::discard_empty
104 .instantiate(&[elem_ty.into()])
105 .unwrap(),
106 [input],
107 )?;
108 Ok(())
109 }
110}
111
112impl<D: Dataflow> ArrayOpBuilder for D {}
113
114#[cfg(test)]
115pub mod test {
116 use hugr_core::extension::prelude::PRELUDE_ID;
117 use hugr_core::extension::ExtensionSet;
118 use hugr_core::std_extensions::collections::array::{self, array_type};
119 use hugr_core::{
120 builder::{DFGBuilder, HugrBuilder},
121 extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
122 types::Signature,
123 Hugr,
124 };
125 use rstest::rstest;
126
127 use super::*;
128
129 #[rstest::fixture]
130 #[default(DFGBuilder<Hugr>)]
131 pub fn all_array_ops<B: Dataflow>(
132 #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)
133 .with_extension_delta(ExtensionSet::from_iter([
134 PRELUDE_ID,
135 array::EXTENSION_ID
136 ]))).unwrap())]
137 mut builder: B,
138 ) -> B {
139 let us0 = builder.add_load_value(ConstUsize::new(0));
140 let us1 = builder.add_load_value(ConstUsize::new(1));
141 let us2 = builder.add_load_value(ConstUsize::new(2));
142 let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
143 let [arr] = {
144 let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap();
145 let res_sum_ty = {
146 let array_type = array_type(2, usize_t());
147 either_type(array_type.clone(), array_type)
148 };
149 builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
150 };
151
152 let [elem_0] = {
153 let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
154 builder
155 .build_unwrap_sum(1, option_type(usize_t()), r)
156 .unwrap()
157 };
158
159 let [_elem_1, arr] = {
160 let r = builder
161 .add_array_set(usize_t(), 2, arr, us1, elem_0)
162 .unwrap();
163 let res_sum_ty = {
164 let row = vec![usize_t(), array_type(2, usize_t())];
165 either_type(row.clone(), row)
166 };
167 builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
168 };
169
170 let [_elem_left, arr] = {
171 let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap();
172 builder
173 .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r)
174 .unwrap()
175 };
176 let [_elem_right, arr] = {
177 let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap();
178 builder
179 .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r)
180 .unwrap()
181 };
182
183 builder.add_array_discard_empty(usize_t(), arr).unwrap();
184 builder
185 }
186
187 #[rstest]
188 fn build_all_ops(all_array_ops: DFGBuilder<Hugr>) {
189 all_array_ops.finish_hugr().unwrap();
190 }
191}