Skip to main content

laddu_python/
lib.rs

1#![warn(clippy::perf, clippy::style)]
2#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
3use laddu_core::ThreadPoolManager;
4use pyo3::prelude::*;
5use pyo3::types::PyDict;
6
7/// Returns the number of CPUs (logical cores) available for use by ``laddu``.
8///
9#[pyfunction]
10pub fn available_parallelism() -> usize {
11    num_cpus::get()
12}
13
14/// Set the process-global default thread count for Python APIs that accept ``threads=``.
15///
16/// Parameters
17/// ----------
18/// n_threads : int, optional
19///     The default number of threads to use for omitted thread arguments and for ``threads=0``.
20///     Setting this to ``0`` or ``None`` resets the default to "all available CPUs".
21///
22/// Notes
23/// -----
24/// Explicit positive ``threads=`` arguments on individual calls override this default.
25///
26#[pyfunction]
27pub fn set_threads(n_threads: Option<usize>) {
28    ThreadPoolManager::set_global_thread_count(n_threads.unwrap_or(0));
29}
30
31/// Return the process-global default thread count used by omitted or zero-valued thread requests.
32///
33/// Returns ``0`` when the default is "all available CPUs".
34#[pyfunction]
35pub fn get_threads() -> usize {
36    ThreadPoolManager::global_thread_count().unwrap_or(0)
37}
38
39#[cfg_attr(coverage_nightly, coverage(off))]
40pub mod amplitudes;
41#[cfg_attr(coverage_nightly, coverage(off))]
42pub mod data;
43#[cfg_attr(coverage_nightly, coverage(off))]
44pub mod utils;
45
46#[cfg_attr(coverage_nightly, coverage(off))]
47pub mod mpi {
48    #[cfg(not(feature = "mpi"))]
49    use pyo3::exceptions::PyModuleNotFoundError;
50
51    use super::*;
52    /// Check if ``laddu`` was compiled with MPI support (returns ``True`` if it was).
53    ///
54    /// Since ``laddu-mpi`` has the same namespace as ``laddu`` (they both are imported with
55    /// ``import laddu``), this method can be used to check if MPI capabilities are available
56    /// without actually running any MPI code. While functions in the ``laddu.mpi`` module will
57    /// raise an ``ModuleNotFoundError`` if MPI is not supported, its sometimes convenient to have
58    /// a simple boolean check rather than a try-catch block, and this method provides that.
59    ///
60    #[pyfunction]
61    pub fn is_mpi_available() -> bool {
62        #[cfg(feature = "mpi")]
63        return true;
64        #[cfg(not(feature = "mpi"))]
65        return false;
66    }
67    /// Use the Message Passing Interface (MPI) to run on a distributed system
68    ///
69    /// Parameters
70    /// ----------
71    /// trigger: bool, default=True
72    ///     An optional parameter which allows MPI to only be used under some boolean
73    ///     condition.
74    ///
75    /// Notes
76    /// -----
77    /// You must have MPI installed for this to work, and you must call the program with
78    /// ``mpirun <executable>``, or bad things will happen.
79    ///
80    /// MPI runs an identical program on each process, but gives the program an ID called its
81    /// "rank". Only the results of methods on the root process (rank 0) should be
82    /// considered valid, as other processes only contain portions of each dataset. To ensure
83    /// you don't save or print data at other ranks, use the provided ``laddu.mpi.is_root()``
84    /// method to check if the process is the root process.
85    ///
86    /// Once MPI is enabled, it cannot be disabled. If MPI could be toggled (which it can't),
87    /// the other processes will still run, but they will be independent of the root process
88    /// and will no longer communicate with it. The root process stores no data, so it would
89    /// be difficult (and convoluted) to get the results which were already processed via
90    /// MPI.
91    ///
92    /// Additionally, MPI must be enabled at the beginning of a script, at least before any
93    /// other ``laddu`` functions are called. For this reason, it is suggested that you use the
94    /// context manager ``laddu.mpi.MPI`` to ensure the MPI backend is used properly.
95    ///
96    /// If ``laddu.mpi.use_mpi()`` is called multiple times, the subsequent calls will have no
97    /// effect.
98    ///
99    /// You **must** call ``laddu.mpi.finalize_mpi()`` before your program exits for MPI to terminate
100    /// smoothly.
101    ///
102    /// See Also
103    /// --------
104    /// laddu.mpi.MPI
105    /// laddu.mpi.using_mpi
106    /// laddu.mpi.is_root
107    /// laddu.mpi.get_rank
108    /// laddu.mpi.get_size
109    /// laddu.mpi.finalize_mpi
110    ///
111    #[pyfunction]
112    #[allow(unused_variables)]
113    #[pyo3(signature = (*, trigger=true))]
114    pub fn use_mpi(trigger: bool) -> PyResult<()> {
115        #[cfg(feature = "mpi")]
116        {
117            laddu_core::mpi::use_mpi(trigger);
118            Ok(())
119        }
120        #[cfg(not(feature = "mpi"))]
121        return Err(PyModuleNotFoundError::new_err(
122            "`laddu` was not compiled with MPI support! Please use `laddu-mpi` instead.",
123        ));
124    }
125
126    /// Drop the MPI universe and finalize MPI at the end of a program
127    ///
128    /// This should only be called once and should be called at the end of all ``laddu``-related
129    /// function calls. This **must** be called at the end of any program which uses MPI.
130    ///
131    /// See Also
132    /// --------
133    /// laddu.mpi.use_mpi
134    ///
135    #[pyfunction]
136    pub fn finalize_mpi() -> PyResult<()> {
137        #[cfg(feature = "mpi")]
138        {
139            laddu_core::mpi::finalize_mpi();
140            Ok(())
141        }
142        #[cfg(not(feature = "mpi"))]
143        return Err(PyModuleNotFoundError::new_err(
144            "`laddu` was not compiled with MPI support! Please use `laddu-mpi` instead.",
145        ));
146    }
147
148    /// Check if MPI is enabled
149    ///
150    /// This can be combined with ``laddu.mpi.is_root()`` to ensure valid results are only
151    /// returned from the root rank process on the condition that MPI is enabled.
152    ///
153    /// See Also
154    /// --------
155    /// laddu.mpi.use_mpi
156    /// laddu.mpi.is_root
157    ///
158    #[pyfunction]
159    pub fn using_mpi() -> bool {
160        #[cfg(feature = "mpi")]
161        return laddu_core::mpi::using_mpi();
162        #[cfg(not(feature = "mpi"))]
163        return false;
164    }
165
166    /// Check if the current MPI process is the root process
167    ///
168    /// This can be combined with ``laddu.mpi.using_mpi()`` to ensure valid results are only
169    /// returned from the root rank process on the condition that MPI is enabled.
170    ///
171    /// See Also
172    /// --------
173    /// laddu.mpi.use_mpi
174    /// laddu.mpi.using_mpi
175    ///
176    #[pyfunction]
177    pub fn is_root() -> bool {
178        #[cfg(feature = "mpi")]
179        return laddu_core::mpi::is_root();
180        #[cfg(not(feature = "mpi"))]
181        return true;
182    }
183
184    /// Get the rank of the current MPI process
185    ///
186    /// Returns ``0`` if MPI is not enabled
187    ///
188    /// See Also
189    /// --------
190    /// laddu.mpi.use_mpi
191    ///
192    #[pyfunction]
193    pub fn get_rank() -> i32 {
194        #[cfg(feature = "mpi")]
195        return laddu_core::mpi::get_rank();
196        #[cfg(not(feature = "mpi"))]
197        return 0;
198    }
199
200    /// Get the total number of MPI processes (including the root process)
201    ///
202    /// Returns ``1`` if MPI is not enabled
203    ///
204    /// See Also
205    /// --------
206    /// laddu.mpi.use_mpi
207    ///
208    #[pyfunction]
209    pub fn get_size() -> i32 {
210        #[cfg(feature = "mpi")]
211        return laddu_core::mpi::get_size();
212        #[cfg(not(feature = "mpi"))]
213        return 1;
214    }
215}
216
217pub trait GetStrExtractObj {
218    fn get_extract<T>(&self, key: &str) -> PyResult<Option<T>>
219    where
220        T: for<'a, 'py> FromPyObject<'a, 'py, Error = PyErr>;
221}
222
223#[cfg_attr(coverage_nightly, coverage(off))]
224impl GetStrExtractObj for Bound<'_, PyDict> {
225    fn get_extract<T>(&self, key: &str) -> PyResult<Option<T>>
226    where
227        T: for<'a, 'py> FromPyObject<'a, 'py, Error = PyErr>,
228    {
229        self.get_item(key)?
230            .map(|value| value.extract::<T>())
231            .transpose()
232    }
233}