1#![warn(clippy::perf, clippy::style, missing_docs)]
5#![allow(clippy::excessive_precision)]
6#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
7
8use ganesh::core::{MCMCSummary, MinimizationSummary};
9#[cfg(feature = "python")]
10use pyo3::PyErr;
11
12pub use std::f64;
15
16#[cfg(feature = "mpi")]
39#[cfg_attr(coverage_nightly, coverage(off))]
40pub mod mpi {
41 use std::sync::atomic::{AtomicBool, Ordering};
42 use std::sync::OnceLock;
43
44 use lazy_static::lazy_static;
45 use mpi::datatype::PartitionMut;
46 use mpi::environment::Universe;
47 use mpi::topology::{Process, SimpleCommunicator};
48 use mpi::traits::{Communicator, CommunicatorCollectives, Equivalence};
49 use parking_lot::RwLock;
50
51 lazy_static! {
52 static ref USE_MPI: AtomicBool = AtomicBool::new(false);
53 }
54
55 static MPI_UNIVERSE: OnceLock<RwLock<Option<Universe>>> = OnceLock::new();
56
57 pub const ROOT_RANK: i32 = 0;
59
60 pub fn is_root() -> bool {
62 if let Some(world) = crate::mpi::get_world() {
63 world.rank() == ROOT_RANK
64 } else {
65 true
66 }
67 }
68
69 pub fn get_world() -> Option<SimpleCommunicator> {
72 if let Some(universe_lock) = MPI_UNIVERSE.get() {
73 if let Some(universe) = &*universe_lock.read() {
74 return Some(universe.world());
75 }
76 }
77 None
78 }
79
80 pub fn get_rank() -> i32 {
82 get_world().map(|w| w.rank()).unwrap_or(ROOT_RANK)
83 }
84
85 pub fn get_size() -> i32 {
87 get_world().map(|w| w.size()).unwrap_or(1)
88 }
89
90 pub fn use_mpi(trigger: bool) {
135 if trigger {
136 USE_MPI.store(true, Ordering::SeqCst);
137 MPI_UNIVERSE.get_or_init(|| {
138 #[cfg(feature = "rayon")]
139 let threading = mpi::Threading::Funneled;
140 #[cfg(not(feature = "rayon"))]
141 let threading = mpi::Threading::Single;
142 let (universe, _threading) = mpi::initialize_with_threading(threading).unwrap();
143 let world = universe.world();
144 if world.size() == 1 {
145 eprintln!("Warning: MPI is enabled, but only one process is available. MPI will not be used, but single-CPU parallelism may still be used if enabled.");
146 finalize_mpi();
147 USE_MPI.store(false, Ordering::SeqCst);
148 RwLock::new(None)
149 } else {
150 RwLock::new(Some(universe))
151 }
152 });
153 }
154 }
155
156 pub fn finalize_mpi() {
167 if using_mpi() {
168 let mut universe = MPI_UNIVERSE.get().unwrap().write();
169 *universe = None;
170 }
171 }
172
173 pub fn using_mpi() -> bool {
175 USE_MPI.load(Ordering::SeqCst)
176 }
177
178 fn counts_displs(size: usize, total: usize, stride: usize) -> (Vec<i32>, Vec<i32>) {
179 let mut counts = vec![0i32; size];
180 let mut displs = vec![0i32; size];
181 if size == 0 {
182 return (counts, displs);
183 }
184 let base = total / size;
185 let remainder = total % size;
186 let mut offset = 0i32;
187 for rank in 0..size {
188 let n = if rank < remainder { base + 1 } else { base };
189 let scaled = (n * stride) as i32;
190 counts[rank] = scaled;
191 displs[rank] = offset;
192 offset += scaled;
193 }
194 (counts, displs)
195 }
196
197 #[inline]
198 fn rank_local_from_global(i_global: usize, size: usize, total: usize) -> (usize, usize) {
199 assert!(size > 0, "Communicator must have at least one rank");
200 assert!(total > 0, "Cannot map global indices when dataset is empty");
201 assert!(
202 i_global < total,
203 "Global index {} out of bounds for {} events",
204 i_global,
205 total
206 );
207 let base = total / size;
208 let remainder = total % size;
209 let big_block = base + 1;
210 let threshold = remainder * big_block;
211 if i_global < threshold {
212 let rank = i_global / big_block;
213 let local = i_global % big_block;
214 (rank, local)
215 } else {
216 let adjusted = i_global - threshold;
217 let rank = remainder + adjusted / base;
218 let local = adjusted % base;
219 (rank, local)
220 }
221 }
222
223 pub trait LadduMPI {
225 fn process_at_root(&self) -> Process<'_>;
227 fn is_root(&self) -> bool;
229 fn all_gather_partitioned<T: Equivalence + Default + Clone>(
232 &self,
233 local: &[T],
234 total: usize,
235 stride: Option<usize>,
236 ) -> Vec<T>;
237 fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
240 &self,
241 local: &[T],
242 global_indices: &[usize],
243 total: usize,
244 stride: Option<usize>,
245 ) -> Vec<T>;
246 fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize);
249 fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize>;
252 fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>);
255 fn get_flattened_counts_displs(
261 &self,
262 unflattened_len: usize,
263 internal_len: usize,
264 ) -> (Vec<i32>, Vec<i32>);
265 }
266
267 impl LadduMPI for SimpleCommunicator {
268 fn process_at_root(&self) -> Process<'_> {
269 self.process_at_rank(crate::mpi::ROOT_RANK)
270 }
271
272 fn is_root(&self) -> bool {
273 self.rank() == crate::mpi::ROOT_RANK
274 }
275
276 fn all_gather_partitioned<T: Equivalence + Default + Clone>(
279 &self,
280 local: &[T],
281 total: usize,
282 stride: Option<usize>,
283 ) -> Vec<T> {
284 let size = self.size() as usize;
285 let stride = stride.unwrap_or(1);
286 assert!(stride > 0, "Stride must be greater than zero");
287 let mut out = vec![T::default(); total * stride];
288 if total == 0 || size == 0 {
289 return out;
290 }
291 let (counts, displs) = counts_displs(size, total, stride);
292 {
293 let mut partition = PartitionMut::new(&mut out, counts, displs);
294 self.all_gather_varcount_into(local, &mut partition);
295 }
296 out
297 }
298
299 fn all_gather_batched_partitioned<T: Equivalence + Default + Clone>(
302 &self,
303 local: &[T],
304 global_indices: &[usize],
305 total: usize,
306 stride: Option<usize>,
307 ) -> Vec<T> {
308 let size = self.size() as usize;
309 let stride = stride.unwrap_or(1);
310 assert!(stride > 0, "Stride must be greater than zero");
311 let n_indices = global_indices.len();
312 let mut gathered = vec![T::default(); n_indices * stride];
313 if n_indices == 0 || size == 0 {
314 return gathered;
315 }
316
317 assert!(
318 total > 0,
319 "Cannot gather batched data from an empty dataset"
320 );
321
322 let mut locals_by_rank = vec![Vec::<usize>::new(); size];
323 let mut targets_by_rank = vec![Vec::<usize>::new(); size];
324 for (position, &global_index) in global_indices.iter().enumerate() {
325 let (rank, local_index) = rank_local_from_global(global_index, size, total);
326 locals_by_rank[rank].push(local_index);
327 targets_by_rank[rank].push(position);
328 }
329
330 let mut counts = vec![0i32; size];
331 let mut displs = vec![0i32; size];
332 for rank in 0..size {
333 counts[rank] = (locals_by_rank[rank].len() * stride) as i32;
334 displs[rank] = if rank == 0 {
335 0
336 } else {
337 displs[rank - 1] + counts[rank - 1]
338 };
339 }
340
341 let expected_local = locals_by_rank[self.rank() as usize].len() * stride;
342 debug_assert_eq!(
343 local.len(),
344 expected_local,
345 "Local buffer length does not match expected gathered size for rank {}",
346 self.rank()
347 );
348
349 {
350 let mut partition =
351 PartitionMut::new(&mut gathered, counts.clone(), displs.clone());
352 self.all_gather_varcount_into(local, &mut partition);
353 }
354
355 let mut result = vec![T::default(); n_indices * stride];
356 for rank in 0..size {
357 let mut cursor = displs[rank] as usize;
358 for &target in &targets_by_rank[rank] {
359 let dst = target * stride;
360 for offset in 0..stride {
361 result[dst + offset] = gathered[cursor + offset].clone();
362 }
363 cursor += stride;
364 }
365 }
366
367 result
368 }
369
370 fn owner_of_global_index(&self, global_index: usize, total: usize) -> (i32, usize) {
371 assert!(total > 0, "Cannot look up ownership in an empty dataset");
372 let size = self.size() as usize;
373 let (rank, local) = rank_local_from_global(global_index, size, total);
374 (rank as i32, local)
375 }
376
377 fn locals_from_globals(&self, global_indices: &[usize], total: usize) -> Vec<usize> {
380 let size = self.size() as usize;
381 let this_rank = self.rank() as usize;
382 let mut locals = Vec::new();
383 if total == 0 {
384 return locals;
385 }
386 for &global_index in global_indices {
387 let (rank, local_index) = rank_local_from_global(global_index, size, total);
388 if rank == this_rank {
389 locals.push(local_index);
390 }
391 }
392 locals
393 }
394 fn get_counts_displs(&self, buf_len: usize) -> (Vec<i32>, Vec<i32>) {
395 let mut counts = vec![0; self.size() as usize];
396 let mut displs = vec![0; self.size() as usize];
397 let chunk_size = buf_len / self.size() as usize;
398 let surplus = buf_len % self.size() as usize;
399 for i in 0..self.size() as usize {
400 counts[i] = if i < surplus {
401 chunk_size + 1
402 } else {
403 chunk_size
404 } as i32;
405 displs[i] = if i == 0 {
406 0
407 } else {
408 displs[i - 1] + counts[i - 1]
409 };
410 }
411 (counts, displs)
412 }
413
414 fn get_flattened_counts_displs(
415 &self,
416 unflattened_len: usize,
417 internal_len: usize,
418 ) -> (Vec<i32>, Vec<i32>) {
419 let mut counts = vec![0; self.size() as usize];
420 let mut displs = vec![0; self.size() as usize];
421 let chunk_size = unflattened_len / self.size() as usize;
422 let surplus = unflattened_len % self.size() as usize;
423 for i in 0..self.size() as usize {
424 counts[i] = if i < surplus {
425 (chunk_size + 1) * internal_len
426 } else {
427 chunk_size * internal_len
428 } as i32;
429 displs[i] = if i == 0 {
430 0
431 } else {
432 displs[i - 1] + counts[i - 1]
433 };
434 }
435 (counts, displs)
436 }
437 }
438}
439
440use thiserror::Error;
441
442pub mod amplitudes;
444pub mod data;
446pub mod resources;
448pub mod utils;
450pub mod traits {
452 pub use crate::amplitudes::Amplitude;
453 pub use crate::utils::variables::Variable;
454 pub use crate::ReadWrite;
455}
456
457pub use crate::data::{
458 BinnedDataset, Dataset, DatasetMetadata, DatasetReadOptions, Event, EventData,
459};
460pub use crate::resources::{
461 Cache, ComplexMatrixID, ComplexScalarID, ComplexVectorID, MatrixID, ParameterID, Parameters,
462 Resources, ScalarID, VectorID,
463};
464pub use crate::utils::enums::{Channel, Frame, Sign};
465pub use crate::utils::variables::{
466 Angles, CosTheta, Mandelstam, Mass, Phi, PolAngle, PolMagnitude, Polarization,
467};
468pub use crate::utils::vectors::{Vec3, Vec4};
469pub use amplitudes::{constant, parameter, AmplitudeID, Evaluator, Expression, ParameterLike};
470
471pub const PI: f64 = std::f64::consts::PI;
473
474pub type LadduResult<T> = Result<T, LadduError>;
476
477#[derive(Error, Debug)]
479pub enum LadduError {
480 #[error("IO Error: {0}")]
482 IOError(#[from] std::io::Error),
483 #[error("Parquet Error: {0}")]
485 ParquetError(#[from] parquet::errors::ParquetError),
486 #[error("Arrow Error: {0}")]
488 ArrowError(#[from] arrow::error::ArrowError),
489 #[error("Failed to expand path: {0}")]
491 LookupError(#[from] shellexpand::LookupError<std::env::VarError>),
492 #[error("An amplitude by the name \"{name}\" is already registered!")]
494 RegistrationError {
495 name: String,
497 },
498 #[error("No registered amplitude with name \"{name}\"!")]
500 AmplitudeNotFoundError {
501 name: String,
503 },
504 #[error("Failed to parse string: \"{name}\" does not correspond to a valid \"{object}\"!")]
507 ParseError {
508 name: String,
510 object: String,
512 },
513 #[error("Encoder error: {0}")]
515 EncodeError(#[from] bincode::error::EncodeError),
516 #[error("Decoder error: {0}")]
518 DecodeError(#[from] bincode::error::DecodeError),
519 #[error("Pickle conversion error: {0}")]
521 PickleError(#[from] serde_pickle::Error),
522 #[error("Parameter \"{name}\" conflict: {reason}")]
524 ParameterConflict {
525 name: String,
527 reason: String,
529 },
530 #[error("Parameter \"{name}\" could not be registered: {reason}")]
532 UnregisteredParameter {
533 name: String,
535 reason: String,
537 },
538 #[cfg(feature = "rayon")]
540 #[error("Error building thread pool: {0}")]
541 ThreadPoolError(#[from] rayon::ThreadPoolBuildError),
542 #[cfg(feature = "numpy")]
544 #[error("Numpy error: {0}")]
545 NumpyError(#[from] numpy::FromVecError),
546 #[error("Required column \"{name}\" was not found in the dataset")]
548 MissingColumn {
549 name: String,
551 },
552 #[error("Column \"{name}\" has unsupported type \"{datatype}\"")]
554 InvalidColumnType {
555 name: String,
557 datatype: String,
559 },
560 #[error("Duplicate {category} name \"{name}\" provided")]
562 DuplicateName {
563 category: &'static str,
565 name: String,
567 },
568 #[error("Unknown {category} name \"{name}\"")]
570 UnknownName {
571 category: &'static str,
573 name: String,
575 },
576 #[error("{0}")]
579 Custom(String),
580}
581
582impl Clone for LadduError {
583 fn clone(&self) -> Self {
586 let err_string = self.to_string();
587 LadduError::Custom(err_string)
588 }
589}
590
591#[cfg(feature = "python")]
592impl From<LadduError> for PyErr {
593 fn from(err: LadduError) -> Self {
594 use pyo3::exceptions::*;
595 let err_string = err.to_string();
596 match err {
597 LadduError::LookupError(_)
598 | LadduError::RegistrationError { .. }
599 | LadduError::AmplitudeNotFoundError { .. }
600 | LadduError::ParseError { .. } => PyValueError::new_err(err_string),
601 LadduError::ParquetError(_)
602 | LadduError::ArrowError(_)
603 | LadduError::IOError(_)
604 | LadduError::EncodeError(_)
605 | LadduError::DecodeError(_)
606 | LadduError::PickleError(_) => PyIOError::new_err(err_string),
607 LadduError::MissingColumn { .. } | LadduError::UnknownName { .. } => {
608 PyKeyError::new_err(err_string)
609 }
610 LadduError::InvalidColumnType { .. }
611 | LadduError::DuplicateName { .. }
612 | LadduError::ParameterConflict { .. }
613 | LadduError::UnregisteredParameter { .. } => PyValueError::new_err(err_string),
614 LadduError::Custom(_) => PyException::new_err(err_string),
615 #[cfg(feature = "rayon")]
616 LadduError::ThreadPoolError(_) => PyException::new_err(err_string),
617 #[cfg(feature = "numpy")]
618 LadduError::NumpyError(_) => PyException::new_err(err_string),
619 }
620 }
621}
622
623use serde::{de::DeserializeOwned, Serialize};
624use std::fmt::Debug;
625pub trait ReadWrite: Serialize + DeserializeOwned {
629 fn create_null() -> Self;
632}
633impl ReadWrite for MCMCSummary {
634 fn create_null() -> Self {
635 MCMCSummary::default()
636 }
637}
638impl ReadWrite for MinimizationSummary {
639 fn create_null() -> Self {
640 MinimizationSummary::default()
641 }
642}