laddu_python/amplitudes.rs
1use crate::data::PyDataset;
2use bincode::{deserialize, serialize};
3use laddu_core::{
4 amplitudes::{
5 constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
6 ParameterLike,
7 },
8 traits::ReadWrite,
9 Complex, Float, LadduError,
10};
11use numpy::{PyArray1, PyArray2};
12use pyo3::{
13 exceptions::PyTypeError,
14 prelude::*,
15 types::{PyBytes, PyList},
16};
17#[cfg(feature = "rayon")]
18use rayon::ThreadPoolBuilder;
19
20/// An object which holds a registered ``Amplitude``
21///
22/// See Also
23/// --------
24/// laddu.Manager.register
25///
26#[pyclass(name = "AmplitudeID", module = "laddu")]
27#[derive(Clone)]
28pub struct PyAmplitudeID(AmplitudeID);
29
30/// A mathematical expression formed from AmplitudeIDs
31///
32#[pyclass(name = "Expression", module = "laddu")]
33#[derive(Clone)]
34pub struct PyExpression(Expression);
35
36#[pymethods]
37impl PyAmplitudeID {
38 /// The real part of a complex Amplitude
39 ///
40 /// Returns
41 /// -------
42 /// Expression
43 /// The real part of the given Amplitude
44 ///
45 fn real(&self) -> PyExpression {
46 PyExpression(self.0.real())
47 }
48 /// The imaginary part of a complex Amplitude
49 ///
50 /// Returns
51 /// -------
52 /// Expression
53 /// The imaginary part of the given Amplitude
54 ///
55 fn imag(&self) -> PyExpression {
56 PyExpression(self.0.imag())
57 }
58 /// The norm-squared of a complex Amplitude
59 ///
60 /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
61 ///
62 /// Returns
63 /// -------
64 /// Expression
65 /// The norm-squared of the given Amplitude
66 ///
67 fn norm_sqr(&self) -> PyExpression {
68 PyExpression(self.0.norm_sqr())
69 }
70 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
71 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
72 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
73 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
74 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
75 } else if let Ok(other_int) = other.extract::<usize>() {
76 if other_int == 0 {
77 Ok(PyExpression(Expression::Amp(self.0.clone())))
78 } else {
79 Err(PyTypeError::new_err(
80 "Addition with an integer for this type is only defined for 0",
81 ))
82 }
83 } else {
84 Err(PyTypeError::new_err("Unsupported operand type for +"))
85 }
86 }
87 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
88 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
89 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
90 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
91 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
92 } else if let Ok(other_int) = other.extract::<usize>() {
93 if other_int == 0 {
94 Ok(PyExpression(Expression::Amp(self.0.clone())))
95 } else {
96 Err(PyTypeError::new_err(
97 "Addition with an integer for this type is only defined for 0",
98 ))
99 }
100 } else {
101 Err(PyTypeError::new_err("Unsupported operand type for +"))
102 }
103 }
104 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
105 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
106 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
107 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
108 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
109 } else {
110 Err(PyTypeError::new_err("Unsupported operand type for *"))
111 }
112 }
113 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
114 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
115 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
116 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
117 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
118 } else {
119 Err(PyTypeError::new_err("Unsupported operand type for *"))
120 }
121 }
122 fn __str__(&self) -> String {
123 format!("{}", self.0)
124 }
125 fn __repr__(&self) -> String {
126 format!("{:?}", self.0)
127 }
128}
129
130#[pymethods]
131impl PyExpression {
132 /// The real part of a complex Expression
133 ///
134 /// Returns
135 /// -------
136 /// Expression
137 /// The real part of the given Expression
138 ///
139 fn real(&self) -> PyExpression {
140 PyExpression(self.0.real())
141 }
142 /// The imaginary part of a complex Expression
143 ///
144 /// Returns
145 /// -------
146 /// Expression
147 /// The imaginary part of the given Expression
148 ///
149 fn imag(&self) -> PyExpression {
150 PyExpression(self.0.imag())
151 }
152 /// The norm-squared of a complex Expression
153 ///
154 /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
155 ///
156 /// Returns
157 /// -------
158 /// Expression
159 /// The norm-squared of the given Expression
160 ///
161 fn norm_sqr(&self) -> PyExpression {
162 PyExpression(self.0.norm_sqr())
163 }
164 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
165 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
166 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
167 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
168 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
169 } else if let Ok(other_int) = other.extract::<usize>() {
170 if other_int == 0 {
171 Ok(PyExpression(self.0.clone()))
172 } else {
173 Err(PyTypeError::new_err(
174 "Addition with an integer for this type is only defined for 0",
175 ))
176 }
177 } else {
178 Err(PyTypeError::new_err("Unsupported operand type for +"))
179 }
180 }
181 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
182 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
183 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
184 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
185 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
186 } else if let Ok(other_int) = other.extract::<usize>() {
187 if other_int == 0 {
188 Ok(PyExpression(self.0.clone()))
189 } else {
190 Err(PyTypeError::new_err(
191 "Addition with an integer for this type is only defined for 0",
192 ))
193 }
194 } else {
195 Err(PyTypeError::new_err("Unsupported operand type for +"))
196 }
197 }
198 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
199 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
200 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
201 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
202 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
203 } else {
204 Err(PyTypeError::new_err("Unsupported operand type for *"))
205 }
206 }
207 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
208 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
209 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
210 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
211 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
212 } else {
213 Err(PyTypeError::new_err("Unsupported operand type for *"))
214 }
215 }
216 fn __str__(&self) -> String {
217 format!("{}", self.0)
218 }
219 fn __repr__(&self) -> String {
220 format!("{:?}", self.0)
221 }
222}
223
224/// A class which can be used to register Amplitudes and store precalculated data
225///
226#[pyclass(name = "Manager", module = "laddu")]
227pub struct PyManager(Manager);
228
229#[pymethods]
230impl PyManager {
231 #[new]
232 fn new() -> Self {
233 Self(Manager::default())
234 }
235 /// The free parameters used by the Manager
236 ///
237 /// Returns
238 /// -------
239 /// parameters : list of str
240 /// The list of parameter names
241 ///
242 #[getter]
243 fn parameters(&self) -> Vec<String> {
244 self.0.parameters()
245 }
246 /// Register an Amplitude with the Manager
247 ///
248 /// Parameters
249 /// ----------
250 /// amplitude : Amplitude
251 /// The Amplitude to register
252 ///
253 /// Returns
254 /// -------
255 /// AmplitudeID
256 /// A reference to the registered `amplitude` that can be used to form complex
257 /// Expressions
258 ///
259 /// Raises
260 /// ------
261 /// ValueError
262 /// If the name of the ``amplitude`` has already been registered
263 ///
264 fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
265 Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
266 }
267 /// Generate a Model from the given expression made of registered Amplitudes
268 ///
269 /// Parameters
270 /// ----------
271 /// expression : Expression or AmplitudeID
272 /// The expression to use in precalculation
273 ///
274 /// Returns
275 /// -------
276 /// Model
277 /// An object which represents the underlying mathematical model and can be loaded with
278 /// a Dataset
279 ///
280 /// Raises
281 /// ------
282 /// TypeError
283 /// If the expression is not convertable to a Model
284 ///
285 /// Notes
286 /// -----
287 /// While the given `expression` will be the one evaluated in the end, all registered
288 /// Amplitudes will be loaded, and all of their parameters will be included in the final
289 /// expression. These parameters will have no effect on evaluation, but they must be
290 /// included in function calls.
291 ///
292 fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
293 let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
294 Ok(expression.0)
295 } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
296 Ok(Expression::Amp(aid.0))
297 } else {
298 Err(PyTypeError::new_err(
299 "'expression' must either by an Expression or AmplitudeID",
300 ))
301 }?;
302 Ok(PyModel(self.0.model(&expression)))
303 }
304}
305
306/// A class which represents a model composed of registered Amplitudes
307///
308#[pyclass(name = "Model", module = "laddu")]
309pub struct PyModel(pub Model);
310
311#[pymethods]
312impl PyModel {
313 /// The free parameters used by the Manager
314 ///
315 /// Returns
316 /// -------
317 /// parameters : list of str
318 /// The list of parameter names
319 ///
320 #[getter]
321 fn parameters(&self) -> Vec<String> {
322 self.0.parameters()
323 }
324 /// Load a Model by precalculating each term over the given Dataset
325 ///
326 /// Parameters
327 /// ----------
328 /// dataset : Dataset
329 /// The Dataset to use in precalculation
330 ///
331 /// Returns
332 /// -------
333 /// Evaluator
334 /// An object that can be used to evaluate the `expression` over each event in the
335 /// `dataset`
336 ///
337 /// Notes
338 /// -----
339 /// While the given `expression` will be the one evaluated in the end, all registered
340 /// Amplitudes will be loaded, and all of their parameters will be included in the final
341 /// expression. These parameters will have no effect on evaluation, but they must be
342 /// included in function calls.
343 ///
344 fn load(&self, dataset: &PyDataset) -> PyEvaluator {
345 PyEvaluator(self.0.load(&dataset.0))
346 }
347 /// Save the Model to a file
348 ///
349 /// Parameters
350 /// ----------
351 /// path : str
352 /// The path of the new file (overwrites if the file exists!)
353 ///
354 /// Raises
355 /// ------
356 /// IOError
357 /// If anything fails when trying to write the file
358 ///
359 fn save_as(&self, path: &str) -> PyResult<()> {
360 self.0.save_as(path)?;
361 Ok(())
362 }
363 /// Load a Model from a file
364 ///
365 /// Parameters
366 /// ----------
367 /// path : str
368 /// The path of the existing fit file
369 ///
370 /// Returns
371 /// -------
372 /// Model
373 /// The model contained in the file
374 ///
375 /// Raises
376 /// ------
377 /// IOError
378 /// If anything fails when trying to read the file
379 ///
380 #[staticmethod]
381 fn load_from(path: &str) -> PyResult<Self> {
382 Ok(PyModel(Model::load_from(path)?))
383 }
384 #[new]
385 fn new() -> Self {
386 PyModel(Model::create_null())
387 }
388 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
389 Ok(PyBytes::new(
390 py,
391 serialize(&self.0)
392 .map_err(LadduError::SerdeError)?
393 .as_slice(),
394 ))
395 }
396 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
397 *self = PyModel(deserialize(state.as_bytes()).map_err(LadduError::SerdeError)?);
398 Ok(())
399 }
400}
401
402/// An Amplitude which can be registered by a Manager
403///
404/// See Also
405/// --------
406/// laddu.Manager
407///
408#[pyclass(name = "Amplitude", module = "laddu")]
409pub struct PyAmplitude(pub Box<dyn Amplitude>);
410
411/// A class which can be used to evaluate a stored Expression
412///
413/// See Also
414/// --------
415/// laddu.Manager.load
416///
417#[pyclass(name = "Evaluator", module = "laddu")]
418#[derive(Clone)]
419pub struct PyEvaluator(pub Evaluator);
420
421#[pymethods]
422impl PyEvaluator {
423 /// The free parameters used by the Evaluator
424 ///
425 /// Returns
426 /// -------
427 /// parameters : list of str
428 /// The list of parameter names
429 ///
430 #[getter]
431 fn parameters(&self) -> Vec<String> {
432 self.0.parameters()
433 }
434 /// Activates Amplitudes in the Expression by name
435 ///
436 /// Parameters
437 /// ----------
438 /// arg : str or list of str
439 /// Names of Amplitudes to be activated
440 ///
441 /// Raises
442 /// ------
443 /// TypeError
444 /// If `arg` is not a str or list of str
445 /// ValueError
446 /// If `arg` or any items of `arg` are not registered Amplitudes
447 ///
448 fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
449 if let Ok(string_arg) = arg.extract::<String>() {
450 self.0.activate(&string_arg)?;
451 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
452 let vec: Vec<String> = list_arg.extract()?;
453 self.0.activate_many(&vec)?;
454 } else {
455 return Err(PyTypeError::new_err(
456 "Argument must be either a string or a list of strings",
457 ));
458 }
459 Ok(())
460 }
461 /// Activates all Amplitudes in the Expression
462 ///
463 fn activate_all(&self) {
464 self.0.activate_all();
465 }
466 /// Deactivates Amplitudes in the Expression by name
467 ///
468 /// Deactivated Amplitudes act as zeros in the Expression
469 ///
470 /// Parameters
471 /// ----------
472 /// arg : str or list of str
473 /// Names of Amplitudes to be deactivated
474 ///
475 /// Raises
476 /// ------
477 /// TypeError
478 /// If `arg` is not a str or list of str
479 /// ValueError
480 /// If `arg` or any items of `arg` are not registered Amplitudes
481 ///
482 fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
483 if let Ok(string_arg) = arg.extract::<String>() {
484 self.0.deactivate(&string_arg)?;
485 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
486 let vec: Vec<String> = list_arg.extract()?;
487 self.0.deactivate_many(&vec)?;
488 } else {
489 return Err(PyTypeError::new_err(
490 "Argument must be either a string or a list of strings",
491 ));
492 }
493 Ok(())
494 }
495 /// Deactivates all Amplitudes in the Expression
496 ///
497 fn deactivate_all(&self) {
498 self.0.deactivate_all();
499 }
500 /// Isolates Amplitudes in the Expression by name
501 ///
502 /// Activates the Amplitudes given in `arg` and deactivates the rest
503 ///
504 /// Parameters
505 /// ----------
506 /// arg : str or list of str
507 /// Names of Amplitudes to be isolated
508 ///
509 /// Raises
510 /// ------
511 /// TypeError
512 /// If `arg` is not a str or list of str
513 /// ValueError
514 /// If `arg` or any items of `arg` are not registered Amplitudes
515 ///
516 fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
517 if let Ok(string_arg) = arg.extract::<String>() {
518 self.0.isolate(&string_arg)?;
519 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
520 let vec: Vec<String> = list_arg.extract()?;
521 self.0.isolate_many(&vec)?;
522 } else {
523 return Err(PyTypeError::new_err(
524 "Argument must be either a string or a list of strings",
525 ));
526 }
527 Ok(())
528 }
529 /// Evaluate the stored Expression over the stored Dataset
530 ///
531 /// Parameters
532 /// ----------
533 /// parameters : list of float
534 /// The values to use for the free parameters
535 /// threads : int, optional
536 /// The number of threads to use (setting this to None will use all available CPUs)
537 ///
538 /// Returns
539 /// -------
540 /// result : array_like
541 /// A ``numpy`` array of complex values for each Event in the Dataset
542 ///
543 /// Raises
544 /// ------
545 /// Exception
546 /// If there was an error building the thread pool
547 ///
548 #[pyo3(signature = (parameters, *, threads=None))]
549 fn evaluate<'py>(
550 &self,
551 py: Python<'py>,
552 parameters: Vec<Float>,
553 threads: Option<usize>,
554 ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
555 #[cfg(feature = "rayon")]
556 {
557 Ok(PyArray1::from_slice(
558 py,
559 &ThreadPoolBuilder::new()
560 .num_threads(threads.unwrap_or_else(num_cpus::get))
561 .build()
562 .map_err(LadduError::from)?
563 .install(|| self.0.evaluate(¶meters)),
564 ))
565 }
566 #[cfg(not(feature = "rayon"))]
567 {
568 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
569 }
570 }
571 /// Evaluate the gradient of the stored Expression over the stored Dataset
572 ///
573 /// Parameters
574 /// ----------
575 /// parameters : list of float
576 /// The values to use for the free parameters
577 /// threads : int, optional
578 /// The number of threads to use (setting this to None will use all available CPUs)
579 ///
580 /// Returns
581 /// -------
582 /// result : array_like
583 /// A ``numpy`` 2D array of complex values for each Event in the Dataset
584 ///
585 /// Raises
586 /// ------
587 /// Exception
588 /// If there was an error building the thread pool or problem creating the resulting
589 /// ``numpy`` array
590 ///
591 #[pyo3(signature = (parameters, *, threads=None))]
592 fn evaluate_gradient<'py>(
593 &self,
594 py: Python<'py>,
595 parameters: Vec<Float>,
596 threads: Option<usize>,
597 ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
598 #[cfg(feature = "rayon")]
599 {
600 Ok(PyArray2::from_vec2(
601 py,
602 &ThreadPoolBuilder::new()
603 .num_threads(threads.unwrap_or_else(num_cpus::get))
604 .build()
605 .map_err(LadduError::from)?
606 .install(|| {
607 self.0
608 .evaluate_gradient(¶meters)
609 .iter()
610 .map(|grad| grad.data.as_vec().to_vec())
611 .collect::<Vec<Vec<Complex<Float>>>>()
612 }),
613 )
614 .map_err(LadduError::NumpyError)?)
615 }
616 #[cfg(not(feature = "rayon"))]
617 {
618 Ok(PyArray2::from_vec2(
619 py,
620 &self
621 .0
622 .evaluate_gradient(¶meters)
623 .iter()
624 .map(|grad| grad.data.as_vec().to_vec())
625 .collect::<Vec<Vec<Complex<Float>>>>(),
626 )
627 .map_err(LadduError::NumpyError)?)
628 }
629 }
630}
631
632/// A class, typically used to allow Amplitudes to take either free parameters or constants as
633/// inputs
634///
635/// See Also
636/// --------
637/// laddu.parameter
638/// laddu.constant
639///
640#[pyclass(name = "ParameterLike", module = "laddu")]
641#[derive(Clone)]
642pub struct PyParameterLike(pub ParameterLike);
643
644/// A free parameter which floats during an optimization
645///
646/// Parameters
647/// ----------
648/// name : str
649/// The name of the free parameter
650///
651/// Returns
652/// -------
653/// laddu.ParameterLike
654/// An object that can be used as the input for many Amplitude constructors
655///
656/// Notes
657/// -----
658/// Two free parameters with the same name are shared in a fit
659///
660#[pyfunction(name = "parameter")]
661pub fn py_parameter(name: &str) -> PyParameterLike {
662 PyParameterLike(parameter(name))
663}
664
665/// A term which stays constant during an optimization
666///
667/// Parameters
668/// ----------
669/// value : float
670/// The numerical value of the constant
671///
672/// Returns
673/// -------
674/// laddu.ParameterLike
675/// An object that can be used as the input for many Amplitude constructors
676///
677#[pyfunction(name = "constant")]
678pub fn py_constant(value: Float) -> PyParameterLike {
679 PyParameterLike(constant(value))
680}