hugr_core/std_extensions/
ptr.rs

1//! Pointer type and operations.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::builder::{BuildError, Dataflow};
8use crate::extension::TypeDefBound;
9use crate::ops::OpName;
10use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName};
11use crate::Wire;
12use crate::{
13    extension::{
14        simple_op::{
15            HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
16        },
17        ExtensionId, OpDef, SignatureError, SignatureFunc,
18    },
19    ops::{custom::ExtensionOp, NamedOp},
20    type_row,
21    types::type_param::{TypeArg, TypeParam},
22    Extension,
23};
24use lazy_static::lazy_static;
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs)]
27#[non_exhaustive]
28/// Pointer operation definitions.
29pub enum PtrOpDef {
30    /// Create a new pointer.
31    New,
32    /// Read a value from a pointer.
33    Read,
34    /// Write a value to a pointer.
35    Write,
36}
37
38impl PtrOpDef {
39    /// Create a new concrete pointer operation with the given value type.
40    pub fn with_type(self, ty: Type) -> PtrOp {
41        PtrOp::new(self, ty)
42    }
43}
44
45impl MakeOpDef for PtrOpDef {
46    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
47    where
48        Self: Sized,
49    {
50        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
51    }
52
53    fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc {
54        let ptr_t: Type =
55            ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into();
56        let inner_t = Type::new_var_use(0, TypeBound::Copyable);
57        let body = match self {
58            PtrOpDef::New => Signature::new(inner_t, ptr_t),
59            PtrOpDef::Read => Signature::new(ptr_t, inner_t),
60            PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]),
61        };
62
63        PolyFuncType::new(TYPE_PARAMS, body).into()
64    }
65
66    fn extension(&self) -> ExtensionId {
67        EXTENSION_ID
68    }
69
70    fn extension_ref(&self) -> Weak<Extension> {
71        Arc::downgrade(&EXTENSION)
72    }
73
74    fn description(&self) -> String {
75        match self {
76            PtrOpDef::New => "Create a new pointer from a value.".into(),
77            PtrOpDef::Read => "Read a value from a pointer.".into(),
78            PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(),
79        }
80    }
81}
82
83/// Name of pointer extension.
84pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr");
85/// Name of pointer type.
86pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr");
87const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type {
88    b: TypeBound::Copyable,
89}];
90/// Extension version.
91pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
92
93/// Extension for pointer operations.
94fn extension() -> Arc<Extension> {
95    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
96        extension
97            .add_type(
98                PTR_TYPE_ID,
99                TYPE_PARAMS.into(),
100                "Standard extension pointer type.".into(),
101                TypeDefBound::copyable(),
102                extension_ref,
103            )
104            .unwrap();
105        PtrOpDef::load_all_ops(extension, extension_ref).unwrap();
106    })
107}
108
109lazy_static! {
110    /// Reference to the pointer Extension.
111    pub static ref EXTENSION: Arc<Extension> = extension();
112}
113
114/// Integer type of a given bit width (specified by the TypeArg).  Depending on
115/// the operation, the semantic interpretation may be unsigned integer, signed
116/// integer or bit string.
117fn ptr_custom_type(ty: impl Into<Type>, extension_ref: &Weak<Extension>) -> CustomType {
118    let ty = ty.into();
119    CustomType::new(
120        PTR_TYPE_ID,
121        [ty.into()],
122        EXTENSION_ID,
123        TypeBound::Copyable,
124        extension_ref,
125    )
126}
127
128/// Integer type of a given bit width (specified by the TypeArg).
129pub fn ptr_type(ty: impl Into<Type>) -> Type {
130    ptr_custom_type(ty, &Arc::<Extension>::downgrade(&EXTENSION)).into()
131}
132
133#[derive(Clone, Debug, PartialEq)]
134/// A concrete pointer operation.
135pub struct PtrOp {
136    /// The operation definition.
137    pub def: PtrOpDef,
138    /// Type of the value being pointed to.
139    pub ty: Type,
140}
141
142impl PtrOp {
143    fn new(op: PtrOpDef, ty: Type) -> Self {
144        Self { def: op, ty }
145    }
146}
147
148impl NamedOp for PtrOp {
149    fn name(&self) -> OpName {
150        self.def.name()
151    }
152}
153
154impl MakeExtensionOp for PtrOp {
155    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
156        let def = PtrOpDef::from_def(ext_op.def())?;
157        def.instantiate(ext_op.args())
158    }
159
160    fn type_args(&self) -> Vec<TypeArg> {
161        vec![self.ty.clone().into()]
162    }
163}
164
165impl MakeRegisteredOp for PtrOp {
166    fn extension_id(&self) -> ExtensionId {
167        EXTENSION_ID.to_owned()
168    }
169
170    fn extension_ref(&self) -> Weak<Extension> {
171        Arc::downgrade(&EXTENSION)
172    }
173}
174
175/// An extension trait for [Dataflow] providing methods to add pointer
176/// operations.
177pub trait PtrOpBuilder: Dataflow {
178    /// Add a "ptr.New" op.
179    fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> {
180        let ty = self.get_wire_type(val_wire)?;
181        let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?;
182
183        Ok(handle.out_wire(0))
184    }
185
186    /// Add a "ptr.Read" op.
187    fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> {
188        let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?;
189        Ok(handle.out_wire(0))
190    }
191
192    /// Add a "ptr.Write" op.
193    fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> {
194        let ty = self.get_wire_type(val_wire)?;
195
196        let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?;
197        debug_assert_eq!(handle.outputs().len(), 0);
198        Ok(())
199    }
200}
201
202impl<D: Dataflow> PtrOpBuilder for D {}
203
204impl HasConcrete for PtrOpDef {
205    type Concrete = PtrOp;
206
207    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
208        let ty = match type_args {
209            [TypeArg::Type { ty }] => ty.clone(),
210            _ => return Err(SignatureError::InvalidTypeArgs.into()),
211        };
212
213        Ok(self.with_type(ty))
214    }
215}
216
217impl HasDef for PtrOp {
218    type Def = PtrOpDef;
219}
220
221#[cfg(test)]
222pub(crate) mod test {
223    use crate::builder::DFGBuilder;
224    use crate::extension::prelude::bool_t;
225    use crate::ops::ExtensionOp;
226    use crate::{
227        builder::{Dataflow, DataflowHugr},
228        ops::NamedOp,
229        std_extensions::arithmetic::int_types::INT_TYPES,
230    };
231    use cool_asserts::assert_matches;
232    use std::sync::Arc;
233    use strum::IntoEnumIterator;
234
235    use super::*;
236    use crate::std_extensions::arithmetic::float_types::float64_type;
237    fn get_opdef(op: impl NamedOp) -> Option<&'static Arc<OpDef>> {
238        EXTENSION.get_op(&op.name())
239    }
240
241    #[test]
242    fn create_extension() {
243        assert_eq!(EXTENSION.name(), &EXTENSION_ID);
244
245        for o in PtrOpDef::iter() {
246            assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o));
247        }
248    }
249
250    #[test]
251    fn test_ops() {
252        let ops = [
253            PtrOp::new(PtrOpDef::New, bool_t().clone()),
254            PtrOp::new(PtrOpDef::Read, float64_type()),
255            PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()),
256        ];
257        for op in ops {
258            let op_t: ExtensionOp = op.clone().to_extension_op().unwrap();
259            let def_op = PtrOpDef::from_op(&op_t).unwrap();
260            assert_eq!(op.def, def_op);
261            let new_op = PtrOp::from_op(&op_t).unwrap();
262            assert_eq!(new_op, op);
263        }
264    }
265
266    #[test]
267    fn test_build() {
268        let in_row = vec![bool_t(), float64_type()];
269
270        let hugr = {
271            let mut builder = DFGBuilder::new(
272                Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID),
273            )
274            .unwrap();
275
276            let in_wires: [Wire; 2] = builder.input_wires_arr();
277            for (ty, w) in in_row.into_iter().zip(in_wires.iter()) {
278                let new_ptr = builder.add_new_ptr(*w).unwrap();
279                let read = builder.add_read_ptr(new_ptr, ty).unwrap();
280                builder.add_write_ptr(new_ptr, read).unwrap();
281            }
282
283            builder.finish_hugr_with_outputs([]).unwrap()
284        };
285        assert_matches!(hugr.validate(), Ok(_));
286    }
287}