1use std::any::Any;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use qudit_circuit::QuditCircuit;
6use qudit_core::ComplexScalar;
7
8use crate::InstantiationResult;
9use crate::InstantiationTarget;
10
11pub trait DataItem: Any + ToString {}
12
13impl<T: Any + ToString> DataItem for T {}
14
15pub type DataMap = HashMap<String, Box<dyn DataItem + Send + Sync>>;
16
17pub trait Instantiater<C: ComplexScalar> {
18 fn instantiate(
19 &self,
20 circuit: Arc<QuditCircuit>,
21 target: Arc<InstantiationTarget<C>>,
22 data: Arc<DataMap>,
23 ) -> InstantiationResult<C>;
24
25 fn batched_instantiate(
26 &self,
27 circuit: Arc<QuditCircuit>,
28 targets: &[Arc<InstantiationTarget<C>>],
29 data: Arc<DataMap>,
30 ) -> Vec<InstantiationResult<C>> {
31 targets
32 .iter()
33 .map(|t| self.instantiate(circuit.clone(), t.clone(), data.clone()))
34 .collect()
35 }
36}
37
38#[cfg(feature = "python")]
39pub mod python {
40 use super::*;
41 use dyn_clone::DynClone;
42 use pyo3::{
43 exceptions::{PyNotImplementedError, PyTypeError},
44 prelude::*,
45 types::{PyDict, PyList},
46 };
47 use qudit_core::c64;
48
49 fn pydict_to_datamap(py_dict: Option<&Bound<'_, PyDict>>) -> PyResult<Arc<DataMap>> {
50 let mut data_map = HashMap::new();
51
52 match py_dict {
53 None => Ok(Arc::new(data_map)),
54 Some(py_dict) => {
55 for (key, value) in py_dict.iter() {
56 let key_str: String = key.extract()?;
57 let value_str: String = value.extract()?;
58 data_map.insert(
59 key_str,
60 Box::new(value_str) as Box<dyn DataItem + Send + Sync>,
61 );
62 }
63 Ok(Arc::new(data_map))
64 }
65 }
66 }
67
68 pub trait InstantiaterWrapper: Instantiater<c64> + Send + Sync + DynClone {}
69
70 #[pyclass(name = "NativeInstantiater")]
71 pub struct BoxedInstantiater {
72 pub inner: Box<dyn InstantiaterWrapper>,
73 }
74
75 #[pymethods]
76 impl BoxedInstantiater {
77 #[pyo3(name = "instantiate")]
78 #[pyo3(signature = (circuit, target, data = None))]
79 fn instantiate_python(
80 &self,
81 circuit: QuditCircuit,
82 target: InstantiationTarget<c64>,
83 data: Option<&Bound<'_, PyDict>>,
84 ) -> PyResult<InstantiationResult<c64>> {
85 let data_map = pydict_to_datamap(data)?;
86 let result =
87 Instantiater::instantiate(self, Arc::new(circuit), Arc::new(target), data_map);
88 Ok(result)
89 }
90
91 #[pyo3(name = "batched_instantiate")]
92 #[pyo3(signature = (circuit, targets, data = None))]
93 fn batched_instantiate_python(
94 &self,
95 circuit: QuditCircuit,
96 targets: Vec<InstantiationTarget<c64>>,
97 data: Option<&Bound<'_, PyDict>>,
98 ) -> PyResult<Vec<InstantiationResult<c64>>> {
99 let data_map = pydict_to_datamap(data)?;
100 let target_arcs: Vec<Arc<InstantiationTarget<c64>>> =
101 targets.into_iter().map(Arc::new).collect();
102 let result =
103 Instantiater::batched_instantiate(self, Arc::new(circuit), &target_arcs, data_map);
104 Ok(result)
105 }
106 }
107
108 impl Instantiater<c64> for BoxedInstantiater {
109 fn instantiate(
110 &self,
111 circuit: Arc<QuditCircuit>,
112 target: Arc<InstantiationTarget<c64>>,
113 data: Arc<DataMap>,
114 ) -> InstantiationResult<c64> {
115 self.inner.instantiate(circuit, target, data)
116 }
117
118 fn batched_instantiate(
119 &self,
120 circuit: Arc<QuditCircuit>,
121 targets: &[Arc<InstantiationTarget<c64>>],
122 data: Arc<DataMap>,
123 ) -> Vec<InstantiationResult<c64>> {
124 self.inner.batched_instantiate(circuit, targets, data)
125 }
126 }
127
128 #[pyclass(name = "Instantiater", subclass)]
129 struct PyInstantiaterABC;
130
131 #[pymethods]
132 impl PyInstantiaterABC {
133 fn instantiate(
134 &self,
135 _circuit: QuditCircuit,
136 _target: InstantiationTarget<c64>,
137 _data: &Bound<'_, PyDict>,
138 ) -> PyResult<InstantiationResult<c64>> {
139 Err(PyNotImplementedError::new_err(
140 "Instantiaters must implement the instantiate method.",
141 ))
142 }
143 }
144
145 struct PyInstantiaterTrampoline {
146 instantiater: Py<PyAny>,
147 }
148
149 impl Instantiater<c64> for PyInstantiaterTrampoline {
150 fn instantiate(
151 &self,
152 circuit: Arc<QuditCircuit>,
153 target: Arc<InstantiationTarget<c64>>,
154 data: Arc<DataMap>,
155 ) -> InstantiationResult<c64> {
156 Python::attach(|py| {
158 let py_data = PyDict::new(py);
159 for (key, val) in data.iter() {
160 py_data.set_item(key, val.to_string()).unwrap();
161 }
162
163 self.instantiater
164 .bind(py)
165 .call_method(
166 "instantiate",
167 ((*circuit).clone(), (*target).clone(), py_data),
168 None,
169 )
170 .unwrap()
171 .extract()
172 .expect("Invalid return type from instantiate.")
173 })
174 }
175
176 fn batched_instantiate(
177 &self,
178 circuit: Arc<QuditCircuit>,
179 targets: &[Arc<InstantiationTarget<c64>>],
180 data: Arc<DataMap>,
181 ) -> Vec<InstantiationResult<c64>> {
182 Python::attach(|py| {
184 let bound = self.instantiater.bind(py);
185
186 let py_data = PyDict::new(py);
187 for (key, val) in data.iter() {
188 py_data.set_item(key, val.to_string()).unwrap();
189 }
190
191 if bound.hasattr("batched_instantiate").is_ok_and(|x| x) {
192 let py_targets =
193 PyList::new(py, targets.iter().map(|t| (**t).clone())).unwrap();
194 bound
195 .call_method(
196 "batched_instantiate",
197 ((*circuit).clone(), py_targets, py_data),
198 None,
199 )
200 .unwrap()
201 .extract()
202 .expect("Invalid return type from batched instantiate.")
203 } else {
204 let circuit = (*circuit).clone().into_pyobject(py).unwrap();
205 targets
206 .iter()
207 .map(|t| {
208 bound
209 .call_method(
210 "instantiate",
211 (&circuit, (**t).clone(), &py_data),
212 None,
213 )
214 .unwrap()
215 .extract()
216 .expect("Invalid return type from instantiate.")
217 })
218 .collect()
219 }
220 })
221 }
222 }
223
224 pub enum PyInstantiater {
229 #[allow(private_interfaces)]
230 Python(PyInstantiaterTrampoline),
231 Native(BoxedInstantiater),
232 }
233
234 impl Instantiater<c64> for PyInstantiater {
235 fn instantiate(
236 &self,
237 circuit: Arc<QuditCircuit>,
238 target: Arc<InstantiationTarget<c64>>,
239 data: Arc<DataMap>,
240 ) -> InstantiationResult<c64> {
241 match self {
242 PyInstantiater::Python(inner) => inner.instantiate(circuit, target, data),
243 PyInstantiater::Native(inner) => inner.instantiate(circuit, target, data),
244 }
245 }
246
247 fn batched_instantiate(
248 &self,
249 circuit: Arc<QuditCircuit>,
250 targets: &[Arc<InstantiationTarget<c64>>],
251 data: Arc<DataMap>,
252 ) -> Vec<InstantiationResult<c64>> {
253 match self {
254 PyInstantiater::Python(inner) => inner.batched_instantiate(circuit, targets, data),
255 PyInstantiater::Native(inner) => inner.batched_instantiate(circuit, targets, data),
256 }
257 }
258 }
259
260 impl<'a, 'py> FromPyObject<'a, 'py> for PyInstantiater {
261 type Error = PyErr;
262
263 fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
264 if let Ok(dyn_trait) = obj.extract::<PyRef<BoxedInstantiater>>() {
265 Ok(PyInstantiater::Native(BoxedInstantiater {
266 inner: dyn_clone::clone_box(&*dyn_trait.inner),
267 }))
268 } else if obj.hasattr("instantiate")? {
269 let trampoline = PyInstantiaterTrampoline {
270 instantiater: obj.to_owned().unbind(),
271 };
272 Ok(PyInstantiater::Python(trampoline))
273 } else {
274 Err(PyTypeError::new_err(
275 "Cannot extract an 'Instantiater' during conversion to native code.",
276 ))
277 }
278 }
279 }
280}