hugr_core/std_extensions/
ptr.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Wire;
8use crate::builder::{BuildError, Dataflow};
9use crate::extension::TypeDefBound;
10use crate::ops::OpName;
11use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName};
12use crate::{
13 Extension,
14 extension::{
15 ExtensionId, OpDef, SignatureError, SignatureFunc,
16 simple_op::{
17 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
18 },
19 },
20 ops::custom::ExtensionOp,
21 type_row,
22 types::type_param::{TypeArg, TypeParam},
23};
24use lazy_static::lazy_static;
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs)]
27#[non_exhaustive]
28pub enum PtrOpDef {
30 New,
32 Read,
34 Write,
36}
37
38impl PtrOpDef {
39 #[must_use]
41 pub fn with_type(self, ty: Type) -> PtrOp {
42 PtrOp::new(self, ty)
43 }
44}
45
46impl MakeOpDef for PtrOpDef {
47 fn opdef_id(&self) -> OpName {
48 <&'static str>::from(self).into()
49 }
50
51 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
52 where
53 Self: Sized,
54 {
55 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
56 }
57
58 fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc {
59 let ptr_t: Type =
60 ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into();
61 let inner_t = Type::new_var_use(0, TypeBound::Copyable);
62 let body = match self {
63 PtrOpDef::New => Signature::new(inner_t, ptr_t),
64 PtrOpDef::Read => Signature::new(ptr_t, inner_t),
65 PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]),
66 };
67
68 PolyFuncType::new(TYPE_PARAMS, body).into()
69 }
70
71 fn extension(&self) -> ExtensionId {
72 EXTENSION_ID
73 }
74
75 fn extension_ref(&self) -> Weak<Extension> {
76 Arc::downgrade(&EXTENSION)
77 }
78
79 fn description(&self) -> String {
80 match self {
81 PtrOpDef::New => "Create a new pointer from a value.".into(),
82 PtrOpDef::Read => "Read a value from a pointer.".into(),
83 PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(),
84 }
85 }
86}
87
88pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr");
90pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr");
92const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type {
93 b: TypeBound::Copyable,
94}];
95pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
97
98fn extension() -> Arc<Extension> {
100 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
101 extension
102 .add_type(
103 PTR_TYPE_ID,
104 TYPE_PARAMS.into(),
105 "Standard extension pointer type.".into(),
106 TypeDefBound::copyable(),
107 extension_ref,
108 )
109 .unwrap();
110 PtrOpDef::load_all_ops(extension, extension_ref).unwrap();
111 })
112}
113
114lazy_static! {
115 pub static ref EXTENSION: Arc<Extension> = extension();
117}
118
119fn ptr_custom_type(ty: impl Into<Type>, extension_ref: &Weak<Extension>) -> CustomType {
123 let ty = ty.into();
124 CustomType::new(
125 PTR_TYPE_ID,
126 [ty.into()],
127 EXTENSION_ID,
128 TypeBound::Copyable,
129 extension_ref,
130 )
131}
132
133pub fn ptr_type(ty: impl Into<Type>) -> Type {
135 ptr_custom_type(ty, &Arc::<Extension>::downgrade(&EXTENSION)).into()
136}
137
138#[derive(Clone, Debug, PartialEq)]
139pub struct PtrOp {
141 pub def: PtrOpDef,
143 pub ty: Type,
145}
146
147impl PtrOp {
148 fn new(op: PtrOpDef, ty: Type) -> Self {
149 Self { def: op, ty }
150 }
151}
152
153impl MakeExtensionOp for PtrOp {
154 fn op_id(&self) -> OpName {
155 self.def.opdef_id()
156 }
157
158 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
159 let def = PtrOpDef::from_def(ext_op.def())?;
160 def.instantiate(ext_op.args())
161 }
162
163 fn type_args(&self) -> Vec<TypeArg> {
164 vec![self.ty.clone().into()]
165 }
166}
167
168impl MakeRegisteredOp for PtrOp {
169 fn extension_id(&self) -> ExtensionId {
170 EXTENSION_ID.clone()
171 }
172
173 fn extension_ref(&self) -> Weak<Extension> {
174 Arc::downgrade(&EXTENSION)
175 }
176}
177
178pub trait PtrOpBuilder: Dataflow {
181 fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> {
183 let ty = self.get_wire_type(val_wire)?;
184 let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?;
185
186 Ok(handle.out_wire(0))
187 }
188
189 fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> {
191 let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?;
192 Ok(handle.out_wire(0))
193 }
194
195 fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> {
197 let ty = self.get_wire_type(val_wire)?;
198
199 let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?;
200 debug_assert_eq!(handle.outputs().len(), 0);
201 Ok(())
202 }
203}
204
205impl<D: Dataflow> PtrOpBuilder for D {}
206
207impl HasConcrete for PtrOpDef {
208 type Concrete = PtrOp;
209
210 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
211 let ty = match type_args {
212 [TypeArg::Type { ty }] => ty.clone(),
213 _ => return Err(SignatureError::InvalidTypeArgs.into()),
214 };
215
216 Ok(self.with_type(ty))
217 }
218}
219
220impl HasDef for PtrOp {
221 type Def = PtrOpDef;
222}
223
224#[cfg(test)]
225pub(crate) mod test {
226 use crate::HugrView;
227 use crate::builder::DFGBuilder;
228 use crate::extension::prelude::bool_t;
229 use crate::ops::ExtensionOp;
230 use crate::{
231 builder::{Dataflow, DataflowHugr},
232 std_extensions::arithmetic::int_types::INT_TYPES,
233 };
234 use cool_asserts::assert_matches;
235 use std::sync::Arc;
236 use strum::IntoEnumIterator;
237
238 use super::*;
239 use crate::std_extensions::arithmetic::float_types::float64_type;
240 fn get_opdef(op: impl Into<&'static str>) -> Option<&'static Arc<OpDef>> {
241 EXTENSION.get_op(op.into())
242 }
243
244 #[test]
245 fn create_extension() {
246 assert_eq!(EXTENSION.name(), &EXTENSION_ID);
247
248 for o in PtrOpDef::iter() {
249 assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o));
250 }
251 }
252
253 #[test]
254 fn test_ops() {
255 let ops = [
256 PtrOp::new(PtrOpDef::New, bool_t().clone()),
257 PtrOp::new(PtrOpDef::Read, float64_type()),
258 PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()),
259 ];
260 for op in ops {
261 let op_t: ExtensionOp = op.clone().to_extension_op().unwrap();
262 let def_op = PtrOpDef::from_op(&op_t).unwrap();
263 assert_eq!(op.def, def_op);
264 let new_op = PtrOp::from_op(&op_t).unwrap();
265 assert_eq!(new_op, op);
266 }
267 }
268
269 #[test]
270 fn test_build() {
271 let in_row = vec![bool_t(), float64_type()];
272
273 let hugr = {
274 let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap();
275
276 let in_wires: [Wire; 2] = builder.input_wires_arr();
277 for (ty, w) in in_row.into_iter().zip(in_wires.iter()) {
278 let new_ptr = builder.add_new_ptr(*w).unwrap();
279 let read = builder.add_read_ptr(new_ptr, ty).unwrap();
280 builder.add_write_ptr(new_ptr, read).unwrap();
281 }
282
283 builder.finish_hugr_with_outputs([]).unwrap()
284 };
285 assert_matches!(hugr.validate(), Ok(()));
286 }
287}