hugr_core/std_extensions/
ptr.rs1use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Wire;
8use crate::builder::{BuildError, Dataflow};
9use crate::extension::TypeDefBound;
10use crate::ops::OpName;
11use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName};
12use crate::{
13 Extension,
14 extension::{
15 ExtensionId, OpDef, SignatureError, SignatureFunc,
16 simple_op::{
17 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
18 },
19 },
20 ops::custom::ExtensionOp,
21 type_row,
22 types::type_param::{TypeArg, TypeParam},
23};
24#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
25#[allow(missing_docs)]
26#[non_exhaustive]
27pub enum PtrOpDef {
29 New,
31 Read,
33 Write,
35}
36
37impl PtrOpDef {
38 #[must_use]
40 pub fn with_type(self, ty: Type) -> PtrOp {
41 PtrOp::new(self, ty)
42 }
43}
44
45impl MakeOpDef for PtrOpDef {
46 fn opdef_id(&self) -> OpName {
47 <&'static str>::from(self).into()
48 }
49
50 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
51 where
52 Self: Sized,
53 {
54 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
55 }
56
57 fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc {
58 let ptr_t: Type =
59 ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into();
60 let inner_t = Type::new_var_use(0, TypeBound::Copyable);
61 let body = match self {
62 PtrOpDef::New => Signature::new(inner_t, ptr_t),
63 PtrOpDef::Read => Signature::new(ptr_t, inner_t),
64 PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]),
65 };
66
67 PolyFuncType::new(TYPE_PARAMS, body).into()
68 }
69
70 fn extension(&self) -> ExtensionId {
71 EXTENSION_ID
72 }
73
74 fn extension_ref(&self) -> Weak<Extension> {
75 Arc::downgrade(&EXTENSION)
76 }
77
78 fn description(&self) -> String {
79 match self {
80 PtrOpDef::New => "Create a new pointer from a value.".into(),
81 PtrOpDef::Read => "Read a value from a pointer.".into(),
82 PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(),
83 }
84 }
85}
86
87pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr");
89pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr");
91const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::RuntimeType(TypeBound::Copyable)];
92pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
94
95fn extension() -> Arc<Extension> {
97 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
98 extension
99 .add_type(
100 PTR_TYPE_ID,
101 TYPE_PARAMS.into(),
102 "Standard extension pointer type.".into(),
103 TypeDefBound::copyable(),
104 extension_ref,
105 )
106 .unwrap();
107 PtrOpDef::load_all_ops(extension, extension_ref).unwrap();
108 })
109}
110
111pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(extension);
113
114fn ptr_custom_type(ty: impl Into<Type>, extension_ref: &Weak<Extension>) -> CustomType {
118 let ty = ty.into();
119 CustomType::new(
120 PTR_TYPE_ID,
121 [ty.into()],
122 EXTENSION_ID,
123 TypeBound::Copyable,
124 extension_ref,
125 )
126}
127
128pub fn ptr_type(ty: impl Into<Type>) -> Type {
130 ptr_custom_type(ty, &Arc::<Extension>::downgrade(&EXTENSION)).into()
131}
132
133#[derive(Clone, Debug, PartialEq)]
134pub struct PtrOp {
136 pub def: PtrOpDef,
138 pub ty: Type,
140}
141
142impl PtrOp {
143 fn new(op: PtrOpDef, ty: Type) -> Self {
144 Self { def: op, ty }
145 }
146}
147
148impl MakeExtensionOp for PtrOp {
149 fn op_id(&self) -> OpName {
150 self.def.opdef_id()
151 }
152
153 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
154 let def = PtrOpDef::from_def(ext_op.def())?;
155 def.instantiate(ext_op.args())
156 }
157
158 fn type_args(&self) -> Vec<TypeArg> {
159 vec![self.ty.clone().into()]
160 }
161}
162
163impl MakeRegisteredOp for PtrOp {
164 fn extension_id(&self) -> ExtensionId {
165 EXTENSION_ID.clone()
166 }
167
168 fn extension_ref(&self) -> Weak<Extension> {
169 Arc::downgrade(&EXTENSION)
170 }
171}
172
173pub trait PtrOpBuilder: Dataflow {
176 fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> {
178 let ty = self.get_wire_type(val_wire)?;
179 let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?;
180
181 Ok(handle.out_wire(0))
182 }
183
184 fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> {
186 let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?;
187 Ok(handle.out_wire(0))
188 }
189
190 fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> {
192 let ty = self.get_wire_type(val_wire)?;
193
194 let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?;
195 debug_assert_eq!(handle.outputs().len(), 0);
196 Ok(())
197 }
198}
199
200impl<D: Dataflow> PtrOpBuilder for D {}
201
202impl HasConcrete for PtrOpDef {
203 type Concrete = PtrOp;
204
205 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
206 let ty = match type_args {
207 [TypeArg::Runtime(ty)] => ty.clone(),
208 _ => return Err(SignatureError::InvalidTypeArgs.into()),
209 };
210
211 Ok(self.with_type(ty))
212 }
213}
214
215impl HasDef for PtrOp {
216 type Def = PtrOpDef;
217}
218
219#[cfg(test)]
220pub(crate) mod test {
221 use crate::HugrView;
222 use crate::builder::DFGBuilder;
223 use crate::extension::prelude::bool_t;
224 use crate::ops::ExtensionOp;
225 use crate::{
226 builder::{Dataflow, DataflowHugr},
227 std_extensions::arithmetic::int_types::INT_TYPES,
228 };
229 use cool_asserts::assert_matches;
230 use std::sync::Arc;
231 use strum::IntoEnumIterator;
232
233 use super::*;
234 use crate::std_extensions::arithmetic::float_types::float64_type;
235 fn get_opdef(op: impl Into<&'static str>) -> Option<&'static Arc<OpDef>> {
236 EXTENSION.get_op(op.into())
237 }
238
239 #[test]
240 fn create_extension() {
241 assert_eq!(EXTENSION.name(), &EXTENSION_ID);
242
243 for o in PtrOpDef::iter() {
244 assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o));
245 }
246 }
247
248 #[test]
249 fn test_ops() {
250 let ops = [
251 PtrOp::new(PtrOpDef::New, bool_t().clone()),
252 PtrOp::new(PtrOpDef::Read, float64_type()),
253 PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()),
254 ];
255 for op in ops {
256 let op_t: ExtensionOp = op.clone().to_extension_op().unwrap();
257 let def_op = PtrOpDef::from_op(&op_t).unwrap();
258 assert_eq!(op.def, def_op);
259 let new_op = PtrOp::from_op(&op_t).unwrap();
260 assert_eq!(new_op, op);
261 }
262 }
263
264 #[test]
265 fn test_build() {
266 let in_row = vec![bool_t(), float64_type()];
267
268 let hugr = {
269 let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap();
270
271 let in_wires: [Wire; 2] = builder.input_wires_arr();
272 for (ty, w) in in_row.into_iter().zip(in_wires.iter()) {
273 let new_ptr = builder.add_new_ptr(*w).unwrap();
274 let read = builder.add_read_ptr(new_ptr, ty).unwrap();
275 builder.add_write_ptr(new_ptr, read).unwrap();
276 }
277
278 builder.finish_hugr_with_outputs([]).unwrap()
279 };
280 assert_matches!(hugr.validate(), Ok(()));
281 }
282}