hugr_core/extension/prelude/
generic.rs

1use std::str::FromStr;
2use std::sync::{Arc, Weak};
3
4use crate::extension::OpDef;
5use crate::extension::SignatureFunc;
6use crate::extension::prelude::usize_custom_t;
7use crate::extension::simple_op::{
8    HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{ConstFold, ExtensionId};
11use crate::ops::ExtensionOp;
12use crate::ops::OpName;
13use crate::type_row;
14use crate::types::FuncValueType;
15
16use crate::types::Type;
17
18use crate::extension::SignatureError;
19
20use crate::types::PolyFuncTypeRV;
21
22use crate::Extension;
23use crate::types::type_param::TypeArg;
24
25use super::PRELUDE;
26use super::{ConstUsize, PRELUDE_ID};
27use crate::types::type_param::TypeParam;
28
29/// Name of the operation for loading generic `BoundedNat` parameters.
30pub static LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");
31
32/// Definition of the load nat operation.
33#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
34pub struct LoadNatDef;
35
36impl FromStr for LoadNatDef {
37    type Err = ();
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        if s == Self.op_id() { Ok(Self) } else { Err(()) }
41    }
42}
43
44impl ConstFold for LoadNatDef {
45    fn fold(
46        &self,
47        type_args: &[TypeArg],
48        _consts: &[(crate::IncomingPort, crate::ops::Value)],
49    ) -> crate::extension::ConstFoldResult {
50        let [arg] = type_args else {
51            return None;
52        };
53        let nat = arg.as_nat();
54        if let Some(n) = nat {
55            let n_const = ConstUsize::new(n);
56            Some(vec![(0.into(), n_const.into())])
57        } else {
58            None
59        }
60    }
61}
62
63impl MakeOpDef for LoadNatDef {
64    fn opdef_id(&self) -> OpName {
65        LOAD_NAT_OP_ID.clone()
66    }
67
68    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
69    where
70        Self: Sized,
71    {
72        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
73    }
74
75    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
76        let usize_t: Type = usize_custom_t(_extension_ref).into();
77        let params = vec![TypeParam::max_nat_type()];
78        PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into()
79    }
80
81    fn extension_ref(&self) -> Weak<Extension> {
82        Arc::downgrade(&PRELUDE)
83    }
84
85    fn extension(&self) -> ExtensionId {
86        PRELUDE_ID
87    }
88
89    fn description(&self) -> String {
90        "Loads a generic bounded nat parameter into a usize runtime value.".into()
91    }
92
93    fn post_opdef(&self, def: &mut OpDef) {
94        def.set_constant_folder(*self);
95    }
96}
97
98/// Concrete load nat operation.
99#[derive(Clone, Debug, PartialEq)]
100pub struct LoadNat {
101    nat: TypeArg,
102}
103
104impl LoadNat {
105    /// Creates a new [`LoadNat`] operation.
106    #[must_use]
107    pub fn new(nat: TypeArg) -> Self {
108        LoadNat { nat }
109    }
110
111    /// Returns the nat type argument that should be loaded.
112    #[must_use]
113    pub fn get_nat(self) -> TypeArg {
114        self.nat
115    }
116}
117
118impl MakeExtensionOp for LoadNat {
119    fn op_id(&self) -> OpName {
120        LoadNatDef.opdef_id()
121    }
122
123    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
124    where
125        Self: Sized,
126    {
127        let def = LoadNatDef::from_def(ext_op.def())?;
128        def.instantiate(ext_op.args())
129    }
130
131    fn type_args(&self) -> Vec<TypeArg> {
132        vec![self.nat.clone()]
133    }
134}
135
136impl MakeRegisteredOp for LoadNat {
137    fn extension_id(&self) -> ExtensionId {
138        PRELUDE_ID
139    }
140
141    fn extension_ref(&self) -> Weak<Extension> {
142        Arc::downgrade(&PRELUDE)
143    }
144}
145
146impl HasDef for LoadNat {
147    type Def = LoadNatDef;
148}
149
150impl HasConcrete for LoadNatDef {
151    type Concrete = LoadNat;
152
153    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
154        match type_args {
155            [n] => Ok(LoadNat::new(n.clone())),
156            _ => Err(SignatureError::InvalidTypeArgs.into()),
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use crate::{
164        HugrView, OutgoingPort,
165        builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig},
166        extension::prelude::{ConstUsize, usize_t},
167        ops::{OpType, constant},
168        type_row,
169        types::Term,
170    };
171
172    use super::LoadNat;
173
174    #[test]
175    fn test_load_nat() {
176        let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap();
177
178        let arg = Term::from(4u64);
179        let op = LoadNat::new(arg);
180
181        let out = b.add_dataflow_op(op.clone(), []).unwrap();
182
183        let result = b.finish_hugr_with_outputs(out.outputs()).unwrap();
184
185        let exp_optype: OpType = op.into();
186
187        for child in result.children(result.entrypoint()) {
188            let node_optype = result.get_optype(child);
189            // The only node in the HUGR besides Input and Output should be LoadNat.
190            if !node_optype.is_input() && !node_optype.is_output() {
191                assert_eq!(node_optype, &exp_optype);
192            }
193        }
194    }
195
196    #[test]
197    fn test_load_nat_fold() {
198        let arg = Term::from(5u64);
199        let op = LoadNat::new(arg);
200
201        let optype: OpType = op.into();
202
203        if let OpType::ExtensionOp(ext_op) = optype {
204            let result = ext_op.constant_fold(&[]);
205            let exp_port: OutgoingPort = 0.into();
206            let exp_val: constant::Value = ConstUsize::new(5).into();
207            assert_eq!(result, Some(vec![(exp_port, exp_val)]));
208        } else {
209            panic!()
210        }
211    }
212}