hugr_core/extension/prelude/
generic.rs

1use std::str::FromStr;
2use std::sync::{Arc, Weak};
3
4use crate::extension::prelude::usize_custom_t;
5use crate::extension::simple_op::{
6    HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
7};
8use crate::extension::OpDef;
9use crate::extension::SignatureFunc;
10use crate::extension::{ConstFold, ExtensionId};
11use crate::ops::ExtensionOp;
12use crate::ops::NamedOp;
13use crate::ops::OpName;
14use crate::type_row;
15use crate::types::FuncValueType;
16
17use crate::types::Type;
18
19use crate::extension::SignatureError;
20
21use crate::types::PolyFuncTypeRV;
22
23use crate::types::type_param::TypeArg;
24use crate::Extension;
25
26use super::PRELUDE;
27use super::{ConstUsize, PRELUDE_ID};
28use crate::types::type_param::TypeParam;
29
30/// Name of the operation for loading generic BoundedNat parameters.
31pub const LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");
32
33/// Definition of the load nat operation.
34#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
35pub struct LoadNatDef;
36
37impl NamedOp for LoadNatDef {
38    fn name(&self) -> OpName {
39        LOAD_NAT_OP_ID
40    }
41}
42
43impl FromStr for LoadNatDef {
44    type Err = ();
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        if s == LoadNatDef.name() {
48            Ok(Self)
49        } else {
50            Err(())
51        }
52    }
53}
54
55impl ConstFold for LoadNatDef {
56    fn fold(
57        &self,
58        type_args: &[TypeArg],
59        _consts: &[(crate::IncomingPort, crate::ops::Value)],
60    ) -> crate::extension::ConstFoldResult {
61        let [arg] = type_args else {
62            return None;
63        };
64        let nat = arg.as_nat();
65        if let Some(n) = nat {
66            let n_const = ConstUsize::new(n);
67            Some(vec![(0.into(), n_const.into())])
68        } else {
69            None
70        }
71    }
72}
73
74impl MakeOpDef for LoadNatDef {
75    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
76    where
77        Self: Sized,
78    {
79        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
80    }
81
82    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
83        let usize_t: Type = usize_custom_t(_extension_ref).into();
84        let params = vec![TypeParam::max_nat()];
85        PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into()
86    }
87
88    fn extension_ref(&self) -> Weak<Extension> {
89        Arc::downgrade(&PRELUDE)
90    }
91
92    fn extension(&self) -> ExtensionId {
93        PRELUDE_ID
94    }
95
96    fn description(&self) -> String {
97        "Loads a generic bounded nat parameter into a usize runtime value.".into()
98    }
99
100    fn post_opdef(&self, def: &mut OpDef) {
101        def.set_constant_folder(*self);
102    }
103}
104
105/// Concrete load nat operation.
106#[derive(Clone, Debug, PartialEq)]
107pub struct LoadNat {
108    nat: TypeArg,
109}
110
111impl LoadNat {
112    /// Creates a new [LoadNat] operation.
113    pub fn new(nat: TypeArg) -> Self {
114        LoadNat { nat }
115    }
116
117    /// Returns the nat type argument that should be loaded.
118    pub fn get_nat(self) -> TypeArg {
119        self.nat
120    }
121}
122
123impl NamedOp for LoadNat {
124    fn name(&self) -> OpName {
125        LOAD_NAT_OP_ID
126    }
127}
128
129impl MakeExtensionOp for LoadNat {
130    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
131    where
132        Self: Sized,
133    {
134        let def = LoadNatDef::from_def(ext_op.def())?;
135        def.instantiate(ext_op.args())
136    }
137
138    fn type_args(&self) -> Vec<TypeArg> {
139        vec![self.nat.clone()]
140    }
141}
142
143impl MakeRegisteredOp for LoadNat {
144    fn extension_id(&self) -> ExtensionId {
145        PRELUDE_ID
146    }
147
148    fn extension_ref(&self) -> Weak<Extension> {
149        Arc::downgrade(&PRELUDE)
150    }
151}
152
153impl HasDef for LoadNat {
154    type Def = LoadNatDef;
155}
156
157impl HasConcrete for LoadNatDef {
158    type Concrete = LoadNat;
159
160    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
161        match type_args {
162            [n] => Ok(LoadNat::new(n.clone())),
163            _ => Err(SignatureError::InvalidTypeArgs.into()),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use crate::{
171        builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
172        extension::prelude::{usize_t, ConstUsize},
173        ops::{constant, OpType},
174        type_row,
175        types::TypeArg,
176        HugrView, OutgoingPort,
177    };
178
179    use super::LoadNat;
180
181    #[test]
182    fn test_load_nat() {
183        let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap();
184
185        let arg = TypeArg::BoundedNat { n: 4 };
186        let op = LoadNat::new(arg);
187
188        let out = b.add_dataflow_op(op.clone(), []).unwrap();
189
190        let result = b.finish_hugr_with_outputs(out.outputs()).unwrap();
191
192        let exp_optype: OpType = op.into();
193
194        for child in result.children(result.root()) {
195            let node_optype = result.get_optype(child);
196            // The only node in the HUGR besides Input and Output should be LoadNat.
197            if !node_optype.is_input() && !node_optype.is_output() {
198                assert_eq!(node_optype, &exp_optype)
199            }
200        }
201    }
202
203    #[test]
204    fn test_load_nat_fold() {
205        let arg = TypeArg::BoundedNat { n: 5 };
206        let op = LoadNat::new(arg);
207
208        let optype: OpType = op.into();
209
210        match optype {
211            OpType::ExtensionOp(ext_op) => {
212                let result = ext_op.constant_fold(&[]);
213                let exp_port: OutgoingPort = 0.into();
214                let exp_val: constant::Value = ConstUsize::new(5).into();
215                assert_eq!(result, Some(vec![(exp_port, exp_val)]))
216            }
217            _ => panic!(),
218        }
219    }
220}