hugr_core/extension/prelude/
generic.rs1use std::str::FromStr;
2use std::sync::{Arc, Weak};
3
4use crate::extension::prelude::usize_custom_t;
5use crate::extension::simple_op::{
6 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
7};
8use crate::extension::OpDef;
9use crate::extension::SignatureFunc;
10use crate::extension::{ConstFold, ExtensionId};
11use crate::ops::ExtensionOp;
12use crate::ops::NamedOp;
13use crate::ops::OpName;
14use crate::type_row;
15use crate::types::FuncValueType;
16
17use crate::types::Type;
18
19use crate::extension::SignatureError;
20
21use crate::types::PolyFuncTypeRV;
22
23use crate::types::type_param::TypeArg;
24use crate::Extension;
25
26use super::PRELUDE;
27use super::{ConstUsize, PRELUDE_ID};
28use crate::types::type_param::TypeParam;
29
30pub const LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");
32
33#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
35pub struct LoadNatDef;
36
37impl NamedOp for LoadNatDef {
38 fn name(&self) -> OpName {
39 LOAD_NAT_OP_ID
40 }
41}
42
43impl FromStr for LoadNatDef {
44 type Err = ();
45
46 fn from_str(s: &str) -> Result<Self, Self::Err> {
47 if s == LoadNatDef.name() {
48 Ok(Self)
49 } else {
50 Err(())
51 }
52 }
53}
54
55impl ConstFold for LoadNatDef {
56 fn fold(
57 &self,
58 type_args: &[TypeArg],
59 _consts: &[(crate::IncomingPort, crate::ops::Value)],
60 ) -> crate::extension::ConstFoldResult {
61 let [arg] = type_args else {
62 return None;
63 };
64 let nat = arg.as_nat();
65 if let Some(n) = nat {
66 let n_const = ConstUsize::new(n);
67 Some(vec![(0.into(), n_const.into())])
68 } else {
69 None
70 }
71 }
72}
73
74impl MakeOpDef for LoadNatDef {
75 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
76 where
77 Self: Sized,
78 {
79 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
80 }
81
82 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
83 let usize_t: Type = usize_custom_t(_extension_ref).into();
84 let params = vec![TypeParam::max_nat()];
85 PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into()
86 }
87
88 fn extension_ref(&self) -> Weak<Extension> {
89 Arc::downgrade(&PRELUDE)
90 }
91
92 fn extension(&self) -> ExtensionId {
93 PRELUDE_ID
94 }
95
96 fn description(&self) -> String {
97 "Loads a generic bounded nat parameter into a usize runtime value.".into()
98 }
99
100 fn post_opdef(&self, def: &mut OpDef) {
101 def.set_constant_folder(*self);
102 }
103}
104
105#[derive(Clone, Debug, PartialEq)]
107pub struct LoadNat {
108 nat: TypeArg,
109}
110
111impl LoadNat {
112 pub fn new(nat: TypeArg) -> Self {
114 LoadNat { nat }
115 }
116
117 pub fn get_nat(self) -> TypeArg {
119 self.nat
120 }
121}
122
123impl NamedOp for LoadNat {
124 fn name(&self) -> OpName {
125 LOAD_NAT_OP_ID
126 }
127}
128
129impl MakeExtensionOp for LoadNat {
130 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
131 where
132 Self: Sized,
133 {
134 let def = LoadNatDef::from_def(ext_op.def())?;
135 def.instantiate(ext_op.args())
136 }
137
138 fn type_args(&self) -> Vec<TypeArg> {
139 vec![self.nat.clone()]
140 }
141}
142
143impl MakeRegisteredOp for LoadNat {
144 fn extension_id(&self) -> ExtensionId {
145 PRELUDE_ID
146 }
147
148 fn extension_ref(&self) -> Weak<Extension> {
149 Arc::downgrade(&PRELUDE)
150 }
151}
152
153impl HasDef for LoadNat {
154 type Def = LoadNatDef;
155}
156
157impl HasConcrete for LoadNatDef {
158 type Concrete = LoadNat;
159
160 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
161 match type_args {
162 [n] => Ok(LoadNat::new(n.clone())),
163 _ => Err(SignatureError::InvalidTypeArgs.into()),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use crate::{
171 builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
172 extension::prelude::{usize_t, ConstUsize},
173 ops::{constant, OpType},
174 type_row,
175 types::TypeArg,
176 HugrView, OutgoingPort,
177 };
178
179 use super::LoadNat;
180
181 #[test]
182 fn test_load_nat() {
183 let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap();
184
185 let arg = TypeArg::BoundedNat { n: 4 };
186 let op = LoadNat::new(arg);
187
188 let out = b.add_dataflow_op(op.clone(), []).unwrap();
189
190 let result = b.finish_hugr_with_outputs(out.outputs()).unwrap();
191
192 let exp_optype: OpType = op.into();
193
194 for child in result.children(result.root()) {
195 let node_optype = result.get_optype(child);
196 if !node_optype.is_input() && !node_optype.is_output() {
198 assert_eq!(node_optype, &exp_optype)
199 }
200 }
201 }
202
203 #[test]
204 fn test_load_nat_fold() {
205 let arg = TypeArg::BoundedNat { n: 5 };
206 let op = LoadNat::new(arg);
207
208 let optype: OpType = op.into();
209
210 match optype {
211 OpType::ExtensionOp(ext_op) => {
212 let result = ext_op.constant_fold(&[]);
213 let exp_port: OutgoingPort = 0.into();
214 let exp_val: constant::Value = ConstUsize::new(5).into();
215 assert_eq!(result, Some(vec![(exp_port, exp_val)]))
216 }
217 _ => panic!(),
218 }
219 }
220}