laddu_core/
lib.rs

1//! # laddu-core
2//!
3//! This is an internal crate used by `laddu`.
4#![warn(clippy::perf, clippy::style, missing_docs)]
5#![allow(clippy::excessive_precision)]
6
7use bincode::ErrorKind;
8#[cfg(feature = "python")]
9use pyo3::PyErr;
10
11/// MPI backend for `laddu`
12///
13/// Message Passing Interface (MPI) is a protocol which enables communication between multiple
14/// CPUs in a high-performance computing environment. While [`rayon`] can parallelize tasks on a
15/// single CPU, MPI can also parallelize tasks on multiple CPUs by running independent
16/// processes on all CPUs at once (tasks) which are assigned ids (ranks) which tell each
17/// process what to do and where to send results. This backend coordinates processes which would
18/// typically be parallelized over the events in a [`Dataset`](`crate::data::Dataset`).
19///
20/// To use this backend, the library must be built with the `mpi` feature, which requires an
21/// existing implementation of MPI like OpenMPI or MPICH. All processing code should be
22/// sandwiched between calls to [`use_mpi`] and [`finalize_mpi`]:
23/// ```ignore
24/// fn main() {
25///     laddu_core::mpi::use_mpi(true);
26///     // laddu analysis code here
27///     laddu_core::mpi::finalize_mpi();
28/// }
29/// ```
30///
31/// [`finalize_mpi`] must be called to trigger all the methods which clean up the MPI
32/// environment. While these are called by default when the [`Universe`](`mpi::environment::Universe`) is dropped, `laddu` uses a static `Universe` that can be accessed by all of the methods that need it, rather than passing the context to each method. This simplifies the way programs can be converted to use MPI, but means that the `Universe` is not automatically dropped at the end of the program (so it must be dropped manually).
33#[cfg(feature = "mpi")]
34pub mod mpi {
35    use std::sync::atomic::{AtomicBool, Ordering};
36    use std::sync::OnceLock;
37
38    use lazy_static::lazy_static;
39    use mpi::environment::Universe;
40    use mpi::topology::{Process, SimpleCommunicator};
41    use mpi::traits::Communicator;
42    use parking_lot::RwLock;
43
44    lazy_static! {
45        static ref USE_MPI: AtomicBool = AtomicBool::new(false);
46    }
47
48    static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
49
50    /// The default root rank for MPI processes
51    pub const ROOT_RANK: i32 = 0;
52
53    /// Check if the current MPI process is the root process
54    pub fn is_root() -> bool {
55        if let Some(world) = crate::mpi::get_world() {
56            world.rank() == ROOT_RANK
57        } else {
58            false
59        }
60    }
61
62    /// Shortcut method to just get the global MPI communicator without accessing `size` and `rank`
63    /// directly
64    pub fn get_world() -> Option<SimpleCommunicator> {
65        if let Some(universe_lock) = MPI_UNIVERSE.get() {
66            if let Some(universe) = &*universe_lock.read() {
67                let world = universe.world();
68                if world.size() == 1 {
69                    return None;
70                }
71                return Some(world);
72            }
73        }
74        None
75    }
76
77    /// Get the rank of the current process
78    pub fn get_rank() -> Option<i32> {
79        get_world().map(|w| w.rank())
80    }
81
82    /// Get number of available processes/ranks
83    pub fn get_size() -> Option<i32> {
84        get_world().map(|w| w.size())
85    }
86
87    /// Use the MPI backend
88    ///
89    /// # Notes
90    ///
91    /// You must have MPI installed for this to work, and you must call the program with
92    /// `mpirun <executable>`, or bad things will happen.
93    ///
94    /// MPI runs an identical program on each process, but gives the program an ID called its
95    /// "rank". Only the results of methods on the root process (rank 0) should be
96    /// considered valid, as other processes only contain portions of each dataset. To ensure
97    /// you don't save or print data at other ranks, use the provided [`is_root()`]
98    /// method to check if the process is the root process.
99    ///
100    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
101    /// the other processes will still run, but they will be independent of the root process
102    /// and will no longer communicate with it. The root process stores no data, so it would
103    /// be difficult (and convoluted) to get the results which were already processed via
104    /// MPI.
105    ///
106    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
107    /// other `laddu` functions are called.
108    ///
109    /// If [`use_mpi()`] is called multiple times, the subsequent calls will have no
110    /// effect.
111    ///
112    /// <div class="warning">
113    ///
114    /// You **must** call [`finalize_mpi()`] before your program exits for MPI to terminate
115    /// smoothly.
116    ///
117    /// </div>
118    ///
119    /// # Examples
120    ///
121    /// ```ignore
122    /// fn main() {
123    ///     laddu_core::use_mpi();
124    ///
125    ///     // ... your code here ...
126    ///
127    ///     laddu_core::finalize_mpi();
128    /// }
129    ///
130    /// ```
131    pub fn use_mpi(trigger: bool) {
132        if trigger {
133            USE_MPI.store(true, Ordering::SeqCst);
134            MPI_UNIVERSE.get_or_init(|| {
135                #[cfg(feature = "rayon")]
136                let threading = mpi::Threading::Funneled;
137                #[cfg(not(feature = "rayon"))]
138                let threading = mpi::Threading::Single;
139                let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
140                let world = universe.world();
141                if world.size() == 1 {
142                    eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
143                    finalize_mpi();
144                    USE_MPI.store(false, Ordering::SeqCst);
145                    RwLock::new(None)
146                } else {
147                    RwLock::new(Some(universe))
148                }
149            });
150        }
151    }
152
153    /// Drop the MPI universe and finalize MPI at the end of a program
154    ///
155    /// This function will do nothing if MPI is not initialized.
156    ///
157    /// <div class="warning">
158    ///
159    /// This should only be called once and should be called at the end of all `laddu`-related
160    /// function calls. This must be called at the end of any program which uses MPI.
161    ///
162    /// </div>
163    pub fn finalize_mpi() {
164        if using_mpi() {
165            let mut universe = MPI_UNIVERSE.get().unwrap().write();
166            *universe = None;
167        }
168    }
169
170    /// Check if MPI backend is enabled
171    pub fn using_mpi() -> bool {
172        USE_MPI.load(Ordering::SeqCst)
173    }
174
175    /// A trait including some useful auxiliary methods for MPI
176    pub trait LadduMPI {
177        /// Get the process at the root rank
178        fn process_at_root(&self) -> Process<'_>;
179        /// Check if the current rank is the root rank
180        fn is_root(&self) -> bool;
181        /// Get the counts/displacements for partitioning a buffer of length
182        /// `buf_len`
183        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
184        /// Get the counts/displacements for partitioning a nested buffer (like
185        /// a [`Vec<Vec<T>>`]). If the internal vectors all have the same length
186        /// `internal_len` and there are `unflattened_len` elements in the
187        /// outer vector, then this will give the correct counts/displacements for a
188        /// flattened version of the nested buffer.
189        fn get_flattened_counts_displs(
190            &self,
191            unflattened_len: usize,
192            internal_len: usize,
193        ) -> (Vec<i32>, Vec<i32>);
194    }
195
196    impl LadduMPI for SimpleCommunicator {
197        fn process_at_root(&self) -> Process<'_> {
198            self.process_at_rank(crate::mpi::ROOT_RANK)
199        }
200
201        fn is_root(&self) -> bool {
202            self.rank() == crate::mpi::ROOT_RANK
203        }
204
205        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
206            let mut counts = vec![0; self.size() as usize];
207            let mut displs = vec![0; self.size() as usize];
208            let chunk_size = buf_len / self.size() as usize;
209            let surplus = buf_len % self.size() as usize;
210            for i in 0..self.size() as usize {
211                counts[i] = if i < surplus {
212                    chunk_size + 1
213                } else {
214                    chunk_size
215                } as i32;
216                displs[i] = if i == 0 {
217                    0
218                } else {
219                    displs[i - 1] + counts[i - 1]
220                };
221            }
222            (counts, displs)
223        }
224
225        fn get_flattened_counts_displs(
226            &self,
227            unflattened_len: usize,
228            internal_len: usize,
229        ) -> (Vec<i32>, Vec<i32>) {
230            let mut counts = vec![0; self.size() as usize];
231            let mut displs = vec![0; self.size() as usize];
232            let chunk_size = unflattened_len / self.size() as usize;
233            let surplus = unflattened_len % self.size() as usize;
234            for i in 0..self.size() as usize {
235                counts[i] = if i < surplus {
236                    (chunk_size + 1) * internal_len
237                } else {
238                    chunk_size * internal_len
239                } as i32;
240                displs[i] = if i == 0 {
241                    0
242                } else {
243                    displs[i - 1] + counts[i - 1]
244                };
245            }
246            (counts, displs)
247        }
248    }
249}
250
251use thiserror::Error;
252
253/// [`Amplitude`](crate::amplitudes::Amplitude)s and methods for making and evaluating them.
254pub mod amplitudes;
255/// Methods for loading and manipulating [`Event`]-based data.
256pub mod data;
257/// Structures for manipulating the cache and free parameters.
258pub mod resources;
259/// Utility functions, enums, and traits
260pub mod utils;
261/// Useful traits for all crate structs
262pub mod traits {
263    pub use crate::amplitudes::Amplitude;
264    pub use crate::utils::variables::Variable;
265    pub use crate::ReadWrite;
266}
267
268pub use crate::data::{open, BinnedDataset, Dataset, Event};
269pub use crate::resources::{
270    Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, ParameterID, Parameters,
271    Resources, ScalarID, VectorID,
272};
273pub use crate::utils::enums::{Channel, Frame, Sign};
274pub use crate::utils::variables::{
275    Angles, CosTheta, Mandelstam, Mass, Phi, PolAngle, PolMagnitude, Polarization,
276};
277pub use crate::utils::vectors::{Vec3, Vec4};
278pub use amplitudes::{
279    constant, parameter, AmplitudeID, Evaluator, Expression, Manager, Model, ParameterLike,
280};
281
282// Re-exports
283pub use ganesh::{mcmc::Ensemble, Bound, Status};
284pub use nalgebra::DVector;
285pub use num::Complex;
286
287/// A floating-point number type (defaults to [`f64`], see `f32` feature).
288#[cfg(not(feature = "f32"))]
289pub type Float = f64;
290
291/// A floating-point number type (defaults to [`f64`], see `f32` feature).
292#[cfg(feature = "f32")]
293pub type Float = f32;
294
295/// The mathematical constant $`\pi`$.
296#[cfg(not(feature = "f32"))]
297pub const PI: Float = std::f64::consts::PI;
298
299/// The mathematical constant $`\pi`$.
300#[cfg(feature = "f32")]
301pub const PI: Float = std::f32::consts::PI;
302
303/// The error type used by all `laddu` internal methods
304#[derive(Error, Debug)]
305pub enum LadduError {
306    /// An alias for [`std::io::Error`].
307    #[error("IO Error: {0}")]
308    IOError(#[from] std::io::Error),
309    /// An alias for [`parquet::errors::ParquetError`].
310    #[error("Parquet Error: {0}")]
311    ParquetError(#[from] parquet::errors::ParquetError),
312    /// An alias for [`arrow::error::ArrowError`].
313    #[error("Arrow Error: {0}")]
314    ArrowError(#[from] arrow::error::ArrowError),
315    /// An alias for [`shellexpand::LookupError`].
316    #[error("Failed to expand path: {0}")]
317    LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
318    /// An error which occurs when the user tries to register two amplitudes by the same name to
319    /// the same [`Manager`].
320    #[error("An amplitude by the name \"{name}\" is already registered by this manager!")]
321    RegistrationError {
322        /// Name of amplitude which is already registered
323        name: String,
324    },
325    /// An error which occurs when the user tries to use an unregistered amplitude.
326    #[error("No registered amplitude with name \"{name}\"!")]
327    AmplitudeNotFoundError {
328        /// Name of amplitude which failed lookup
329        name: String,
330    },
331    /// An error which occurs when the user tries to parse an invalid string of text, typically
332    /// into an enum variant.
333    #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
334    ParseError {
335        /// The string which was parsed
336        name: String,
337        /// The name of the object it failed to parse into
338        object: String,
339    },
340    /// An error returned by the Rust de(serializer)
341    #[error("(De)Serialization error: {0}")]
342    SerdeError(#[from] Box<ErrorKind>),
343    /// An error returned by the Python pickle (de)serializer
344    #[error("Pickle conversion error: {0}")]
345    PickleError(#[from] serde_pickle::Error),
346    /// An error type for [`rayon`] thread pools
347    #[cfg(feature = "rayon")]
348    #[error("Error building thread pool: {0}")]
349    ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
350    /// An error type for [`numpy`]-related conversions
351    #[cfg(feature = "numpy")]
352    #[error("Numpy error: {0}")]
353    NumpyError(#[from] numpy::FromVecError),
354    /// A custom fallback error for errors too complex or too infrequent to warrant their own error
355    /// category.
356    #[error("{0}")]
357    Custom(String),
358}
359
360impl Clone for LadduError {
361    // This is a little hack because error types are rarely cloneable, but I need to store them in a
362    // cloneable box for minimizers and MCMC methods
363    fn clone(&self) -> Self {
364        let err_string = self.to_string();
365        LadduError::Custom(err_string)
366    }
367}
368
369#[cfg(feature = "python")]
370impl From<LadduError> for PyErr {
371    fn from(err: LadduError) -> Self {
372        use pyo3::exceptions::*;
373        let err_string = err.to_string();
374        match err {
375            LadduError::LookupError(_)
376            | LadduError::RegistrationError { .. }
377            | LadduError::AmplitudeNotFoundError { .. }
378            | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
379            LadduError::ParquetError(_)
380            | LadduError::ArrowError(_)
381            | LadduError::IOError(_)
382            | LadduError::SerdeError(_)
383            | LadduError::PickleError(_) => PyIOError::new_err(err_string),
384            LadduError::Custom(_) => PyException::new_err(err_string),
385            #[cfg(feature = "rayon")]
386            LadduError::ThreadPoolError(_) => PyException::new_err(err_string),
387            #[cfg(feature = "numpy")]
388            LadduError::NumpyError(_) => PyException::new_err(err_string),
389        }
390    }
391}
392
393use serde::{de::DeserializeOwned, Serialize};
394use std::{
395    fmt::Debug,
396    fs::File,
397    io::{BufReader, BufWriter},
398    path::Path,
399};
400/// A trait which allows structs with [`Serialize`] and [`Deserialize`](`serde::Deserialize`) to be
401/// written and read from files with a certain set of types/extensions.
402///
403/// Currently, Python's pickle format is supported supported, since it's an easy-to-parse standard
404/// that supports floating point values better that JSON or TOML
405pub trait ReadWrite: Serialize + DeserializeOwned {
406    /// Create a null version of the object which acts as a shell into which Python's `pickle` module
407    /// can load data. This generally shouldn't be used to construct the struct in regular code.
408    fn create_null() -> Self;
409    /// Save a [`serde`]-object to a file path, using the extension to determine the file format
410    fn save_as<T: AsRef<str>>(&self, file_path: T) -> Result<(), LadduError> {
411        let expanded_path = shellexpand::full(file_path.as_ref())?;
412        let file_path = Path::new(expanded_path.as_ref());
413        let file = File::create(file_path)?;
414        let mut writer = BufWriter::new(file);
415        serde_pickle::to_writer(&mut writer, self, Default::default())?;
416        Ok(())
417    }
418    /// Load a [`serde`]-object from a file path, using the extension to determine the file format
419    fn load_from<T: AsRef<str>>(file_path: T) -> Result<Self, LadduError> {
420        let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
421        let file = File::open(file_path)?;
422        let reader = BufReader::new(file);
423        serde_pickle::from_reader(reader, Default::default()).map_err(LadduError::from)
424    }
425}
426
427impl ReadWrite for Status {
428    fn create_null() -> Self {
429        Status::default()
430    }
431}
432impl ReadWrite for Ensemble {
433    fn create_null() -> Self {
434        Ensemble::new(Vec::default())
435    }
436}
437impl ReadWrite for Model {
438    fn create_null() -> Self {
439        Model {
440            manager: Manager::default(),
441            expression: Expression::default(),
442        }
443    }
444}