Skip to main content

laddu_core/
lib.rs

1//! # laddu-core
2//!
3//! This is an internal crate used by `laddu`.
4#![warn(clippy::perf, clippy::style, missing_docs)]
5#![allow(clippy::excessive_precision)]
6#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
7
8/// Re-exported alias for `std::f64` to ease dependent crates transitioning to the 64-bit
9/// floating point API.
10pub use std::f64;
11
12#[cfg(feature = "python")]
13use pyo3::PyErr;
14
15/// MPI backend for `laddu`
16///
17/// Message Passing Interface (MPI) is a protocol which enables communication between multiple
18/// CPUs in a high-performance computing environment. While [`rayon`] can parallelize tasks on a
19/// single CPU, MPI can also parallelize tasks on multiple CPUs by running independent
20/// processes on all CPUs at once (tasks) which are assigned ids (ranks) which tell each
21/// process what to do and where to send results. This backend coordinates processes which would
22/// typically be parallelized over the events in a [`Dataset`](`crate::data::Dataset`).
23///
24/// To use this backend, the library must be built with the `mpi` feature, which requires an
25/// existing implementation of MPI like OpenMPI or MPICH. All processing code should be
26/// sandwiched between calls to [`mpi::use_mpi`] and [`mpi::finalize_mpi`]:
27/// ```ignore
28/// fn main() {
29///     laddu_core::mpi::use_mpi(true);
30///     // laddu analysis code here
31///     laddu_core::mpi::finalize_mpi();
32/// }
33/// ```
34///
35/// [`mpi::finalize_mpi`] must be called to trigger all the methods which clean up the MPI
36/// environment. While these are called by default when MPI's `Universe` is dropped, `laddu`
37/// uses a static `Universe` that can be accessed by all of the methods that need it, rather
38/// than passing the context to each method. This simplifies the way programs can be converted
39/// to use MPI, but means that the `Universe` is not automatically dropped at the end of the
40/// program (so it must be dropped manually).
41#[cfg(feature = "mpi")]
42#[cfg_attr(coverage_nightly, coverage(off))]
43pub mod mpi {
44    use std::{
45        ops::Range,
46        sync::{
47            atomic::{AtomicBool, Ordering},
48            OnceLock,
49        },
50    };
51
52    use lazy_static::lazy_static;
53    use mpi::{
54        datatype::PartitionMut,
55        environment::Universe,
56        topology::{Process, SimpleCommunicator},
57        traits::{Communicator, CommunicatorCollectives, Equivalence},
58    };
59    use parking_lot::RwLock;
60
61    lazy_static! {
62        static ref USE_MPI: AtomicBool = AtomicBool::new(false);
63    }
64
65    static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
66
67    /// The default root rank for MPI processes
68    pub const ROOT_RANK: i32 = 0;
69
70    /// Check if the current MPI process is the root process
71    pub fn is_root() -> bool {
72        if let Some(world) = crate::mpi::get_world() {
73            world.rank() == ROOT_RANK
74        } else {
75            true
76        }
77    }
78
79    /// Shortcut method to just get the global MPI communicator without accessing `size` and `rank`
80    /// directly
81    pub fn get_world() -> Option<SimpleCommunicator> {
82        if let Some(universe_lock) = MPI_UNIVERSE.get() {
83            if let Some(universe) = &*universe_lock.read() {
84                return Some(universe.world());
85            }
86        }
87        None
88    }
89
90    /// Get the rank of the current process
91    pub fn get_rank() -> i32 {
92        get_world().map(|w| w.rank()).unwrap_or(ROOT_RANK)
93    }
94
95    /// Get number of available processes/ranks
96    pub fn get_size() -> i32 {
97        get_world().map(|w| w.size()).unwrap_or(1)
98    }
99
100    /// Use the MPI backend
101    ///
102    /// # Notes
103    ///
104    /// You must have MPI installed for this to work, and you must call the program with
105    /// `mpirun <executable>`, or bad things will happen.
106    ///
107    /// MPI runs an identical program on each process, but gives the program an ID called its
108    /// "rank". Only the results of methods on the root process (rank 0) should be
109    /// considered valid, as other processes only contain portions of each dataset. To ensure
110    /// you don't save or print data at other ranks, use the provided [`is_root()`]
111    /// method to check if the process is the root process.
112    ///
113    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
114    /// the other processes will still run, but they will be independent of the root process
115    /// and will no longer communicate with it. The root process stores no data, so it would
116    /// be difficult (and convoluted) to get the results which were already processed via
117    /// MPI.
118    ///
119    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
120    /// other `laddu` functions are called.
121    ///
122    /// If [`use_mpi()`] is called multiple times, the subsequent calls will have no
123    /// effect.
124    ///
125    /// <div class="warning">
126    ///
127    /// You **must** call [`finalize_mpi()`] before your program exits for MPI to terminate
128    /// smoothly.
129    ///
130    /// </div>
131    ///
132    /// # Examples
133    ///
134    /// ```ignore
135    /// fn main() {
136    ///     laddu_core::use_mpi();
137    ///
138    ///     // ... your code here ...
139    ///
140    ///     laddu_core::finalize_mpi();
141    /// }
142    ///
143    /// ```
144    pub fn use_mpi(trigger: bool) {
145        if trigger {
146            USE_MPI.store(true, Ordering::SeqCst);
147            MPI_UNIVERSE.get_or_init(|| {
148                #[cfg(feature = "rayon")]
149                let threading = mpi::Threading::Funneled;
150                #[cfg(not(feature = "rayon"))]
151                let threading = mpi::Threading::Single;
152                let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
153                let world = universe.world();
154                if world.size() == 1 {
155                    eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
156                    finalize_mpi();
157                    USE_MPI.store(false, Ordering::SeqCst);
158                    RwLock::new(None)
159                } else {
160                    RwLock::new(Some(universe))
161                }
162            });
163        }
164    }
165
166    /// Drop the MPI universe and finalize MPI at the end of a program
167    ///
168    /// This function will do nothing if MPI is not initialized.
169    ///
170    /// <div class="warning">
171    ///
172    /// This should only be called once and should be called at the end of all `laddu`-related
173    /// function calls. This must be called at the end of any program which uses MPI.
174    ///
175    /// </div>
176    pub fn finalize_mpi() {
177        if get_world().is_some() {
178            if let Some(universe_lock) = MPI_UNIVERSE.get() {
179                let mut universe = universe_lock.write();
180                *universe = None;
181            }
182        }
183        USE_MPI.store(false, Ordering::SeqCst);
184    }
185
186    /// Check if MPI backend is enabled
187    pub fn using_mpi() -> bool {
188        USE_MPI.load(Ordering::SeqCst)
189    }
190
191    fn counts_displs(size: usize, total: usize, stride: usize) -> (Vec<i32>, Vec<i32>) {
192        let mut counts = vec![0i32; size];
193        let mut displs = vec![0i32; size];
194        if size == 0 {
195            return (counts, displs);
196        }
197        let base = total / size;
198        let remainder = total % size;
199        let mut offset = 0i32;
200        for rank in 0..size {
201            let n = if rank < remainder { base + 1 } else { base };
202            let scaled = (n * stride) as i32;
203            counts[rank] = scaled;
204            displs[rank] = offset;
205            offset += scaled;
206        }
207        (counts, displs)
208    }
209
210    #[inline]
211    fn rank_local_from_global(i_global: usize, size: usize, total: usize) -> (usize, usize) {
212        assert!(size > 0, "Communicator must have at least one rank");
213        assert!(total > 0, "Cannot map global indices when dataset is empty");
214        assert!(
215            i_global < total,
216            "Global index {} out of bounds for {} events",
217            i_global,
218            total
219        );
220        let base = total / size;
221        let remainder = total % size;
222        let big_block = base + 1;
223        let threshold = remainder * big_block;
224        if i_global < threshold {
225            let rank = i_global / big_block;
226            let local = i_global % big_block;
227            (rank, local)
228        } else {
229            let adjusted = i_global - threshold;
230            let rank = remainder + adjusted / base;
231            let local = adjusted % base;
232            (rank, local)
233        }
234    }
235
236    /// Canonical partitioning information for distributing a dataset across MPI ranks.
237    #[derive(Clone, Debug)]
238    pub struct Partition {
239        counts: Vec<i32>,
240        displs: Vec<i32>,
241        total: usize,
242    }
243
244    impl Partition {
245        /// Build a new distribution for `total` items across `size` ranks.
246        pub fn new(size: usize, total: usize) -> Self {
247            assert!(size > 0, "Communicator must have at least one rank");
248            let (counts, displs) = counts_displs(size, total, 1);
249            Self {
250                counts,
251                displs,
252                total,
253            }
254        }
255
256        /// Total number of items tracked by this partition.
257        pub fn total(&self) -> usize {
258            self.total
259        }
260
261        /// Number of ranks described by this partition.
262        pub fn n_ranks(&self) -> usize {
263            self.counts.len()
264        }
265
266        /// Number of items assigned to `rank`.
267        pub fn len_for_rank(&self, rank: usize) -> usize {
268            self.counts[rank] as usize
269        }
270
271        /// Starting global index for `rank`.
272        pub fn start_for_rank(&self, rank: usize) -> usize {
273            self.displs[rank] as usize
274        }
275
276        /// Contiguous global range owned by `rank`.
277        pub fn range_for_rank(&self, rank: usize) -> Range<usize> {
278            let start = self.start_for_rank(rank);
279            start..start + self.len_for_rank(rank)
280        }
281
282        /// Determine the owning rank and local index for a global dataset index.
283        pub fn owner_of(&self, global_index: usize) -> (usize, usize) {
284            assert!(
285                self.total > 0,
286                "Cannot map global indices when dataset is empty"
287            );
288            rank_local_from_global(global_index, self.n_ranks(), self.total)
289        }
290
291        /// Convert into raw `(counts, displacements)` buffers.
292        pub fn into_raw(self) -> (Vec<i32>, Vec<i32>) {
293            (self.counts, self.displs)
294        }
295    }
296
297    /// A trait including some useful auxiliary methods for MPI
298    pub trait LadduMPI {
299        /// Get the process at the root rank
300        fn process_at_root(&self) -> Process<'_>;
301        /// Check if the current rank is the root rank
302        fn is_root(&self) -> bool;
303        /// Gather arbitrarily-sized local slices into a buffer ordered by the
304        /// canonical dataset partition.
305        fn all_gather_partitioned<T: Equivalence + Default + Clone>(
306            &self,
307            local: &[T],
308            total: usize,
309            stride: Option<usize>,
310        ) -> Vec<T>;
311        /// Gather local slices into a buffer using explicit
312        /// `(counts, displacements)` in element units.
313        fn all_gather_with_counts<T: Equivalence + Default + Clone>(
314            &self,
315            local: &[T],
316            counts: &[i32],
317            displs: &[i32],
318        ) -> Vec<T>;
319        /// Gather batches corresponding to arbitrary global indices while
320        /// preserving the order of `global_indices`.
321        fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
322            &self,
323            local: &[T],
324            global_indices: &[usize],
325            total: usize,
326            stride: Option<usize>,
327        ) -> Vec<T>;
328        /// Return the `(rank, local_index)` pair owning `global_index` in a
329        /// dataset containing `total` events.
330        fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize);
331        /// Translate a list of global dataset indices into the corresponding
332        /// local indices owned by this rank, preserving their original order.
333        fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize>;
334        /// Get the counts/displacements for partitioning a buffer of length
335        /// `buf_len`
336        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
337        /// Build a [`Partition`] describing how `total` items are distributed
338        /// across ranks.
339        fn partition(&self, total: usize) -> Partition;
340        /// Get the counts/displacements for partitioning a nested buffer (like
341        /// a [`Vec<Vec<T>>`]). If the internal vectors all have the same length
342        /// `internal_len` and there are `unflattened_len` elements in the
343        /// outer vector, then this will give the correct counts/displacements for a
344        /// flattened version of the nested buffer.
345        fn get_flattened_counts_displs(
346            &self,
347            unflattened_len: usize,
348            internal_len: usize,
349        ) -> (Vec<i32>, Vec<i32>);
350    }
351
352    impl LadduMPI for SimpleCommunicator {
353        fn process_at_root(&self) -> Process<'_> {
354            self.process_at_rank(crate::mpi::ROOT_RANK)
355        }
356
357        fn is_root(&self) -> bool {
358            self.rank() == crate::mpi::ROOT_RANK
359        }
360
361        /// Gather arbitrarily-sized local slices into a buffer ordered by the
362        /// canonical dataset partition.
363        fn all_gather_partitioned<T: Equivalence + Default + Clone>(
364            &self,
365            local: &[T],
366            total: usize,
367            stride: Option<usize>,
368        ) -> Vec<T> {
369            let size = self.size() as usize;
370            let stride = stride.unwrap_or(1);
371            assert!(stride > 0, "Stride must be greater than zero");
372            let mut out = vec![T::default(); total * stride];
373            if total == 0 || size == 0 {
374                return out;
375            }
376            let (counts, displs) = counts_displs(size, total, stride);
377            {
378                let mut partition = PartitionMut::new(&mut out, counts, displs);
379                self.all_gather_varcount_into(local, &mut partition);
380            }
381            out
382        }
383
384        fn all_gather_with_counts<T: Equivalence + Default + Clone>(
385            &self,
386            local: &[T],
387            counts: &[i32],
388            displs: &[i32],
389        ) -> Vec<T> {
390            assert_eq!(
391                counts.len(),
392                displs.len(),
393                "Counts and displacements must have the same length"
394            );
395            assert_eq!(
396                counts.len(),
397                self.size() as usize,
398                "Counts/displacements must match communicator size"
399            );
400            let total = counts.iter().map(|count| *count as usize).sum();
401            let mut out = vec![T::default(); total];
402            {
403                let mut partition = PartitionMut::new(&mut out, counts.to_vec(), displs.to_vec());
404                self.all_gather_varcount_into(local, &mut partition);
405            }
406            out
407        }
408
409        /// Gather batches corresponding to arbitrary global indices while
410        /// preserving the order of `global_indices`.
411        fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
412            &self,
413            local: &[T],
414            global_indices: &[usize],
415            total: usize,
416            stride: Option<usize>,
417        ) -> Vec<T> {
418            let size = self.size() as usize;
419            let stride = stride.unwrap_or(1);
420            assert!(stride > 0, "Stride must be greater than zero");
421            let n_indices = global_indices.len();
422            let mut gathered = vec![T::default(); n_indices * stride];
423            if n_indices == 0 || size == 0 {
424                return gathered;
425            }
426
427            assert!(
428                total > 0,
429                "Cannot gather batched data from an empty dataset"
430            );
431
432            let partition = Partition::new(size, total);
433            let mut locals_by_rank = vec![Vec::<usize>::new(); size];
434            let mut targets_by_rank = vec![Vec::<usize>::new(); size];
435            for (position, &global_index) in global_indices.iter().enumerate() {
436                let (rank, local_index) = partition.owner_of(global_index);
437                locals_by_rank[rank].push(local_index);
438                targets_by_rank[rank].push(position);
439            }
440
441            let mut counts = vec![0i32; size];
442            let mut displs = vec![0i32; size];
443            for rank in 0..size {
444                counts[rank] = (locals_by_rank[rank].len() * stride) as i32;
445                displs[rank] = if rank == 0 {
446                    0
447                } else {
448                    displs[rank - 1] + counts[rank - 1]
449                };
450            }
451
452            let expected_local = locals_by_rank[self.rank() as usize].len() * stride;
453            debug_assert_eq!(
454                local.len(),
455                expected_local,
456                "Local buffer length does not match expected gathered size for rank {}",
457                self.rank()
458            );
459
460            {
461                let mut partition =
462                    PartitionMut::new(&mut gathered, counts.clone(), displs.clone());
463                self.all_gather_varcount_into(local, &mut partition);
464            }
465
466            let mut result = vec![T::default(); n_indices * stride];
467            for rank in 0..size {
468                let mut cursor = displs[rank] as usize;
469                for &target in &targets_by_rank[rank] {
470                    let dst = target * stride;
471                    result[dst..(stride + dst)]
472                        .clone_from_slice(&gathered[cursor..(stride + cursor)]);
473                    cursor += stride;
474                }
475            }
476
477            result
478        }
479
480        fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize) {
481            let partition = Partition::new(self.size() as usize, total);
482            let (rank, local) = partition.owner_of(global_index);
483            (rank as i32, local)
484        }
485
486        /// Translate a list of global dataset indices into the corresponding
487        /// local indices owned by this rank, preserving their original order.
488        fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize> {
489            let partition = Partition::new(self.size() as usize, total);
490            let this_rank = self.rank() as usize;
491            let mut locals = Vec::new();
492            if total == 0 {
493                return locals;
494            }
495            for &global_index in global_indices {
496                let (rank, local_index) = partition.owner_of(global_index);
497                if rank == this_rank {
498                    locals.push(local_index);
499                }
500            }
501            locals
502        }
503        fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
504            self.partition(buf_len).into_raw()
505        }
506
507        fn partition(&self, total: usize) -> Partition {
508            Partition::new(self.size() as usize, total)
509        }
510
511        fn get_flattened_counts_displs(
512            &self,
513            unflattened_len: usize,
514            internal_len: usize,
515        ) -> (Vec<i32>, Vec<i32>) {
516            let mut counts = vec![0; self.size() as usize];
517            let mut displs = vec![0; self.size() as usize];
518            let chunk_size = unflattened_len / self.size() as usize;
519            let surplus = unflattened_len % self.size() as usize;
520            for i in 0..self.size() as usize {
521                counts[i] = if i < surplus {
522                    (chunk_size + 1) * internal_len
523                } else {
524                    chunk_size * internal_len
525                } as i32;
526                displs[i] = if i == 0 {
527                    0
528                } else {
529                    displs[i - 1] + counts[i - 1]
530                };
531            }
532            (counts, displs)
533        }
534    }
535}
536
537use thiserror::Error;
538
539/// Core amplitude traits, identifiers, and expression-facing compatibility exports.
540pub mod amplitude;
541/// Methods for loading and manipulating event datasets.
542pub mod data;
543/// Execution-policy and thread-pool coordination helpers.
544pub mod execution;
545/// Expression trees, compiled diagnostics, and evaluator interfaces.
546pub mod expression;
547/// Kinematic frame helpers and angle containers.
548pub mod kinematics;
549/// Special functions and numerical helpers.
550pub mod math;
551/// Parameter handles, identifiers, and assembled parameter storage.
552pub mod parameters;
553/// Quantum-number helpers and discrete analysis enums.
554pub mod quantum;
555/// Reaction topology, particles, and decay-node helpers.
556pub mod reaction;
557/// Structures for manipulating the cache and free parameters.
558pub mod resources;
559/// Event variables derived from reactions and particle selections.
560pub mod variables;
561/// Three- and four-vector types used throughout the library.
562pub mod vectors;
563/// Useful traits for all crate structs
564pub mod traits {
565    pub use crate::{amplitude::Amplitude, variables::Variable};
566}
567
568pub use amplitude::{Amplitude, AmplitudeID, AmplitudeSemanticField, AmplitudeSemanticKey};
569
570#[cfg(feature = "execution-context-prototype")]
571pub use crate::execution::{ExecutionContext, ScratchAllocator, ThreadPolicy};
572pub use crate::{
573    data::{
574        BinnedDataset, Dataset, DatasetMetadata, DatasetReadOptions, Event, EventData, OwnedEvent,
575    },
576    execution::ThreadPoolManager,
577    expression::{CompiledExpression, CompiledExpressionNode, Evaluator, Expression},
578    kinematics::{DecayAngles, FrameAxes, RestFrame},
579    parameters::{Parameter, ParameterID, ParameterMap, Parameters},
580    quantum::{
581        allowed_projections, AllowedPartialWave, AngularMomentum, Channel, Charge, Frame, Isospin,
582        OrbitalAngularMomentum, Parity, PartialWave, ParticleProperties, Projection, Reflectivity,
583        RuleSet, SelectionRules, SpinState, Statistics,
584    },
585    reaction::{
586        Decay, Particle, ParticleGraph, ParticleSource, Production, Reaction, ReactionTopology,
587        ResolvedTwoToTwo, TwoToTwoReaction,
588    },
589    resources::{
590        Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, Resources, ScalarID,
591        VectorID,
592    },
593    variables::{
594        Angles, CosTheta, IntoP4Selection, Mandelstam, Mass, P4Selection, Phi, PolAngle,
595        PolMagnitude, Polarization,
596    },
597    vectors::{Vec3, Vec4},
598};
599
600/// The mathematical constant $`\pi`$.
601pub const PI: f64 = std::f64::consts::PI;
602
603/// A [`Result`] type alias for [`LadduError`]s.
604pub type LadduResult<T> = Result<T, LadduError>;
605
606/// The error type used by all `laddu` internal methods
607#[derive(Error, Debug)]
608pub enum LadduError {
609    /// An alias for [`std::io::Error`].
610    #[error(transparent)]
611    IOError(#[from] std::io::Error),
612    /// An alias for [`parquet::errors::ParquetError`].
613    #[error(transparent)]
614    ParquetError(#[from] parquet::errors::ParquetError),
615    /// An alias for [`arrow::error::ArrowError`].
616    #[error(transparent)]
617    ArrowError(#[from] arrow::error::ArrowError),
618    /// An alias for [`shellexpand::LookupError`].
619    #[error(transparent)]
620    LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
621    /// An error which occurs when the user tries to register two amplitudes by the same name.
622    #[error("An amplitude by the name \"{name}\" is already registered!")]
623    RegistrationError {
624        /// Name of amplitude which is already registered
625        name: String,
626    },
627    /// An error which occurs when the user tries to select an unregistered amplitude tag.
628    #[error("No registered amplitude with tag \"{name}\"!")]
629    AmplitudeNotFoundError {
630        /// Name of amplitude which failed lookup
631        name: String,
632    },
633    /// An error which occurs when the user tries to parse an invalid string of text, typically
634    /// into an enum variant.
635    #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
636    ParseError {
637        /// The string which was parsed
638        name: String,
639        /// The name of the object it failed to parse into
640        object: String,
641    },
642    /// An error returned by internal bitcode serialization
643    #[error(transparent)]
644    BitcodeError(#[from] bitcode::Error),
645    /// An error returned by the Python pickle (de)serializer
646    #[error(transparent)]
647    PickleError(#[from] serde_pickle::Error),
648    /// An error which occurs when parameter definitions conflict or clash.
649    #[error("Parameter \"{name}\" conflict: {reason}")]
650    ParameterConflict {
651        /// Name of parameter
652        name: String,
653        /// Description of conflict
654        reason: String,
655    },
656    /// An error which occurs when attempting to use an unregistered or unnamed parameter.
657    #[error("Parameter \"{name}\" could not be registered: {reason}")]
658    UnregisteredParameter {
659        /// Name of parameter
660        name: String,
661        /// Reason for failure
662        reason: String,
663    },
664    /// An error which occurs during execution-context setup.
665    #[error("Execution context setup failed: {reason}")]
666    ExecutionContextError {
667        /// Description of setup failure
668        reason: String,
669    },
670    /// An error type for [`rayon`] thread pools
671    #[cfg(feature = "rayon")]
672    #[error(transparent)]
673    ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
674    /// An error type for [`numpy`]-related conversions
675    #[cfg(feature = "numpy")]
676    #[error(transparent)]
677    NumpyError(#[from] numpy::FromVecError),
678    /// A required column was not found in the input
679    #[error("Required column \"{name}\" was not found in the dataset")]
680    MissingColumn {
681        /// Name of the missing column
682        name: String,
683    },
684    /// A column has an unsupported type
685    #[error("Column \"{name}\" has unsupported type \"{datatype}\"")]
686    InvalidColumnType {
687        /// Column name
688        name: String,
689        /// Detected data type
690        datatype: String,
691    },
692    /// A value has an unexpected length.
693    #[error("{context} length mismatch: expected {expected}, received {actual}")]
694    LengthMismatch {
695        /// Description of what length was validated.
696        context: String,
697        /// Expected length.
698        expected: usize,
699        /// Actual length observed.
700        actual: usize,
701    },
702    /// A duplicate name was provided for p4 or aux data
703    #[error("Duplicate {category} name \"{name}\" provided")]
704    DuplicateName {
705        /// Category (p4 or aux)
706        category: &'static str,
707        /// Duplicate name
708        name: String,
709    },
710    /// An unknown name was referenced (e.g., for boosts)
711    #[error("Unknown {category} name \"{name}\"")]
712    UnknownName {
713        /// Category (p4 or aux)
714        category: &'static str,
715        /// Name that could not be resolved
716        name: String,
717    },
718    /// A particle is missing the requested property
719    #[error("Particle is missing the requested property \"{property}\"")]
720    MissingParticleProperty {
721        /// The name of the missing property
722        property: &'static str,
723    },
724    /// A custom fallback error for errors too complex or too infrequent to warrant their own error
725    /// category.
726    #[error("{0}")]
727    Custom(String),
728}
729
730/// Validate the number of free parameters supplied to a public entrypoint.
731pub fn validate_free_parameter_len(input_len: usize, expected_len: usize) -> LadduResult<()> {
732    if input_len != expected_len {
733        return Err(LadduError::LengthMismatch {
734            context: "free parameter vector".to_string(),
735            expected: expected_len,
736            actual: input_len,
737        });
738    }
739    Ok(())
740}
741
742impl Clone for LadduError {
743    // This is a little hack because error types are rarely cloneable, but I need to store them in a
744    // cloneable box for minimizers and MCMC methods
745    fn clone(&self) -> Self {
746        let err_string = self.to_string();
747        LadduError::Custom(err_string)
748    }
749}
750
751#[cfg(feature = "python")]
752impl From<LadduError> for PyErr {
753    fn from(err: LadduError) -> Self {
754        use pyo3::exceptions::*;
755        let err_string = err.to_string();
756        match err {
757            LadduError::LookupError(_)
758            | LadduError::RegistrationError { .. }
759            | LadduError::AmplitudeNotFoundError { .. }
760            | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
761            LadduError::ParquetError(_)
762            | LadduError::ArrowError(_)
763            | LadduError::IOError(_)
764            | LadduError::BitcodeError(_)
765            | LadduError::PickleError(_) => PyIOError::new_err(err_string),
766            LadduError::MissingColumn { .. } | LadduError::UnknownName { .. } => {
767                PyKeyError::new_err(err_string)
768            }
769            LadduError::InvalidColumnType { .. }
770            | LadduError::LengthMismatch { .. }
771            | LadduError::DuplicateName { .. }
772            | LadduError::ParameterConflict { .. }
773            | LadduError::UnregisteredParameter { .. } => PyValueError::new_err(err_string),
774            LadduError::ExecutionContextError { .. } => PyRuntimeError::new_err(err_string),
775            LadduError::Custom(_) => PyRuntimeError::new_err(err_string),
776            #[cfg(feature = "rayon")]
777            LadduError::ThreadPoolError(_) => PyRuntimeError::new_err(err_string),
778            #[cfg(feature = "numpy")]
779            LadduError::NumpyError(_) => PyValueError::new_err(err_string),
780            LadduError::MissingParticleProperty { .. } => PyValueError::new_err(err_string),
781        }
782    }
783}