hugr_core/std_extensions/
ptr.rs

1//! Pointer type and operations.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Wire;
8use crate::builder::{BuildError, Dataflow};
9use crate::extension::TypeDefBound;
10use crate::ops::OpName;
11use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName};
12use crate::{
13    Extension,
14    extension::{
15        ExtensionId, OpDef, SignatureError, SignatureFunc,
16        simple_op::{
17            HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
18        },
19    },
20    ops::custom::ExtensionOp,
21    type_row,
22    types::type_param::{TypeArg, TypeParam},
23};
24use lazy_static::lazy_static;
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs)]
27#[non_exhaustive]
28/// Pointer operation definitions.
29pub enum PtrOpDef {
30    /// Create a new pointer.
31    New,
32    /// Read a value from a pointer.
33    Read,
34    /// Write a value to a pointer.
35    Write,
36}
37
38impl PtrOpDef {
39    /// Create a new concrete pointer operation with the given value type.
40    #[must_use]
41    pub fn with_type(self, ty: Type) -> PtrOp {
42        PtrOp::new(self, ty)
43    }
44}
45
46impl MakeOpDef for PtrOpDef {
47    fn opdef_id(&self) -> OpName {
48        <&'static str>::from(self).into()
49    }
50
51    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
52    where
53        Self: Sized,
54    {
55        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
56    }
57
58    fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc {
59        let ptr_t: Type =
60            ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into();
61        let inner_t = Type::new_var_use(0, TypeBound::Copyable);
62        let body = match self {
63            PtrOpDef::New => Signature::new(inner_t, ptr_t),
64            PtrOpDef::Read => Signature::new(ptr_t, inner_t),
65            PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]),
66        };
67
68        PolyFuncType::new(TYPE_PARAMS, body).into()
69    }
70
71    fn extension(&self) -> ExtensionId {
72        EXTENSION_ID
73    }
74
75    fn extension_ref(&self) -> Weak<Extension> {
76        Arc::downgrade(&EXTENSION)
77    }
78
79    fn description(&self) -> String {
80        match self {
81            PtrOpDef::New => "Create a new pointer from a value.".into(),
82            PtrOpDef::Read => "Read a value from a pointer.".into(),
83            PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(),
84        }
85    }
86}
87
88/// Name of pointer extension.
89pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr");
90/// Name of pointer type.
91pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr");
92const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::RuntimeType(TypeBound::Copyable)];
93/// Extension version.
94pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
95
96/// Extension for pointer operations.
97fn extension() -> Arc<Extension> {
98    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
99        extension
100            .add_type(
101                PTR_TYPE_ID,
102                TYPE_PARAMS.into(),
103                "Standard extension pointer type.".into(),
104                TypeDefBound::copyable(),
105                extension_ref,
106            )
107            .unwrap();
108        PtrOpDef::load_all_ops(extension, extension_ref).unwrap();
109    })
110}
111
112lazy_static! {
113    /// Reference to the pointer Extension.
114    pub static ref EXTENSION: Arc<Extension> = extension();
115}
116
117/// Integer type of a given bit width (specified by the `TypeArg`).  Depending on
118/// the operation, the semantic interpretation may be unsigned integer, signed
119/// integer or bit string.
120fn ptr_custom_type(ty: impl Into<Type>, extension_ref: &Weak<Extension>) -> CustomType {
121    let ty = ty.into();
122    CustomType::new(
123        PTR_TYPE_ID,
124        [ty.into()],
125        EXTENSION_ID,
126        TypeBound::Copyable,
127        extension_ref,
128    )
129}
130
131/// Integer type of a given bit width (specified by the `TypeArg`).
132pub fn ptr_type(ty: impl Into<Type>) -> Type {
133    ptr_custom_type(ty, &Arc::<Extension>::downgrade(&EXTENSION)).into()
134}
135
136#[derive(Clone, Debug, PartialEq)]
137/// A concrete pointer operation.
138pub struct PtrOp {
139    /// The operation definition.
140    pub def: PtrOpDef,
141    /// Type of the value being pointed to.
142    pub ty: Type,
143}
144
145impl PtrOp {
146    fn new(op: PtrOpDef, ty: Type) -> Self {
147        Self { def: op, ty }
148    }
149}
150
151impl MakeExtensionOp for PtrOp {
152    fn op_id(&self) -> OpName {
153        self.def.opdef_id()
154    }
155
156    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
157        let def = PtrOpDef::from_def(ext_op.def())?;
158        def.instantiate(ext_op.args())
159    }
160
161    fn type_args(&self) -> Vec<TypeArg> {
162        vec![self.ty.clone().into()]
163    }
164}
165
166impl MakeRegisteredOp for PtrOp {
167    fn extension_id(&self) -> ExtensionId {
168        EXTENSION_ID.clone()
169    }
170
171    fn extension_ref(&self) -> Weak<Extension> {
172        Arc::downgrade(&EXTENSION)
173    }
174}
175
176/// An extension trait for [Dataflow] providing methods to add pointer
177/// operations.
178pub trait PtrOpBuilder: Dataflow {
179    /// Add a "ptr.New" op.
180    fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> {
181        let ty = self.get_wire_type(val_wire)?;
182        let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?;
183
184        Ok(handle.out_wire(0))
185    }
186
187    /// Add a "ptr.Read" op.
188    fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> {
189        let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?;
190        Ok(handle.out_wire(0))
191    }
192
193    /// Add a "ptr.Write" op.
194    fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> {
195        let ty = self.get_wire_type(val_wire)?;
196
197        let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?;
198        debug_assert_eq!(handle.outputs().len(), 0);
199        Ok(())
200    }
201}
202
203impl<D: Dataflow> PtrOpBuilder for D {}
204
205impl HasConcrete for PtrOpDef {
206    type Concrete = PtrOp;
207
208    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
209        let ty = match type_args {
210            [TypeArg::Runtime(ty)] => ty.clone(),
211            _ => return Err(SignatureError::InvalidTypeArgs.into()),
212        };
213
214        Ok(self.with_type(ty))
215    }
216}
217
218impl HasDef for PtrOp {
219    type Def = PtrOpDef;
220}
221
222#[cfg(test)]
223pub(crate) mod test {
224    use crate::HugrView;
225    use crate::builder::DFGBuilder;
226    use crate::extension::prelude::bool_t;
227    use crate::ops::ExtensionOp;
228    use crate::{
229        builder::{Dataflow, DataflowHugr},
230        std_extensions::arithmetic::int_types::INT_TYPES,
231    };
232    use cool_asserts::assert_matches;
233    use std::sync::Arc;
234    use strum::IntoEnumIterator;
235
236    use super::*;
237    use crate::std_extensions::arithmetic::float_types::float64_type;
238    fn get_opdef(op: impl Into<&'static str>) -> Option<&'static Arc<OpDef>> {
239        EXTENSION.get_op(op.into())
240    }
241
242    #[test]
243    fn create_extension() {
244        assert_eq!(EXTENSION.name(), &EXTENSION_ID);
245
246        for o in PtrOpDef::iter() {
247            assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o));
248        }
249    }
250
251    #[test]
252    fn test_ops() {
253        let ops = [
254            PtrOp::new(PtrOpDef::New, bool_t().clone()),
255            PtrOp::new(PtrOpDef::Read, float64_type()),
256            PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()),
257        ];
258        for op in ops {
259            let op_t: ExtensionOp = op.clone().to_extension_op().unwrap();
260            let def_op = PtrOpDef::from_op(&op_t).unwrap();
261            assert_eq!(op.def, def_op);
262            let new_op = PtrOp::from_op(&op_t).unwrap();
263            assert_eq!(new_op, op);
264        }
265    }
266
267    #[test]
268    fn test_build() {
269        let in_row = vec![bool_t(), float64_type()];
270
271        let hugr = {
272            let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap();
273
274            let in_wires: [Wire; 2] = builder.input_wires_arr();
275            for (ty, w) in in_row.into_iter().zip(in_wires.iter()) {
276                let new_ptr = builder.add_new_ptr(*w).unwrap();
277                let read = builder.add_read_ptr(new_ptr, ty).unwrap();
278                builder.add_write_ptr(new_ptr, read).unwrap();
279            }
280
281            builder.finish_hugr_with_outputs([]).unwrap()
282        };
283        assert_matches!(hugr.validate(), Ok(()));
284    }
285}