hugr_core/extension/prelude/
generic.rs1use std::str::FromStr;
2use std::sync::{Arc, Weak};
3
4use crate::extension::OpDef;
5use crate::extension::SignatureFunc;
6use crate::extension::prelude::usize_custom_t;
7use crate::extension::simple_op::{
8 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{ConstFold, ExtensionId};
11use crate::ops::ExtensionOp;
12use crate::ops::OpName;
13use crate::type_row;
14use crate::types::FuncValueType;
15
16use crate::types::Type;
17
18use crate::extension::SignatureError;
19
20use crate::types::PolyFuncTypeRV;
21
22use crate::Extension;
23use crate::types::type_param::TypeArg;
24
25use super::PRELUDE;
26use super::{ConstUsize, PRELUDE_ID};
27use crate::types::type_param::TypeParam;
28
29pub static LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");
31
32#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
34pub struct LoadNatDef;
35
36impl FromStr for LoadNatDef {
37 type Err = ();
38
39 fn from_str(s: &str) -> Result<Self, Self::Err> {
40 if s == Self.op_id() { Ok(Self) } else { Err(()) }
41 }
42}
43
44impl ConstFold for LoadNatDef {
45 fn fold(
46 &self,
47 type_args: &[TypeArg],
48 _consts: &[(crate::IncomingPort, crate::ops::Value)],
49 ) -> crate::extension::ConstFoldResult {
50 let [arg] = type_args else {
51 return None;
52 };
53 let nat = arg.as_nat();
54 if let Some(n) = nat {
55 let n_const = ConstUsize::new(n);
56 Some(vec![(0.into(), n_const.into())])
57 } else {
58 None
59 }
60 }
61}
62
63impl MakeOpDef for LoadNatDef {
64 fn opdef_id(&self) -> OpName {
65 LOAD_NAT_OP_ID.clone()
66 }
67
68 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
69 where
70 Self: Sized,
71 {
72 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
73 }
74
75 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
76 let usize_t: Type = usize_custom_t(_extension_ref).into();
77 let params = vec![TypeParam::max_nat_type()];
78 PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into()
79 }
80
81 fn extension_ref(&self) -> Weak<Extension> {
82 Arc::downgrade(&PRELUDE)
83 }
84
85 fn extension(&self) -> ExtensionId {
86 PRELUDE_ID
87 }
88
89 fn description(&self) -> String {
90 "Loads a generic bounded nat parameter into a usize runtime value.".into()
91 }
92
93 fn post_opdef(&self, def: &mut OpDef) {
94 def.set_constant_folder(*self);
95 }
96}
97
98#[derive(Clone, Debug, PartialEq)]
100pub struct LoadNat {
101 nat: TypeArg,
102}
103
104impl LoadNat {
105 #[must_use]
107 pub fn new(nat: TypeArg) -> Self {
108 LoadNat { nat }
109 }
110
111 #[must_use]
113 pub fn get_nat(self) -> TypeArg {
114 self.nat
115 }
116}
117
118impl MakeExtensionOp for LoadNat {
119 fn op_id(&self) -> OpName {
120 LoadNatDef.opdef_id()
121 }
122
123 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
124 where
125 Self: Sized,
126 {
127 let def = LoadNatDef::from_def(ext_op.def())?;
128 def.instantiate(ext_op.args())
129 }
130
131 fn type_args(&self) -> Vec<TypeArg> {
132 vec![self.nat.clone()]
133 }
134}
135
136impl MakeRegisteredOp for LoadNat {
137 fn extension_id(&self) -> ExtensionId {
138 PRELUDE_ID
139 }
140
141 fn extension_ref(&self) -> Weak<Extension> {
142 Arc::downgrade(&PRELUDE)
143 }
144}
145
146impl HasDef for LoadNat {
147 type Def = LoadNatDef;
148}
149
150impl HasConcrete for LoadNatDef {
151 type Concrete = LoadNat;
152
153 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
154 match type_args {
155 [n] => Ok(LoadNat::new(n.clone())),
156 _ => Err(SignatureError::InvalidTypeArgs.into()),
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use crate::{
164 HugrView, OutgoingPort,
165 builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig},
166 extension::prelude::{ConstUsize, usize_t},
167 ops::{OpType, constant},
168 type_row,
169 types::Term,
170 };
171
172 use super::LoadNat;
173
174 #[test]
175 fn test_load_nat() {
176 let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap();
177
178 let arg = Term::from(4u64);
179 let op = LoadNat::new(arg);
180
181 let out = b.add_dataflow_op(op.clone(), []).unwrap();
182
183 let result = b.finish_hugr_with_outputs(out.outputs()).unwrap();
184
185 let exp_optype: OpType = op.into();
186
187 for child in result.children(result.entrypoint()) {
188 let node_optype = result.get_optype(child);
189 if !node_optype.is_input() && !node_optype.is_output() {
191 assert_eq!(node_optype, &exp_optype);
192 }
193 }
194 }
195
196 #[test]
197 fn test_load_nat_fold() {
198 let arg = Term::from(5u64);
199 let op = LoadNat::new(arg);
200
201 let optype: OpType = op.into();
202
203 if let OpType::ExtensionOp(ext_op) = optype {
204 let result = ext_op.constant_fold(&[]);
205 let exp_port: OutgoingPort = 0.into();
206 let exp_val: constant::Value = ConstUsize::new(5).into();
207 assert_eq!(result, Some(vec![(exp_port, exp_val)]));
208 } else {
209 panic!()
210 }
211 }
212}