laddu_core/
lib.rs

1//! # laddu-core
2//!
3//! This is an internal crate used by `laddu`.
4#![warn(clippy::perf, clippy::style, missing_docs)]
5#![allow(clippy::excessive_precision)]
6
7use ganesh::swarms::{Particle, SwarmPositionInitializer};
8use ganesh::{Point, Swarm};
9#[cfg(feature = "python")]
10use pyo3::PyErr;
11
12/// MPI backend for `laddu`
13///
14/// Message Passing Interface (MPI) is a protocol which enables communication between multiple
15/// CPUs in a high-performance computing environment. While [`rayon`] can parallelize tasks on a
16/// single CPU, MPI can also parallelize tasks on multiple CPUs by running independent
17/// processes on all CPUs at once (tasks) which are assigned ids (ranks) which tell each
18/// process what to do and where to send results. This backend coordinates processes which would
19/// typically be parallelized over the events in a [`Dataset`](`crate::data::Dataset`).
20///
21/// To use this backend, the library must be built with the `mpi` feature, which requires an
22/// existing implementation of MPI like OpenMPI or MPICH. All processing code should be
23/// sandwiched between calls to [`use_mpi`] and [`finalize_mpi`]:
24/// ```ignore
25/// fn main() {
26///     laddu_core::mpi::use_mpi(true);
27///     // laddu analysis code here
28///     laddu_core::mpi::finalize_mpi();
29/// }
30/// ```
31///
32/// [`finalize_mpi`] must be called to trigger all the methods which clean up the MPI
33/// environment. While these are called by default when the [`Universe`](`mpi::environment::Universe`) is dropped, `laddu` uses a static `Universe` that can be accessed by all of the methods that need it, rather than passing the context to each method. This simplifies the way programs can be converted to use MPI, but means that the `Universe` is not automatically dropped at the end of the program (so it must be dropped manually).
34#[cfg(feature = "mpi")]
35pub mod mpi {
36    use std::sync::atomic::{AtomicBool, Ordering};
37    use std::sync::OnceLock;
38
39    use lazy_static::lazy_static;
40    use mpi::environment::Universe;
41    use mpi::topology::{Process, SimpleCommunicator};
42    use mpi::traits::Communicator;
43    use parking_lot::RwLock;
44
45    lazy_static! {
46        static ref USE_MPI: AtomicBool = AtomicBool::new(false);
47    }
48
49    static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
50
51    /// The default root rank for MPI processes
52    pub const ROOT_RANK: i32 = 0;
53
54    /// Check if the current MPI process is the root process
55    pub fn is_root() -> bool {
56        if let Some(world) = crate::mpi::get_world() {
57            world.rank() == ROOT_RANK
58        } else {
59            false
60        }
61    }
62
63    /// Shortcut method to just get the global MPI communicator without accessing `size` and `rank`
64    /// directly
65    pub fn get_world() -> Option<SimpleCommunicator> {
66        if let Some(universe_lock) = MPI_UNIVERSE.get() {
67            if let Some(universe) = &*universe_lock.read() {
68                let world = universe.world();
69                if world.size() == 1 {
70                    return None;
71                }
72                return Some(world);
73            }
74        }
75        None
76    }
77
78    /// Get the rank of the current process
79    pub fn get_rank() -> Option<i32> {
80        get_world().map(|w| w.rank())
81    }
82
83    /// Get number of available processes/ranks
84    pub fn get_size() -> Option<i32> {
85        get_world().map(|w| w.size())
86    }
87
88    /// Use the MPI backend
89    ///
90    /// # Notes
91    ///
92    /// You must have MPI installed for this to work, and you must call the program with
93    /// `mpirun <executable>`, or bad things will happen.
94    ///
95    /// MPI runs an identical program on each process, but gives the program an ID called its
96    /// "rank". Only the results of methods on the root process (rank 0) should be
97    /// considered valid, as other processes only contain portions of each dataset. To ensure
98    /// you don't save or print data at other ranks, use the provided [`is_root()`]
99    /// method to check if the process is the root process.
100    ///
101    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
102    /// the other processes will still run, but they will be independent of the root process
103    /// and will no longer communicate with it. The root process stores no data, so it would
104    /// be difficult (and convoluted) to get the results which were already processed via
105    /// MPI.
106    ///
107    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
108    /// other `laddu` functions are called.
109    ///
110    /// If [`use_mpi()`] is called multiple times, the subsequent calls will have no
111    /// effect.
112    ///
113    /// <div class="warning">
114    ///
115    /// You **must** call [`finalize_mpi()`] before your program exits for MPI to terminate
116    /// smoothly.
117    ///
118    /// </div>
119    ///
120    /// # Examples
121    ///
122    /// ```ignore
123    /// fn main() {
124    ///     laddu_core::use_mpi();
125    ///
126    ///     // ... your code here ...
127    ///
128    ///     laddu_core::finalize_mpi();
129    /// }
130    ///
131    /// ```
132    pub fn use_mpi(trigger: bool) {
133        if trigger {
134            USE_MPI.store(true, Ordering::SeqCst);
135            MPI_UNIVERSE.get_or_init(|| {
136                #[cfg(feature = "rayon")]
137                let threading = mpi::Threading::Funneled;
138                #[cfg(not(feature = "rayon"))]
139                let threading = mpi::Threading::Single;
140                let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
141                let world = universe.world();
142                if world.size() == 1 {
143                    eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
144                    finalize_mpi();
145                    USE_MPI.store(false, Ordering::SeqCst);
146                    RwLock::new(None)
147                } else {
148                    RwLock::new(Some(universe))
149                }
150            });
151        }
152    }
153
154    /// Drop the MPI universe and finalize MPI at the end of a program
155    ///
156    /// This function will do nothing if MPI is not initialized.
157    ///
158    /// <div class="warning">
159    ///
160    /// This should only be called once and should be called at the end of all `laddu`-related
161    /// function calls. This must be called at the end of any program which uses MPI.
162    ///
163    /// </div>
164    pub fn finalize_mpi() {
165        if using_mpi() {
166            let mut universe = MPI_UNIVERSE.get().unwrap().write();
167            *universe = None;
168        }
169    }
170
171    /// Check if MPI backend is enabled
172    pub fn using_mpi() -> bool {
173        USE_MPI.load(Ordering::SeqCst)
174    }
175
176    /// A trait including some useful auxiliary methods for MPI
177    pub trait LadduMPI {
178        /// Get the process at the root rank
179        fn process_at_root(&self) -> Process<'_>;
180        /// Check if the current rank is the root rank
181        fn is_root(&self) -> bool;
182        /// Get the counts/displacements for partitioning a buffer of length
183        /// `buf_len`
184        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
185        /// Get the counts/displacements for partitioning a nested buffer (like
186        /// a [`Vec<Vec<T>>`]). If the internal vectors all have the same length
187        /// `internal_len` and there are `unflattened_len` elements in the
188        /// outer vector, then this will give the correct counts/displacements for a
189        /// flattened version of the nested buffer.
190        fn get_flattened_counts_displs(
191            &self,
192            unflattened_len: usize,
193            internal_len: usize,
194        ) -> (Vec<i32>, Vec<i32>);
195    }
196
197    impl LadduMPI for SimpleCommunicator {
198        fn process_at_root(&self) -> Process<'_> {
199            self.process_at_rank(crate::mpi::ROOT_RANK)
200        }
201
202        fn is_root(&self) -> bool {
203            self.rank() == crate::mpi::ROOT_RANK
204        }
205
206        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
207            let mut counts = vec![0; self.size() as usize];
208            let mut displs = vec![0; self.size() as usize];
209            let chunk_size = buf_len / self.size() as usize;
210            let surplus = buf_len % self.size() as usize;
211            for i in 0..self.size() as usize {
212                counts[i] = if i < surplus {
213                    chunk_size + 1
214                } else {
215                    chunk_size
216                } as i32;
217                displs[i] = if i == 0 {
218                    0
219                } else {
220                    displs[i - 1] + counts[i - 1]
221                };
222            }
223            (counts, displs)
224        }
225
226        fn get_flattened_counts_displs(
227            &self,
228            unflattened_len: usize,
229            internal_len: usize,
230        ) -> (Vec<i32>, Vec<i32>) {
231            let mut counts = vec![0; self.size() as usize];
232            let mut displs = vec![0; self.size() as usize];
233            let chunk_size = unflattened_len / self.size() as usize;
234            let surplus = unflattened_len % self.size() as usize;
235            for i in 0..self.size() as usize {
236                counts[i] = if i < surplus {
237                    (chunk_size + 1) * internal_len
238                } else {
239                    chunk_size * internal_len
240                } as i32;
241                displs[i] = if i == 0 {
242                    0
243                } else {
244                    displs[i - 1] + counts[i - 1]
245                };
246            }
247            (counts, displs)
248        }
249    }
250}
251
252use thiserror::Error;
253
254/// [`Amplitude`](crate::amplitudes::Amplitude)s and methods for making and evaluating them.
255pub mod amplitudes;
256/// Methods for loading and manipulating [`Event`]-based data.
257pub mod data;
258/// Structures for manipulating the cache and free parameters.
259pub mod resources;
260/// Utility functions, enums, and traits
261pub mod utils;
262/// Useful traits for all crate structs
263pub mod traits {
264    pub use crate::amplitudes::Amplitude;
265    pub use crate::utils::variables::Variable;
266    pub use crate::ReadWrite;
267}
268
269pub use crate::data::{open, BinnedDataset, Dataset, Event};
270pub use crate::resources::{
271    Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, ParameterID, Parameters,
272    Resources, ScalarID, VectorID,
273};
274pub use crate::utils::enums::{Channel, Frame, Sign};
275pub use crate::utils::variables::{
276    Angles, CosTheta, Mandelstam, Mass, Phi, PolAngle, PolMagnitude, Polarization,
277};
278pub use crate::utils::vectors::{Vec3, Vec4};
279pub use amplitudes::{
280    constant, parameter, AmplitudeID, Evaluator, Expression, Manager, Model, ParameterLike,
281};
282
283// Re-exports
284pub use ganesh::{Bound, Ensemble, Status};
285pub use nalgebra::DVector;
286pub use num::Complex;
287
288/// A floating-point number type (defaults to [`f64`], see `f32` feature).
289#[cfg(not(feature = "f32"))]
290pub type Float = f64;
291
292/// A floating-point number type (defaults to [`f64`], see `f32` feature).
293#[cfg(feature = "f32")]
294pub type Float = f32;
295
296/// The mathematical constant $`\pi`$.
297#[cfg(not(feature = "f32"))]
298pub const PI: Float = std::f64::consts::PI;
299
300/// The mathematical constant $`\pi`$.
301#[cfg(feature = "f32")]
302pub const PI: Float = std::f32::consts::PI;
303
304/// The error type used by all `laddu` internal methods
305#[derive(Error, Debug)]
306pub enum LadduError {
307    /// An alias for [`std::io::Error`].
308    #[error("IO Error: {0}")]
309    IOError(#[from] std::io::Error),
310    /// An alias for [`parquet::errors::ParquetError`].
311    #[error("Parquet Error: {0}")]
312    ParquetError(#[from] parquet::errors::ParquetError),
313    /// An alias for [`arrow::error::ArrowError`].
314    #[error("Arrow Error: {0}")]
315    ArrowError(#[from] arrow::error::ArrowError),
316    /// An alias for [`shellexpand::LookupError`].
317    #[error("Failed to expand path: {0}")]
318    LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
319    /// An error which occurs when the user tries to register two amplitudes by the same name to
320    /// the same [`Manager`].
321    #[error("An amplitude by the name \"{name}\" is already registered by this manager!")]
322    RegistrationError {
323        /// Name of amplitude which is already registered
324        name: String,
325    },
326    /// An error which occurs when the user tries to use an unregistered amplitude.
327    #[error("No registered amplitude with name \"{name}\"!")]
328    AmplitudeNotFoundError {
329        /// Name of amplitude which failed lookup
330        name: String,
331    },
332    /// An error which occurs when the user tries to parse an invalid string of text, typically
333    /// into an enum variant.
334    #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
335    ParseError {
336        /// The string which was parsed
337        name: String,
338        /// The name of the object it failed to parse into
339        object: String,
340    },
341    /// An error returned by the Rust encoder
342    #[error("Encoder error: {0}")]
343    EncodeError(#[from] bincode::error::EncodeError),
344    /// An error returned by the Rust decoder
345    #[error("Decoder error: {0}")]
346    DecodeError(#[from] bincode::error::DecodeError),
347    /// An error returned by the Python pickle (de)serializer
348    #[error("Pickle conversion error: {0}")]
349    PickleError(#[from] serde_pickle::Error),
350    /// An error type for [`rayon`] thread pools
351    #[cfg(feature = "rayon")]
352    #[error("Error building thread pool: {0}")]
353    ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
354    /// An error type for [`numpy`]-related conversions
355    #[cfg(feature = "numpy")]
356    #[error("Numpy error: {0}")]
357    NumpyError(#[from] numpy::FromVecError),
358    /// A custom fallback error for errors too complex or too infrequent to warrant their own error
359    /// category.
360    #[error("{0}")]
361    Custom(String),
362}
363
364impl Clone for LadduError {
365    // This is a little hack because error types are rarely cloneable, but I need to store them in a
366    // cloneable box for minimizers and MCMC methods
367    fn clone(&self) -> Self {
368        let err_string = self.to_string();
369        LadduError::Custom(err_string)
370    }
371}
372
373#[cfg(feature = "python")]
374impl From<LadduError> for PyErr {
375    fn from(err: LadduError) -> Self {
376        use pyo3::exceptions::*;
377        let err_string = err.to_string();
378        match err {
379            LadduError::LookupError(_)
380            | LadduError::RegistrationError { .. }
381            | LadduError::AmplitudeNotFoundError { .. }
382            | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
383            LadduError::ParquetError(_)
384            | LadduError::ArrowError(_)
385            | LadduError::IOError(_)
386            | LadduError::EncodeError(_)
387            | LadduError::DecodeError(_)
388            | LadduError::PickleError(_) => PyIOError::new_err(err_string),
389            LadduError::Custom(_) => PyException::new_err(err_string),
390            #[cfg(feature = "rayon")]
391            LadduError::ThreadPoolError(_) => PyException::new_err(err_string),
392            #[cfg(feature = "numpy")]
393            LadduError::NumpyError(_) => PyException::new_err(err_string),
394        }
395    }
396}
397
398use serde::{de::DeserializeOwned, Serialize};
399use std::{
400    fmt::Debug,
401    fs::File,
402    io::{BufReader, BufWriter},
403    path::Path,
404};
405/// A trait which allows structs with [`Serialize`] and [`Deserialize`](`serde::Deserialize`) to be
406/// written and read from files with a certain set of types/extensions.
407///
408/// Currently, Python's pickle format is supported supported, since it's an easy-to-parse standard
409/// that supports floating point values better that JSON or TOML
410pub trait ReadWrite: Serialize + DeserializeOwned {
411    /// Create a null version of the object which acts as a shell into which Python's `pickle` module
412    /// can load data. This generally shouldn't be used to construct the struct in regular code.
413    fn create_null() -> Self;
414    /// Save a [`serde`]-object to a file path, using the extension to determine the file format
415    fn save_as<T: AsRef<str>>(&self, file_path: T) -> Result<(), LadduError> {
416        let expanded_path = shellexpand::full(file_path.as_ref())?;
417        let file_path = Path::new(expanded_path.as_ref());
418        let file = File::create(file_path)?;
419        let mut writer = BufWriter::new(file);
420        serde_pickle::to_writer(&mut writer, self, Default::default())?;
421        Ok(())
422    }
423    /// Load a [`serde`]-object from a file path, using the extension to determine the file format
424    fn load_from<T: AsRef<str>>(file_path: T) -> Result<Self, LadduError> {
425        let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
426        let file = File::open(file_path)?;
427        let reader = BufReader::new(file);
428        serde_pickle::from_reader(reader, Default::default()).map_err(LadduError::from)
429    }
430}
431
432impl ReadWrite for Status {
433    fn create_null() -> Self {
434        Status::default()
435    }
436}
437impl ReadWrite for Ensemble {
438    fn create_null() -> Self {
439        Ensemble::new(Vec::default())
440    }
441}
442impl ReadWrite for Point {
443    fn create_null() -> Self {
444        Point::default()
445    }
446}
447impl ReadWrite for Particle {
448    fn create_null() -> Self {
449        Particle::default()
450    }
451}
452impl ReadWrite for Swarm {
453    fn create_null() -> Self {
454        Swarm::new(SwarmPositionInitializer::Zero {
455            n_particles: 0,
456            n_dimensions: 0,
457        })
458    }
459}
460impl ReadWrite for Model {
461    fn create_null() -> Self {
462        Model {
463            manager: Manager::default(),
464            expression: Expression::default(),
465        }
466    }
467}