laddu_core/
lib.rs

1//! # laddu-core
2//!
3//! This is an internal crate used by `laddu`.
4#![warn(clippy::perf, clippy::style, missing_docs)]
5#![allow(clippy::excessive_precision)]
6#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
7
8use ganesh::core::{MCMCSummary, MinimizationSummary};
9#[cfg(feature = "python")]
10use pyo3::PyErr;
11
12/// Re-exported alias for `std::f64` to ease dependent crates transitioning to the 64-bit
13/// floating point API.
14pub use std::f64;
15
16/// MPI backend for `laddu`
17///
18/// Message Passing Interface (MPI) is a protocol which enables communication between multiple
19/// CPUs in a high-performance computing environment. While [`rayon`] can parallelize tasks on a
20/// single CPU, MPI can also parallelize tasks on multiple CPUs by running independent
21/// processes on all CPUs at once (tasks) which are assigned ids (ranks) which tell each
22/// process what to do and where to send results. This backend coordinates processes which would
23/// typically be parallelized over the events in a [`Dataset`](`crate::data::Dataset`).
24///
25/// To use this backend, the library must be built with the `mpi` feature, which requires an
26/// existing implementation of MPI like OpenMPI or MPICH. All processing code should be
27/// sandwiched between calls to [`use_mpi`] and [`finalize_mpi`]:
28/// ```ignore
29/// fn main() {
30///     laddu_core::mpi::use_mpi(true);
31///     // laddu analysis code here
32///     laddu_core::mpi::finalize_mpi();
33/// }
34/// ```
35///
36/// [`finalize_mpi`] must be called to trigger all the methods which clean up the MPI
37/// environment. While these are called by default when the [`Universe`](`mpi::environment::Universe`) is dropped, `laddu` uses a static `Universe` that can be accessed by all of the methods that need it, rather than passing the context to each method. This simplifies the way programs can be converted to use MPI, but means that the `Universe` is not automatically dropped at the end of the program (so it must be dropped manually).
38#[cfg(feature = "mpi")]
39#[cfg_attr(coverage_nightly, coverage(off))]
40pub mod mpi {
41    use std::sync::atomic::{AtomicBool, Ordering};
42    use std::sync::OnceLock;
43
44    use lazy_static::lazy_static;
45    use mpi::datatype::PartitionMut;
46    use mpi::environment::Universe;
47    use mpi::topology::{Process, SimpleCommunicator};
48    use mpi::traits::{Communicator, CommunicatorCollectives, Equivalence};
49    use parking_lot::RwLock;
50
51    lazy_static! {
52        static ref USE_MPI: AtomicBool = AtomicBool::new(false);
53    }
54
55    static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
56
57    /// The default root rank for MPI processes
58    pub const ROOT_RANK: i32 = 0;
59
60    /// Check if the current MPI process is the root process
61    pub fn is_root() -> bool {
62        if let Some(world) = crate::mpi::get_world() {
63            world.rank() == ROOT_RANK
64        } else {
65            true
66        }
67    }
68
69    /// Shortcut method to just get the global MPI communicator without accessing `size` and `rank`
70    /// directly
71    pub fn get_world() -> Option<SimpleCommunicator> {
72        if let Some(universe_lock) = MPI_UNIVERSE.get() {
73            if let Some(universe) = &*universe_lock.read() {
74                return Some(universe.world());
75            }
76        }
77        None
78    }
79
80    /// Get the rank of the current process
81    pub fn get_rank() -> i32 {
82        get_world().map(|w| w.rank()).unwrap_or(ROOT_RANK)
83    }
84
85    /// Get number of available processes/ranks
86    pub fn get_size() -> i32 {
87        get_world().map(|w| w.size()).unwrap_or(1)
88    }
89
90    /// Use the MPI backend
91    ///
92    /// # Notes
93    ///
94    /// You must have MPI installed for this to work, and you must call the program with
95    /// `mpirun <executable>`, or bad things will happen.
96    ///
97    /// MPI runs an identical program on each process, but gives the program an ID called its
98    /// "rank". Only the results of methods on the root process (rank 0) should be
99    /// considered valid, as other processes only contain portions of each dataset. To ensure
100    /// you don't save or print data at other ranks, use the provided [`is_root()`]
101    /// method to check if the process is the root process.
102    ///
103    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
104    /// the other processes will still run, but they will be independent of the root process
105    /// and will no longer communicate with it. The root process stores no data, so it would
106    /// be difficult (and convoluted) to get the results which were already processed via
107    /// MPI.
108    ///
109    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
110    /// other `laddu` functions are called.
111    ///
112    /// If [`use_mpi()`] is called multiple times, the subsequent calls will have no
113    /// effect.
114    ///
115    /// <div class="warning">
116    ///
117    /// You **must** call [`finalize_mpi()`] before your program exits for MPI to terminate
118    /// smoothly.
119    ///
120    /// </div>
121    ///
122    /// # Examples
123    ///
124    /// ```ignore
125    /// fn main() {
126    ///     laddu_core::use_mpi();
127    ///
128    ///     // ... your code here ...
129    ///
130    ///     laddu_core::finalize_mpi();
131    /// }
132    ///
133    /// ```
134    pub fn use_mpi(trigger: bool) {
135        if trigger {
136            USE_MPI.store(true, Ordering::SeqCst);
137            MPI_UNIVERSE.get_or_init(|| {
138                #[cfg(feature = "rayon")]
139                let threading = mpi::Threading::Funneled;
140                #[cfg(not(feature = "rayon"))]
141                let threading = mpi::Threading::Single;
142                let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
143                let world = universe.world();
144                if world.size() == 1 {
145                    eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
146                    finalize_mpi();
147                    USE_MPI.store(false, Ordering::SeqCst);
148                    RwLock::new(None)
149                } else {
150                    RwLock::new(Some(universe))
151                }
152            });
153        }
154    }
155
156    /// Drop the MPI universe and finalize MPI at the end of a program
157    ///
158    /// This function will do nothing if MPI is not initialized.
159    ///
160    /// <div class="warning">
161    ///
162    /// This should only be called once and should be called at the end of all `laddu`-related
163    /// function calls. This must be called at the end of any program which uses MPI.
164    ///
165    /// </div>
166    pub fn finalize_mpi() {
167        if using_mpi() {
168            let mut universe = MPI_UNIVERSE.get().unwrap().write();
169            *universe = None;
170        }
171    }
172
173    /// Check if MPI backend is enabled
174    pub fn using_mpi() -> bool {
175        USE_MPI.load(Ordering::SeqCst)
176    }
177
178    fn counts_displs(size: usize, total: usize, stride: usize) -> (Vec<i32>, Vec<i32>) {
179        let mut counts = vec![0i32; size];
180        let mut displs = vec![0i32; size];
181        if size == 0 {
182            return (counts, displs);
183        }
184        let base = total / size;
185        let remainder = total % size;
186        let mut offset = 0i32;
187        for rank in 0..size {
188            let n = if rank < remainder { base + 1 } else { base };
189            let scaled = (n * stride) as i32;
190            counts[rank] = scaled;
191            displs[rank] = offset;
192            offset += scaled;
193        }
194        (counts, displs)
195    }
196
197    #[inline]
198    fn rank_local_from_global(i_global: usize, size: usize, total: usize) -> (usize, usize) {
199        assert!(size > 0, "Communicator must have at least one rank");
200        assert!(total > 0, "Cannot map global indices when dataset is empty");
201        assert!(
202            i_global < total,
203            "Global index {} out of bounds for {} events",
204            i_global,
205            total
206        );
207        let base = total / size;
208        let remainder = total % size;
209        let big_block = base + 1;
210        let threshold = remainder * big_block;
211        if i_global < threshold {
212            let rank = i_global / big_block;
213            let local = i_global % big_block;
214            (rank, local)
215        } else {
216            let adjusted = i_global - threshold;
217            let rank = remainder + adjusted / base;
218            let local = adjusted % base;
219            (rank, local)
220        }
221    }
222
223    /// A trait including some useful auxiliary methods for MPI
224    pub trait LadduMPI {
225        /// Get the process at the root rank
226        fn process_at_root(&self) -> Process<'_>;
227        /// Check if the current rank is the root rank
228        fn is_root(&self) -> bool;
229        /// Gather arbitrarily-sized local slices into a buffer ordered by the
230        /// canonical dataset partition.
231        fn all_gather_partitioned<T: Equivalence + Default + Clone>(
232            &self,
233            local: &[T],
234            total: usize,
235            stride: Option<usize>,
236        ) -> Vec<T>;
237        /// Gather batches corresponding to arbitrary global indices while
238        /// preserving the order of `global_indices`.
239        fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
240            &self,
241            local: &[T],
242            global_indices: &[usize],
243            total: usize,
244            stride: Option<usize>,
245        ) -> Vec<T>;
246        /// Return the `(rank, local_index)` pair owning `global_index` in a
247        /// dataset containing `total` events.
248        fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize);
249        /// Translate a list of global dataset indices into the corresponding
250        /// local indices owned by this rank, preserving their original order.
251        fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize>;
252        /// Get the counts/displacements for partitioning a buffer of length
253        /// `buf_len`
254        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
255        /// Get the counts/displacements for partitioning a nested buffer (like
256        /// a [`Vec<Vec<T>>`]). If the internal vectors all have the same length
257        /// `internal_len` and there are `unflattened_len` elements in the
258        /// outer vector, then this will give the correct counts/displacements for a
259        /// flattened version of the nested buffer.
260        fn get_flattened_counts_displs(
261            &self,
262            unflattened_len: usize,
263            internal_len: usize,
264        ) -> (Vec<i32>, Vec<i32>);
265    }
266
267    impl LadduMPI for SimpleCommunicator {
268        fn process_at_root(&self) -> Process<'_> {
269            self.process_at_rank(crate::mpi::ROOT_RANK)
270        }
271
272        fn is_root(&self) -> bool {
273            self.rank() == crate::mpi::ROOT_RANK
274        }
275
276        /// Gather arbitrarily-sized local slices into a buffer ordered by the
277        /// canonical dataset partition.
278        fn all_gather_partitioned<T: Equivalence + Default + Clone>(
279            &self,
280            local: &[T],
281            total: usize,
282            stride: Option<usize>,
283        ) -> Vec<T> {
284            let size = self.size() as usize;
285            let stride = stride.unwrap_or(1);
286            assert!(stride > 0, "Stride must be greater than zero");
287            let mut out = vec![T::default(); total * stride];
288            if total == 0 || size == 0 {
289                return out;
290            }
291            let (counts, displs) = counts_displs(size, total, stride);
292            {
293                let mut partition = PartitionMut::new(&mut out, counts, displs);
294                self.all_gather_varcount_into(local, &mut partition);
295            }
296            out
297        }
298
299        /// Gather batches corresponding to arbitrary global indices while
300        /// preserving the order of `global_indices`.
301        fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
302            &self,
303            local: &[T],
304            global_indices: &[usize],
305            total: usize,
306            stride: Option<usize>,
307        ) -> Vec<T> {
308            let size = self.size() as usize;
309            let stride = stride.unwrap_or(1);
310            assert!(stride > 0, "Stride must be greater than zero");
311            let n_indices = global_indices.len();
312            let mut gathered = vec![T::default(); n_indices * stride];
313            if n_indices == 0 || size == 0 {
314                return gathered;
315            }
316
317            assert!(
318                total > 0,
319                "Cannot gather batched data from an empty dataset"
320            );
321
322            let mut locals_by_rank = vec![Vec::<usize>::new(); size];
323            let mut targets_by_rank = vec![Vec::<usize>::new(); size];
324            for (position, &global_index) in global_indices.iter().enumerate() {
325                let (rank, local_index) = rank_local_from_global(global_index, size, total);
326                locals_by_rank[rank].push(local_index);
327                targets_by_rank[rank].push(position);
328            }
329
330            let mut counts = vec![0i32; size];
331            let mut displs = vec![0i32; size];
332            for rank in 0..size {
333                counts[rank] = (locals_by_rank[rank].len() * stride) as i32;
334                displs[rank] = if rank == 0 {
335                    0
336                } else {
337                    displs[rank - 1] + counts[rank - 1]
338                };
339            }
340
341            let expected_local = locals_by_rank[self.rank() as usize].len() * stride;
342            debug_assert_eq!(
343                local.len(),
344                expected_local,
345                "Local buffer length does not match expected gathered size for rank {}",
346                self.rank()
347            );
348
349            {
350                let mut partition =
351                    PartitionMut::new(&mut gathered, counts.clone(), displs.clone());
352                self.all_gather_varcount_into(local, &mut partition);
353            }
354
355            let mut result = vec![T::default(); n_indices * stride];
356            for rank in 0..size {
357                let mut cursor = displs[rank] as usize;
358                for &target in &targets_by_rank[rank] {
359                    let dst = target * stride;
360                    for offset in 0..stride {
361                        result[dst + offset] = gathered[cursor + offset].clone();
362                    }
363                    cursor += stride;
364                }
365            }
366
367            result
368        }
369
370        fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize) {
371            assert!(total > 0, "Cannot look up ownership in an empty dataset");
372            let size = self.size() as usize;
373            let (rank, local) = rank_local_from_global(global_index, size, total);
374            (rank as i32, local)
375        }
376
377        /// Translate a list of global dataset indices into the corresponding
378        /// local indices owned by this rank, preserving their original order.
379        fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize> {
380            let size = self.size() as usize;
381            let this_rank = self.rank() as usize;
382            let mut locals = Vec::new();
383            if total == 0 {
384                return locals;
385            }
386            for &global_index in global_indices {
387                let (rank, local_index) = rank_local_from_global(global_index, size, total);
388                if rank == this_rank {
389                    locals.push(local_index);
390                }
391            }
392            locals
393        }
394        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
395            let mut counts = vec![0; self.size() as usize];
396            let mut displs = vec![0; self.size() as usize];
397            let chunk_size = buf_len / self.size() as usize;
398            let surplus = buf_len % self.size() as usize;
399            for i in 0..self.size() as usize {
400                counts[i] = if i < surplus {
401                    chunk_size + 1
402                } else {
403                    chunk_size
404                } as i32;
405                displs[i] = if i == 0 {
406                    0
407                } else {
408                    displs[i - 1] + counts[i - 1]
409                };
410            }
411            (counts, displs)
412        }
413
414        fn get_flattened_counts_displs(
415            &self,
416            unflattened_len: usize,
417            internal_len: usize,
418        ) -> (Vec<i32>, Vec<i32>) {
419            let mut counts = vec![0; self.size() as usize];
420            let mut displs = vec![0; self.size() as usize];
421            let chunk_size = unflattened_len / self.size() as usize;
422            let surplus = unflattened_len % self.size() as usize;
423            for i in 0..self.size() as usize {
424                counts[i] = if i < surplus {
425                    (chunk_size + 1) * internal_len
426                } else {
427                    chunk_size * internal_len
428                } as i32;
429                displs[i] = if i == 0 {
430                    0
431                } else {
432                    displs[i - 1] + counts[i - 1]
433                };
434            }
435            (counts, displs)
436        }
437    }
438}
439
440use thiserror::Error;
441
442/// [`Amplitude`](crate::amplitudes::Amplitude)s and methods for making and evaluating them.
443pub mod amplitudes;
444/// Methods for loading and manipulating [`EventData`]-based data.
445pub mod data;
446/// Structures for manipulating the cache and free parameters.
447pub mod resources;
448/// Utility functions, enums, and traits
449pub mod utils;
450/// Useful traits for all crate structs
451pub mod traits {
452    pub use crate::amplitudes::Amplitude;
453    pub use crate::utils::variables::Variable;
454    pub use crate::ReadWrite;
455}
456
457pub use crate::data::{
458    BinnedDataset, Dataset, DatasetMetadata, DatasetReadOptions, Event, EventData,
459};
460pub use crate::resources::{
461    Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, ParameterID, Parameters,
462    Resources, ScalarID, VectorID,
463};
464pub use crate::utils::enums::{Channel, Frame, Sign};
465pub use crate::utils::variables::{
466    Angles, CosTheta, Mandelstam, Mass, Phi, PolAngle, PolMagnitude, Polarization,
467};
468pub use crate::utils::vectors::{Vec3, Vec4};
469pub use amplitudes::{constant, parameter, AmplitudeID, Evaluator, Expression, ParameterLike};
470
471/// The mathematical constant $`\pi`$.
472pub const PI: f64 = std::f64::consts::PI;
473
474/// A [`Result`] type alias for [`LadduError`]s.
475pub type LadduResult<T> = Result<T, LadduError>;
476
477/// The error type used by all `laddu` internal methods
478#[derive(Error, Debug)]
479pub enum LadduError {
480    /// An alias for [`std::io::Error`].
481    #[error("IO Error: {0}")]
482    IOError(#[from] std::io::Error),
483    /// An alias for [`parquet::errors::ParquetError`].
484    #[error("Parquet Error: {0}")]
485    ParquetError(#[from] parquet::errors::ParquetError),
486    /// An alias for [`arrow::error::ArrowError`].
487    #[error("Arrow Error: {0}")]
488    ArrowError(#[from] arrow::error::ArrowError),
489    /// An alias for [`shellexpand::LookupError`].
490    #[error("Failed to expand path: {0}")]
491    LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
492    /// An error which occurs when the user tries to register two amplitudes by the same name.
493    #[error("An amplitude by the name \"{name}\" is already registered!")]
494    RegistrationError {
495        /// Name of amplitude which is already registered
496        name: String,
497    },
498    /// An error which occurs when the user tries to use an unregistered amplitude.
499    #[error("No registered amplitude with name \"{name}\"!")]
500    AmplitudeNotFoundError {
501        /// Name of amplitude which failed lookup
502        name: String,
503    },
504    /// An error which occurs when the user tries to parse an invalid string of text, typically
505    /// into an enum variant.
506    #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
507    ParseError {
508        /// The string which was parsed
509        name: String,
510        /// The name of the object it failed to parse into
511        object: String,
512    },
513    /// An error returned by the Rust encoder
514    #[error("Encoder error: {0}")]
515    EncodeError(#[from] bincode::error::EncodeError),
516    /// An error returned by the Rust decoder
517    #[error("Decoder error: {0}")]
518    DecodeError(#[from] bincode::error::DecodeError),
519    /// An error returned by the Python pickle (de)serializer
520    #[error("Pickle conversion error: {0}")]
521    PickleError(#[from] serde_pickle::Error),
522    /// An error which occurs when parameter definitions conflict or clash.
523    #[error("Parameter \"{name}\" conflict: {reason}")]
524    ParameterConflict {
525        /// Name of parameter
526        name: String,
527        /// Description of conflict
528        reason: String,
529    },
530    /// An error which occurs when attempting to use an unregistered or unnamed parameter.
531    #[error("Parameter \"{name}\" could not be registered: {reason}")]
532    UnregisteredParameter {
533        /// Name of parameter
534        name: String,
535        /// Reason for failure
536        reason: String,
537    },
538    /// An error type for [`rayon`] thread pools
539    #[cfg(feature = "rayon")]
540    #[error("Error building thread pool: {0}")]
541    ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
542    /// An error type for [`numpy`]-related conversions
543    #[cfg(feature = "numpy")]
544    #[error("Numpy error: {0}")]
545    NumpyError(#[from] numpy::FromVecError),
546    /// A required column was not found in the input
547    #[error("Required column \"{name}\" was not found in the dataset")]
548    MissingColumn {
549        /// Name of the missing column
550        name: String,
551    },
552    /// A column has an unsupported type
553    #[error("Column \"{name}\" has unsupported type \"{datatype}\"")]
554    InvalidColumnType {
555        /// Column name
556        name: String,
557        /// Detected data type
558        datatype: String,
559    },
560    /// A duplicate name was provided for p4 or aux data
561    #[error("Duplicate {category} name \"{name}\" provided")]
562    DuplicateName {
563        /// Category (p4 or aux)
564        category: &'static str,
565        /// Duplicate name
566        name: String,
567    },
568    /// An unknown name was referenced (e.g., for boosts)
569    #[error("Unknown {category} name \"{name}\"")]
570    UnknownName {
571        /// Category (p4 or aux)
572        category: &'static str,
573        /// Name that could not be resolved
574        name: String,
575    },
576    /// A custom fallback error for errors too complex or too infrequent to warrant their own error
577    /// category.
578    #[error("{0}")]
579    Custom(String),
580}
581
582impl Clone for LadduError {
583    // This is a little hack because error types are rarely cloneable, but I need to store them in a
584    // cloneable box for minimizers and MCMC methods
585    fn clone(&self) -> Self {
586        let err_string = self.to_string();
587        LadduError::Custom(err_string)
588    }
589}
590
591#[cfg(feature = "python")]
592impl From<LadduError> for PyErr {
593    fn from(err: LadduError) -> Self {
594        use pyo3::exceptions::*;
595        let err_string = err.to_string();
596        match err {
597            LadduError::LookupError(_)
598            | LadduError::RegistrationError { .. }
599            | LadduError::AmplitudeNotFoundError { .. }
600            | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
601            LadduError::ParquetError(_)
602            | LadduError::ArrowError(_)
603            | LadduError::IOError(_)
604            | LadduError::EncodeError(_)
605            | LadduError::DecodeError(_)
606            | LadduError::PickleError(_) => PyIOError::new_err(err_string),
607            LadduError::MissingColumn { .. } | LadduError::UnknownName { .. } => {
608                PyKeyError::new_err(err_string)
609            }
610            LadduError::InvalidColumnType { .. }
611            | LadduError::DuplicateName { .. }
612            | LadduError::ParameterConflict { .. }
613            | LadduError::UnregisteredParameter { .. } => PyValueError::new_err(err_string),
614            LadduError::Custom(_) => PyException::new_err(err_string),
615            #[cfg(feature = "rayon")]
616            LadduError::ThreadPoolError(_) => PyException::new_err(err_string),
617            #[cfg(feature = "numpy")]
618            LadduError::NumpyError(_) => PyException::new_err(err_string),
619        }
620    }
621}
622
623use serde::{de::DeserializeOwned, Serialize};
624use std::fmt::Debug;
625/// A trait which allows structs with [`Serialize`] and [`Deserialize`](`serde::Deserialize`) to
626/// have a null constructor which Python can fill with data. This allows such structs to be
627/// pickle-able from the Python API.
628pub trait ReadWrite: Serialize + DeserializeOwned {
629    /// Create a null version of the object which acts as a shell into which Python's `pickle` module
630    /// can load data. This generally shouldn't be used to construct the struct in regular code.
631    fn create_null() -> Self;
632}
633impl ReadWrite for MCMCSummary {
634    fn create_null() -> Self {
635        MCMCSummary::default()
636    }
637}
638impl ReadWrite for MinimizationSummary {
639    fn create_null() -> Self {
640        MinimizationSummary::default()
641    }
642}