laddu_core/
lib.rs

1//! # laddu-core
2//!
3//! This is an internal crate used by `laddu`.
4#![warn(clippy::perf, clippy::style, missing_docs)]
5
6use bincode::ErrorKind;
7#[cfg(feature = "python")]
8use pyo3::PyErr;
9
10/// MPI backend for `laddu`
11///
12/// Message Passing Interface (MPI) is a protocol which enables communication between multiple
13/// CPUs in a high-performance computing environment. While [`rayon`] can parallelize tasks on a
14/// single CPU, MPI can also parallelize tasks on multiple CPUs by running independent
15/// processes on all CPUs at once (tasks) which are assigned ids (ranks) which tell each
16/// process what to do and where to send results. This backend coordinates processes which would
17/// typically be parallelized over the events in a [`Dataset`](`crate::data::Dataset`).
18///
19/// To use this backend, the library must be built with the `mpi` feature, which requires an
20/// existing implementation of MPI like OpenMPI or MPICH. All processing code should be
21/// sandwiched between calls to [`use_mpi`] and [`finalize_mpi`]:
22/// ```ignore
23/// fn main() {
24///     laddu_core::mpi::use_mpi(true);
25///     // laddu analysis code here
26///     laddu_core::mpi::finalize_mpi();
27/// }
28/// ```
29///
30/// [`finalize_mpi`] must be called to trigger all the methods which clean up the MPI
31/// environment. While these are called by default when the [`Universe`](`mpi::environment::Universe`) is dropped, `laddu` uses a static `Universe` that can be accessed by all of the methods that need it, rather than passing the context to each method. This simplifies the way programs can be converted to use MPI, but means that the `Universe` is not automatically dropped at the end of the program (so it must be dropped manually).
32#[cfg(feature = "mpi")]
33pub mod mpi {
34    use std::sync::atomic::{AtomicBool, Ordering};
35    use std::sync::OnceLock;
36
37    use lazy_static::lazy_static;
38    use mpi::environment::Universe;
39    use mpi::topology::{Process, SimpleCommunicator};
40    use mpi::traits::Communicator;
41    use parking_lot::RwLock;
42
43    lazy_static! {
44        static ref USE_MPI: AtomicBool = AtomicBool::new(false);
45    }
46
47    static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
48
49    /// The default root rank for MPI processes
50    pub const ROOT_RANK: i32 = 0;
51
52    /// Check if the current MPI process is the root process
53    pub fn is_root() -> bool {
54        if let Some(world) = crate::mpi::get_world() {
55            world.rank() == ROOT_RANK
56        } else {
57            false
58        }
59    }
60
61    /// Shortcut method to just get the global MPI communicator without accessing `size` and `rank`
62    /// directly
63    pub fn get_world() -> Option<SimpleCommunicator> {
64        if let Some(universe_lock) = MPI_UNIVERSE.get() {
65            if let Some(universe) = &*universe_lock.read() {
66                let world = universe.world();
67                if world.size() == 1 {
68                    return None;
69                }
70                return Some(world);
71            }
72        }
73        None
74    }
75
76    /// Get the rank of the current process
77    pub fn get_rank() -> Option<i32> {
78        get_world().map(|w| w.rank())
79    }
80
81    /// Get number of available processes/ranks
82    pub fn get_size() -> Option<i32> {
83        get_world().map(|w| w.size())
84    }
85
86    /// Use the MPI backend
87    ///
88    /// # Notes
89    ///
90    /// You must have MPI installed for this to work, and you must call the program with
91    /// `mpirun <executable>`, or bad things will happen.
92    ///
93    /// MPI runs an identical program on each process, but gives the program an ID called its
94    /// "rank". Only the results of methods on the root process (rank 0) should be
95    /// considered valid, as other processes only contain portions of each dataset. To ensure
96    /// you don't save or print data at other ranks, use the provided [`is_root()`]
97    /// method to check if the process is the root process.
98    ///
99    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
100    /// the other processes will still run, but they will be independent of the root process
101    /// and will no longer communicate with it. The root process stores no data, so it would
102    /// be difficult (and convoluted) to get the results which were already processed via
103    /// MPI.
104    ///
105    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
106    /// other `laddu` functions are called.
107    ///
108    /// If [`use_mpi()`] is called multiple times, the subsequent calls will have no
109    /// effect.
110    ///
111    /// <div class="warning">
112    ///
113    /// You **must** call [`finalize_mpi()`] before your program exits for MPI to terminate
114    /// smoothly.
115    ///
116    /// </div>
117    ///
118    /// # Examples
119    ///
120    /// ```ignore
121    /// fn main() {
122    ///     laddu_core::use_mpi();
123    ///
124    ///     // ... your code here ...
125    ///
126    ///     laddu_core::finalize_mpi();
127    /// }
128    ///
129    /// ```
130    pub fn use_mpi(trigger: bool) {
131        if trigger {
132            USE_MPI.store(true, Ordering::SeqCst);
133            MPI_UNIVERSE.get_or_init(|| {
134                #[cfg(feature = "rayon")]
135                let threading = mpi::Threading::Funneled;
136                #[cfg(not(feature = "rayon"))]
137                let threading = mpi::Threading::Single;
138                let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
139                let world = universe.world();
140                if world.size() == 1 {
141                    eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
142                    finalize_mpi();
143                    USE_MPI.store(false, Ordering::SeqCst);
144                    RwLock::new(None)
145                } else {
146                    RwLock::new(Some(universe))
147                }
148            });
149        }
150    }
151
152    /// Drop the MPI universe and finalize MPI at the end of a program
153    ///
154    /// This function will do nothing if MPI is not initialized.
155    ///
156    /// <div class="warning">
157    ///
158    /// This should only be called once and should be called at the end of all `laddu`-related
159    /// function calls. This must be called at the end of any program which uses MPI.
160    ///
161    /// </div>
162    pub fn finalize_mpi() {
163        if using_mpi() {
164            let mut universe = MPI_UNIVERSE.get().unwrap().write();
165            *universe = None;
166        }
167    }
168
169    /// Check if MPI backend is enabled
170    pub fn using_mpi() -> bool {
171        USE_MPI.load(Ordering::SeqCst)
172    }
173
174    /// A trait including some useful auxiliary methods for MPI
175    pub trait LadduMPI {
176        /// Get the process at the root rank
177        fn process_at_root(&self) -> Process<'_>;
178        /// Check if the current rank is the root rank
179        fn is_root(&self) -> bool;
180        /// Get the counts/displacements for partitioning a buffer of length
181        /// `buf_len`
182        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
183        /// Get the counts/displacements for partitioning a nested buffer (like
184        /// a [`Vec<Vec<T>>`]). If the internal vectors all have the same length
185        /// `internal_len` and there are `unflattened_len` elements in the
186        /// outer vector, then this will give the correct counts/displacements for a
187        /// flattened version of the nested buffer.
188        fn get_flattened_counts_displs(
189            &self,
190            unflattened_len: usize,
191            internal_len: usize,
192        ) -> (Vec<i32>, Vec<i32>);
193    }
194
195    impl LadduMPI for SimpleCommunicator {
196        fn process_at_root(&self) -> Process<'_> {
197            self.process_at_rank(crate::mpi::ROOT_RANK)
198        }
199
200        fn is_root(&self) -> bool {
201            self.rank() == crate::mpi::ROOT_RANK
202        }
203
204        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
205            let mut counts = vec![0; self.size() as usize];
206            let mut displs = vec![0; self.size() as usize];
207            let chunk_size = buf_len / self.size() as usize;
208            let surplus = buf_len % self.size() as usize;
209            for i in 0..self.size() as usize {
210                counts[i] = if i < surplus {
211                    chunk_size + 1
212                } else {
213                    chunk_size
214                } as i32;
215                displs[i] = if i == 0 {
216                    0
217                } else {
218                    displs[i - 1] + counts[i - 1]
219                };
220            }
221            (counts, displs)
222        }
223
224        fn get_flattened_counts_displs(
225            &self,
226            unflattened_len: usize,
227            internal_len: usize,
228        ) -> (Vec<i32>, Vec<i32>) {
229            let mut counts = vec![0; self.size() as usize];
230            let mut displs = vec![0; self.size() as usize];
231            let chunk_size = unflattened_len / self.size() as usize;
232            let surplus = unflattened_len % self.size() as usize;
233            for i in 0..self.size() as usize {
234                counts[i] = if i < surplus {
235                    (chunk_size + 1) * internal_len
236                } else {
237                    chunk_size * internal_len
238                } as i32;
239                displs[i] = if i == 0 {
240                    0
241                } else {
242                    displs[i - 1] + counts[i - 1]
243                };
244            }
245            (counts, displs)
246        }
247    }
248}
249
250use thiserror::Error;
251
252/// [`Amplitude`](crate::amplitudes::Amplitude)s and methods for making and evaluating them.
253pub mod amplitudes;
254/// Methods for loading and manipulating [`Event`]-based data.
255pub mod data;
256/// Structures for manipulating the cache and free parameters.
257pub mod resources;
258/// Utility functions, enums, and traits
259pub mod utils;
260/// Useful traits for all crate structs
261pub mod traits {
262    pub use crate::amplitudes::Amplitude;
263    pub use crate::utils::variables::Variable;
264    pub use crate::ReadWrite;
265}
266
267pub use crate::data::{open, BinnedDataset, Dataset, Event};
268pub use crate::resources::{
269    Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, ParameterID, Parameters,
270    Resources, ScalarID, VectorID,
271};
272pub use crate::utils::enums::{Channel, Frame, Sign};
273pub use crate::utils::variables::{
274    Angles, CosTheta, Mandelstam, Mass, Phi, PolAngle, PolMagnitude, Polarization,
275};
276pub use crate::utils::vectors::{Vec3, Vec4};
277pub use amplitudes::{
278    constant, parameter, AmplitudeID, Evaluator, Expression, Manager, Model, ParameterLike,
279};
280
281// Re-exports
282pub use ganesh::{mcmc::Ensemble, Bound, Status};
283pub use nalgebra::DVector;
284pub use num::Complex;
285
286/// A floating-point number type (defaults to [`f64`], see `f32` feature).
287#[cfg(not(feature = "f32"))]
288pub type Float = f64;
289
290/// A floating-point number type (defaults to [`f64`], see `f32` feature).
291#[cfg(feature = "f32")]
292pub type Float = f32;
293
294/// The mathematical constant $`\pi`$.
295#[cfg(not(feature = "f32"))]
296pub const PI: Float = std::f64::consts::PI;
297
298/// The mathematical constant $`\pi`$.
299#[cfg(feature = "f32")]
300pub const PI: Float = std::f32::consts::PI;
301
302/// The error type used by all `laddu` internal methods
303#[derive(Error, Debug)]
304pub enum LadduError {
305    /// An alias for [`std::io::Error`].
306    #[error("IO Error: {0}")]
307    IOError(#[from] std::io::Error),
308    /// An alias for [`parquet::errors::ParquetError`].
309    #[error("Parquet Error: {0}")]
310    ParquetError(#[from] parquet::errors::ParquetError),
311    /// An alias for [`arrow::error::ArrowError`].
312    #[error("Arrow Error: {0}")]
313    ArrowError(#[from] arrow::error::ArrowError),
314    /// An alias for [`shellexpand::LookupError`].
315    #[error("Failed to expand path: {0}")]
316    LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
317    /// An error which occurs when the user tries to register two amplitudes by the same name to
318    /// the same [`Manager`].
319    #[error("An amplitude by the name \"{name}\" is already registered by this manager!")]
320    RegistrationError {
321        /// Name of amplitude which is already registered
322        name: String,
323    },
324    /// An error which occurs when the user tries to use an unregistered amplitude.
325    #[error("No registered amplitude with name \"{name}\"!")]
326    AmplitudeNotFoundError {
327        /// Name of amplitude which failed lookup
328        name: String,
329    },
330    /// An error which occurs when the user tries to parse an invalid string of text, typically
331    /// into an enum variant.
332    #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
333    ParseError {
334        /// The string which was parsed
335        name: String,
336        /// The name of the object it failed to parse into
337        object: String,
338    },
339    /// An error returned by the Rust de(serializer)
340    #[error("(De)Serialization error: {0}")]
341    SerdeError(#[from] Box<ErrorKind>),
342    /// An error returned by the Python pickle (de)serializer
343    #[error("Pickle conversion error: {0}")]
344    PickleError(#[from] serde_pickle::Error),
345    /// An error type for [`rayon`] thread pools
346    #[cfg(feature = "rayon")]
347    #[error("Error building thread pool: {0}")]
348    ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
349    /// An error type for [`numpy`]-related conversions
350    #[cfg(feature = "numpy")]
351    #[error("Numpy error: {0}")]
352    NumpyError(#[from] numpy::FromVecError),
353    /// A custom fallback error for errors too complex or too infrequent to warrant their own error
354    /// category.
355    #[error("{0}")]
356    Custom(String),
357}
358
359impl Clone for LadduError {
360    // This is a little hack because error types are rarely cloneable, but I need to store them in a
361    // cloneable box for minimizers and MCMC methods
362    fn clone(&self) -> Self {
363        let err_string = self.to_string();
364        LadduError::Custom(err_string)
365    }
366}
367
368#[cfg(feature = "python")]
369impl From<LadduError> for PyErr {
370    fn from(err: LadduError) -> Self {
371        use pyo3::exceptions::*;
372        let err_string = err.to_string();
373        match err {
374            LadduError::LookupError(_)
375            | LadduError::RegistrationError { .. }
376            | LadduError::AmplitudeNotFoundError { .. }
377            | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
378            LadduError::ParquetError(_)
379            | LadduError::ArrowError(_)
380            | LadduError::IOError(_)
381            | LadduError::SerdeError(_)
382            | LadduError::PickleError(_) => PyIOError::new_err(err_string),
383            LadduError::Custom(_) => PyException::new_err(err_string),
384            #[cfg(feature = "rayon")]
385            LadduError::ThreadPoolError(_) => PyException::new_err(err_string),
386            #[cfg(feature = "numpy")]
387            LadduError::NumpyError(_) => PyException::new_err(err_string),
388        }
389    }
390}
391
392use serde::{de::DeserializeOwned, Serialize};
393use std::{
394    fmt::Debug,
395    fs::File,
396    io::{BufReader, BufWriter},
397    path::Path,
398};
399/// A trait which allows structs with [`Serialize`] and [`Deserialize`](`serde::Deserialize`) to be
400/// written and read from files with a certain set of types/extensions.
401///
402/// Currently, Python's pickle format is supported supported, since it's an easy-to-parse standard
403/// that supports floating point values better that JSON or TOML
404pub trait ReadWrite: Serialize + DeserializeOwned {
405    /// Create a null version of the object which acts as a shell into which Python's `pickle` module
406    /// can load data. This generally shouldn't be used to construct the struct in regular code.
407    fn create_null() -> Self;
408    /// Save a [`serde`]-object to a file path, using the extension to determine the file format
409    fn save_as<T: AsRef<str>>(&self, file_path: T) -> Result<(), LadduError> {
410        let expanded_path = shellexpand::full(file_path.as_ref())?;
411        let file_path = Path::new(expanded_path.as_ref());
412        let file = File::create(file_path)?;
413        let mut writer = BufWriter::new(file);
414        serde_pickle::to_writer(&mut writer, self, Default::default())?;
415        Ok(())
416    }
417    /// Load a [`serde`]-object from a file path, using the extension to determine the file format
418    fn load_from<T: AsRef<str>>(file_path: T) -> Result<Self, LadduError> {
419        let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
420        let file = File::open(file_path)?;
421        let reader = BufReader::new(file);
422        serde_pickle::from_reader(reader, Default::default()).map_err(LadduError::from)
423    }
424}
425
426impl ReadWrite for Status {
427    fn create_null() -> Self {
428        Status::default()
429    }
430}
431impl ReadWrite for Ensemble {
432    fn create_null() -> Self {
433        Ensemble::new(Vec::default())
434    }
435}
436impl ReadWrite for Model {
437    fn create_null() -> Self {
438        Model {
439            manager: Manager::default(),
440            expression: Expression::default(),
441        }
442    }
443}