qudit_inst/
instantiater.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use qudit_circuit::QuditCircuit;
6use qudit_core::ComplexScalar;
7
8use crate::InstantiationResult;
9use crate::InstantiationTarget;
10
11pub trait DataItem: Any + ToString {}
12
13impl<T: Any + ToString> DataItem for T {}
14
15pub type DataMap = HashMap<String, Box<dyn DataItem + Send + Sync>>;
16
17pub trait Instantiater<C: ComplexScalar> {
18    fn instantiate(
19        &self,
20        circuit: Arc<QuditCircuit>,
21        target: Arc<InstantiationTarget<C>>,
22        data: Arc<DataMap>,
23    ) -> InstantiationResult<C>;
24
25    fn batched_instantiate(
26        &self,
27        circuit: Arc<QuditCircuit>,
28        targets: &[Arc<InstantiationTarget<C>>],
29        data: Arc<DataMap>,
30    ) -> Vec<InstantiationResult<C>> {
31        targets
32            .iter()
33            .map(|t| self.instantiate(circuit.clone(), t.clone(), data.clone()))
34            .collect()
35    }
36}
37
38#[cfg(feature = "python")]
39pub mod python {
40    use super::*;
41    use dyn_clone::DynClone;
42    use pyo3::{
43        exceptions::{PyNotImplementedError, PyTypeError},
44        prelude::*,
45        types::{PyDict, PyList},
46    };
47    use qudit_core::c64;
48
49    fn pydict_to_datamap(py_dict: Option<&Bound<'_, PyDict>>) -> PyResult<Arc<DataMap>> {
50        let mut data_map = HashMap::new();
51
52        match py_dict {
53            None => Ok(Arc::new(data_map)),
54            Some(py_dict) => {
55                for (key, value) in py_dict.iter() {
56                    let key_str: String = key.extract()?;
57                    let value_str: String = value.extract()?;
58                    data_map.insert(
59                        key_str,
60                        Box::new(value_str) as Box<dyn DataItem + Send + Sync>,
61                    );
62                }
63                Ok(Arc::new(data_map))
64            }
65        }
66    }
67
68    pub trait InstantiaterWrapper: Instantiater<c64> + Send + Sync + DynClone {}
69
70    #[pyclass(name = "NativeInstantiater")]
71    pub struct BoxedInstantiater {
72        pub inner: Box<dyn InstantiaterWrapper>,
73    }
74
75    #[pymethods]
76    impl BoxedInstantiater {
77        #[pyo3(name = "instantiate")]
78        #[pyo3(signature = (circuit, target, data = None))]
79        fn instantiate_python(
80            &self,
81            circuit: QuditCircuit,
82            target: InstantiationTarget<c64>,
83            data: Option<&Bound<'_, PyDict>>,
84        ) -> PyResult<InstantiationResult<c64>> {
85            let data_map = pydict_to_datamap(data)?;
86            let result =
87                Instantiater::instantiate(self, Arc::new(circuit), Arc::new(target), data_map);
88            Ok(result)
89        }
90
91        #[pyo3(name = "batched_instantiate")]
92        #[pyo3(signature = (circuit, targets, data = None))]
93        fn batched_instantiate_python(
94            &self,
95            circuit: QuditCircuit,
96            targets: Vec<InstantiationTarget<c64>>,
97            data: Option<&Bound<'_, PyDict>>,
98        ) -> PyResult<Vec<InstantiationResult<c64>>> {
99            let data_map = pydict_to_datamap(data)?;
100            let target_arcs: Vec<Arc<InstantiationTarget<c64>>> =
101                targets.into_iter().map(Arc::new).collect();
102            let result =
103                Instantiater::batched_instantiate(self, Arc::new(circuit), &target_arcs, data_map);
104            Ok(result)
105        }
106    }
107
108    impl Instantiater<c64> for BoxedInstantiater {
109        fn instantiate(
110            &self,
111            circuit: Arc<QuditCircuit>,
112            target: Arc<InstantiationTarget<c64>>,
113            data: Arc<DataMap>,
114        ) -> InstantiationResult<c64> {
115            self.inner.instantiate(circuit, target, data)
116        }
117
118        fn batched_instantiate(
119            &self,
120            circuit: Arc<QuditCircuit>,
121            targets: &[Arc<InstantiationTarget<c64>>],
122            data: Arc<DataMap>,
123        ) -> Vec<InstantiationResult<c64>> {
124            self.inner.batched_instantiate(circuit, targets, data)
125        }
126    }
127
128    #[pyclass(name = "Instantiater", subclass)]
129    struct PyInstantiaterABC;
130
131    #[pymethods]
132    impl PyInstantiaterABC {
133        fn instantiate(
134            &self,
135            _circuit: QuditCircuit,
136            _target: InstantiationTarget<c64>,
137            _data: &Bound<'_, PyDict>,
138        ) -> PyResult<InstantiationResult<c64>> {
139            Err(PyNotImplementedError::new_err(
140                "Instantiaters must implement the instantiate method.",
141            ))
142        }
143    }
144
145    struct PyInstantiaterTrampoline {
146        instantiater: Py<PyAny>,
147    }
148
149    impl Instantiater<c64> for PyInstantiaterTrampoline {
150        fn instantiate(
151            &self,
152            circuit: Arc<QuditCircuit>,
153            target: Arc<InstantiationTarget<c64>>,
154            data: Arc<DataMap>,
155        ) -> InstantiationResult<c64> {
156            // TODO: handle failures by not panicking, and propagating a python error
157            Python::attach(|py| {
158                let py_data = PyDict::new(py);
159                for (key, val) in data.iter() {
160                    py_data.set_item(key, val.to_string()).unwrap();
161                }
162
163                self.instantiater
164                    .bind(py)
165                    .call_method(
166                        "instantiate",
167                        ((*circuit).clone(), (*target).clone(), py_data),
168                        None,
169                    )
170                    .unwrap()
171                    .extract()
172                    .expect("Invalid return type from instantiate.")
173            })
174        }
175
176        fn batched_instantiate(
177            &self,
178            circuit: Arc<QuditCircuit>,
179            targets: &[Arc<InstantiationTarget<c64>>],
180            data: Arc<DataMap>,
181        ) -> Vec<InstantiationResult<c64>> {
182            // TODO: handle failures by not panicking, and propagating a python error
183            Python::attach(|py| {
184                let bound = self.instantiater.bind(py);
185
186                let py_data = PyDict::new(py);
187                for (key, val) in data.iter() {
188                    py_data.set_item(key, val.to_string()).unwrap();
189                }
190
191                if bound.hasattr("batched_instantiate").is_ok_and(|x| x) {
192                    let py_targets =
193                        PyList::new(py, targets.iter().map(|t| (**t).clone())).unwrap();
194                    bound
195                        .call_method(
196                            "batched_instantiate",
197                            ((*circuit).clone(), py_targets, py_data),
198                            None,
199                        )
200                        .unwrap()
201                        .extract()
202                        .expect("Invalid return type from batched instantiate.")
203                } else {
204                    let circuit = (*circuit).clone().into_pyobject(py).unwrap();
205                    targets
206                        .iter()
207                        .map(|t| {
208                            bound
209                                .call_method(
210                                    "instantiate",
211                                    (&circuit, (**t).clone(), &py_data),
212                                    None,
213                                )
214                                .unwrap()
215                                .extract()
216                                .expect("Invalid return type from instantiate.")
217                        })
218                        .collect()
219                }
220            })
221        }
222    }
223
224    /// Other pyo3 code can use PyInstantiater as a parameter's type in pyfunctions and
225    /// pymethods, and this can be populated with either Python defined instantiaters
226    /// or boxed rust ones. The GIL is not held with this object, and calls to the
227    /// rust version are direct through the box without attaching to the GIL.
228    pub enum PyInstantiater {
229        #[allow(private_interfaces)]
230        Python(PyInstantiaterTrampoline),
231        Native(BoxedInstantiater),
232    }
233
234    impl Instantiater<c64> for PyInstantiater {
235        fn instantiate(
236            &self,
237            circuit: Arc<QuditCircuit>,
238            target: Arc<InstantiationTarget<c64>>,
239            data: Arc<DataMap>,
240        ) -> InstantiationResult<c64> {
241            match self {
242                PyInstantiater::Python(inner) => inner.instantiate(circuit, target, data),
243                PyInstantiater::Native(inner) => inner.instantiate(circuit, target, data),
244            }
245        }
246
247        fn batched_instantiate(
248            &self,
249            circuit: Arc<QuditCircuit>,
250            targets: &[Arc<InstantiationTarget<c64>>],
251            data: Arc<DataMap>,
252        ) -> Vec<InstantiationResult<c64>> {
253            match self {
254                PyInstantiater::Python(inner) => inner.batched_instantiate(circuit, targets, data),
255                PyInstantiater::Native(inner) => inner.batched_instantiate(circuit, targets, data),
256            }
257        }
258    }
259
260    impl<'a, 'py> FromPyObject<'a, 'py> for PyInstantiater {
261        type Error = PyErr;
262
263        fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
264            if let Ok(dyn_trait) = obj.extract::<PyRef<BoxedInstantiater>>() {
265                Ok(PyInstantiater::Native(BoxedInstantiater {
266                    inner: dyn_clone::clone_box(&*dyn_trait.inner),
267                }))
268            } else if obj.hasattr("instantiate")? {
269                let trampoline = PyInstantiaterTrampoline {
270                    instantiater: obj.to_owned().unbind(),
271                };
272                Ok(PyInstantiater::Python(trampoline))
273            } else {
274                Err(PyTypeError::new_err(
275                    "Cannot extract an 'Instantiater' during conversion to native code.",
276                ))
277            }
278        }
279    }
280}