Skip to main content

laddu_extensions/likelihood/
nll.rs

1#[cfg(feature = "rayon")]
2use std::cell::RefCell;
3#[cfg(feature = "rayon")]
4use std::fmt::Debug;
5use std::{collections::HashMap, sync::Arc};
6
7use accurate::{sum::Klein, traits::*};
8use fastrand::Rng;
9#[cfg(feature = "mpi")]
10use laddu_core::mpi::LadduMPI;
11use laddu_core::{
12    amplitude::{CompiledExpression, Evaluator, Expression, ParameterMap},
13    data::Dataset,
14    validate_free_parameter_len, LadduError, LadduResult,
15};
16#[cfg(feature = "mpi")]
17use mpi::{
18    collective::SystemOperation, datatype::PartitionMut, topology::SimpleCommunicator, traits::*,
19};
20use nalgebra::DVector;
21use num::complex::Complex64;
22use parking_lot::Mutex;
23#[cfg(feature = "rayon")]
24use rayon::prelude::*;
25
26use super::term::LikelihoodTerm;
27use crate::random::RngSubsetExtension;
28
29pub(crate) type ProjectionMaskCacheKey = (bool, Vec<String>);
30
31pub(crate) fn validate_stochastic_batch_size(
32    batch_size: usize,
33    n_events: usize,
34) -> LadduResult<()> {
35    if n_events == 0 {
36        return Err(LadduError::Custom(
37            "stochastic batch_size requires a non-empty dataset".to_string(),
38        ));
39    }
40    if batch_size == 0 || batch_size > n_events {
41        return Err(LadduError::LengthMismatch {
42            context: format!("stochastic batch_size (valid range: 1..={n_events})"),
43            expected: n_events,
44            actual: batch_size,
45        });
46    }
47    Ok(())
48}
49
50#[cfg(feature = "mpi")]
51pub(crate) fn reduce_scalar(world: &SimpleCommunicator, value: f64) -> f64 {
52    let mut reduced = 0.0;
53    world.all_reduce_into(&value, &mut reduced, SystemOperation::sum());
54    reduced
55}
56
57#[cfg(feature = "mpi")]
58pub(crate) fn reduce_gradient(world: &SimpleCommunicator, gradient: &DVector<f64>) -> DVector<f64> {
59    let mut reduced = vec![0.0; gradient.len()];
60    world.all_reduce_into(gradient.as_slice(), &mut reduced, SystemOperation::sum());
61    DVector::from_vec(reduced)
62}
63
64pub(crate) fn evaluate_weighted_expression_sum_local<F>(
65    evaluator: &Evaluator,
66    parameters: &[f64],
67    value_map: F,
68) -> LadduResult<f64>
69where
70    F: Fn(Complex64) -> f64 + Copy + Send + Sync,
71{
72    let resources = evaluator.resources.read();
73    let parameters = resources.parameter_map.assemble(parameters)?;
74    let amplitude_len = evaluator.amplitude_value_slot_count();
75    let active_indices = resources.active_indices().to_vec();
76    let program_snapshot = evaluator.expression_value_program_snapshot();
77    let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
78    #[cfg(feature = "rayon")]
79    {
80        Ok(resources
81            .caches
82            .par_iter()
83            .zip(evaluator.dataset.weights_local().par_iter())
84            .map_init(
85                || {
86                    (
87                        vec![Complex64::ZERO; amplitude_len],
88                        vec![Complex64::ZERO; slot_count],
89                    )
90                },
91                |(amplitude_values, expr_slots), (cache, event)| {
92                    evaluator.fill_amplitude_values(
93                        amplitude_values,
94                        &active_indices,
95                        &parameters,
96                        cache,
97                    );
98                    let l = evaluator.evaluate_expression_value_with_program_snapshot(
99                        &program_snapshot,
100                        amplitude_values,
101                        expr_slots,
102                    );
103                    *event * value_map(l)
104                },
105            )
106            .parallel_sum_with_accumulator::<Klein<f64>>())
107    }
108    #[cfg(not(feature = "rayon"))]
109    {
110        let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
111        let mut expr_slots = vec![Complex64::ZERO; slot_count];
112        Ok(resources
113            .caches
114            .iter()
115            .zip(evaluator.dataset.weights_local().iter())
116            .map(|(cache, event)| {
117                evaluator.fill_amplitude_values(
118                    &mut amplitude_values,
119                    &active_indices,
120                    &parameters,
121                    cache,
122                );
123                let l = evaluator.evaluate_expression_value_with_program_snapshot(
124                    &program_snapshot,
125                    &amplitude_values,
126                    &mut expr_slots,
127                );
128                *event * value_map(l)
129            })
130            .sum_with_accumulator::<Klein<f64>>())
131    }
132}
133
134pub(crate) fn project_weights_local_from_evaluator(
135    evaluator: &Evaluator,
136    parameters: &[f64],
137    n_mc: f64,
138) -> LadduResult<Vec<f64>> {
139    let resources = evaluator.resources.read();
140    let parameters = resources.parameter_map.assemble(parameters)?;
141    let amplitude_len = evaluator.amplitude_value_slot_count();
142    let active_indices = resources.active_indices().to_vec();
143    let program_snapshot = evaluator.expression_value_program_snapshot();
144    let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
145    #[cfg(feature = "rayon")]
146    {
147        Ok(resources
148            .caches
149            .par_iter()
150            .zip(evaluator.dataset.weights_local().par_iter())
151            .map_init(
152                || {
153                    (
154                        vec![Complex64::ZERO; amplitude_len],
155                        vec![Complex64::ZERO; slot_count],
156                    )
157                },
158                |(amplitude_values, expr_slots), (cache, event)| {
159                    evaluator.fill_amplitude_values(
160                        amplitude_values,
161                        &active_indices,
162                        &parameters,
163                        cache,
164                    );
165                    let value = evaluator.evaluate_expression_value_with_program_snapshot(
166                        &program_snapshot,
167                        amplitude_values,
168                        expr_slots,
169                    );
170                    *event * value.re / n_mc
171                },
172            )
173            .collect())
174    }
175    #[cfg(not(feature = "rayon"))]
176    {
177        let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
178        let mut expr_slots = vec![Complex64::ZERO; slot_count];
179        Ok(resources
180            .caches
181            .iter()
182            .zip(evaluator.dataset.weights_local().iter())
183            .map(|(cache, event)| {
184                evaluator.fill_amplitude_values(
185                    &mut amplitude_values,
186                    &active_indices,
187                    &parameters,
188                    cache,
189                );
190                let value = evaluator.evaluate_expression_value_with_program_snapshot(
191                    &program_snapshot,
192                    &amplitude_values,
193                    &mut expr_slots,
194                );
195                *event * value.re / n_mc
196            })
197            .collect())
198    }
199}
200
201pub(crate) fn project_weights_local_from_resolved_mask(
202    evaluator: &Evaluator,
203    parameters: &[f64],
204    n_mc: f64,
205    resolved_mask: &[bool],
206) -> LadduResult<Vec<f64>> {
207    let resources = evaluator.resources.read();
208    let parameters = resources.parameter_map.assemble(parameters)?;
209    let amplitude_len = evaluator.amplitude_value_slot_count();
210    let active_indices = resolved_mask
211        .iter()
212        .enumerate()
213        .filter_map(|(index, &active)| if active { Some(index) } else { None })
214        .collect::<Vec<_>>();
215    let program_snapshot =
216        evaluator.expression_value_program_snapshot_for_active_mask(resolved_mask)?;
217    let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
218    #[cfg(feature = "rayon")]
219    {
220        Ok(resources
221            .caches
222            .par_iter()
223            .zip(evaluator.dataset.weights_local().par_iter())
224            .map_init(
225                || {
226                    (
227                        vec![Complex64::ZERO; amplitude_len],
228                        vec![Complex64::ZERO; slot_count],
229                    )
230                },
231                |(amplitude_values, expr_slots), (cache, event)| {
232                    evaluator.fill_amplitude_values(
233                        amplitude_values,
234                        &active_indices,
235                        &parameters,
236                        cache,
237                    );
238                    let value = evaluator.evaluate_expression_value_with_program_snapshot(
239                        &program_snapshot,
240                        amplitude_values,
241                        expr_slots,
242                    );
243                    *event * value.re / n_mc
244                },
245            )
246            .collect())
247    }
248    #[cfg(not(feature = "rayon"))]
249    {
250        let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
251        let mut expr_slots = vec![Complex64::ZERO; slot_count];
252        Ok(resources
253            .caches
254            .iter()
255            .zip(evaluator.dataset.weights_local().iter())
256            .map(|(cache, event)| {
257                evaluator.fill_amplitude_values(
258                    &mut amplitude_values,
259                    &active_indices,
260                    &parameters,
261                    cache,
262                );
263                let value = evaluator.evaluate_expression_value_with_program_snapshot(
264                    &program_snapshot,
265                    &amplitude_values,
266                    &mut expr_slots,
267                );
268                *event * value.re / n_mc
269            })
270            .collect())
271    }
272}
273
274pub(crate) fn project_weights_and_gradients_local_from_evaluator(
275    evaluator: &Evaluator,
276    parameters: &[f64],
277    n_mc: f64,
278) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
279    let resources = evaluator.resources.read();
280    let parameters = resources.parameter_map.assemble(parameters)?;
281    let amplitude_len = evaluator.amplitude_value_slot_count();
282    let grad_dim = parameters.len();
283    let active_indices = resources.active_indices().to_vec();
284    let active_mask = resources.active.clone();
285    let slot_count = evaluator.expression_value_gradient_slot_count_public();
286
287    #[cfg(feature = "rayon")]
288    {
289        let weighted = resources
290            .caches
291            .par_iter()
292            .zip(evaluator.dataset.weights_local().par_iter())
293            .map_init(
294                || {
295                    (
296                        vec![Complex64::ZERO; amplitude_len],
297                        vec![DVector::zeros(grad_dim); amplitude_len],
298                        vec![Complex64::ZERO; slot_count],
299                        vec![DVector::zeros(grad_dim); slot_count],
300                    )
301                },
302                |(amplitude_values, gradient_values, value_slots, gradient_slots),
303                 (cache, event)| {
304                    evaluator.fill_amplitude_values_and_gradients(
305                        amplitude_values,
306                        gradient_values,
307                        &active_indices,
308                        &active_mask,
309                        &parameters,
310                        cache,
311                    );
312                    let (value, gradient) = evaluator
313                        .evaluate_expression_value_gradient_with_scratch(
314                            amplitude_values,
315                            gradient_values,
316                            value_slots,
317                            gradient_slots,
318                        );
319                    (
320                        *event * value.re / n_mc,
321                        gradient.map(|g| g.re).scale(*event / n_mc),
322                    )
323                },
324            )
325            .collect::<Vec<_>>();
326        Ok(weighted.into_iter().unzip())
327    }
328    #[cfg(not(feature = "rayon"))]
329    {
330        let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
331        let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
332        let mut value_slots = vec![Complex64::ZERO; slot_count];
333        let mut gradient_slots = vec![DVector::zeros(grad_dim); slot_count];
334        Ok(resources
335            .caches
336            .iter()
337            .zip(evaluator.dataset.weights_local().iter())
338            .map(|(cache, event)| {
339                evaluator.fill_amplitude_values_and_gradients(
340                    &mut amplitude_values,
341                    &mut gradient_values,
342                    &active_indices,
343                    &active_mask,
344                    &parameters,
345                    cache,
346                );
347                let (value, gradient) = evaluator.evaluate_expression_value_gradient_with_scratch(
348                    &amplitude_values,
349                    &gradient_values,
350                    &mut value_slots,
351                    &mut gradient_slots,
352                );
353                (
354                    *event * value.re / n_mc,
355                    gradient.map(|g| g.re).scale(*event / n_mc),
356                )
357            })
358            .unzip())
359    }
360}
361
362#[cfg(feature = "rayon")]
363pub(crate) fn sum_dvectors_parallel(
364    iter: impl rayon::iter::ParallelIterator<Item = DVector<f64>>,
365    len: usize,
366) -> DVector<f64> {
367    iter.reduce(
368        || DVector::zeros(len),
369        |mut accum, value| {
370            accum += value;
371            accum
372        },
373    )
374}
375
376#[cfg(feature = "rayon")]
377#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
378pub(crate) struct GradientScratchKey {
379    n_parameters: usize,
380    n_amplitudes: usize,
381    n_expression_slots: usize,
382}
383
384#[cfg(feature = "rayon")]
385pub(crate) struct GradientScratchWorkspace {
386    amplitude_values: Vec<Complex64>,
387    gradient_values: Vec<DVector<Complex64>>,
388    value_slots: Vec<Complex64>,
389    gradient_slots: Vec<DVector<Complex64>>,
390}
391
392#[cfg(feature = "rayon")]
393impl GradientScratchWorkspace {
394    fn new(key: GradientScratchKey) -> Self {
395        Self {
396            amplitude_values: vec![Complex64::ZERO; key.n_amplitudes],
397            gradient_values: vec![DVector::zeros(key.n_parameters); key.n_amplitudes],
398            value_slots: vec![Complex64::ZERO; key.n_expression_slots],
399            gradient_slots: vec![DVector::zeros(key.n_parameters); key.n_expression_slots],
400        }
401    }
402
403    fn matches_key(&self, key: GradientScratchKey) -> bool {
404        self.amplitude_values.len() == key.n_amplitudes
405            && self.gradient_values.len() == key.n_amplitudes
406            && self.value_slots.len() == key.n_expression_slots
407            && self.gradient_slots.len() == key.n_expression_slots
408            && self
409                .gradient_values
410                .iter()
411                .all(|gradient| gradient.len() == key.n_parameters)
412            && self
413                .gradient_slots
414                .iter()
415                .all(|slot| slot.len() == key.n_parameters)
416    }
417}
418
419#[cfg(feature = "rayon")]
420pub(crate) struct GradientScratchLease {
421    key: GradientScratchKey,
422    workspace: Option<GradientScratchWorkspace>,
423}
424
425#[cfg(feature = "rayon")]
426impl GradientScratchLease {
427    fn workspace_mut(&mut self) -> &mut GradientScratchWorkspace {
428        self.workspace
429            .as_mut()
430            .expect("gradient scratch workspace must be available while leased")
431    }
432}
433
434#[cfg(feature = "rayon")]
435impl Drop for GradientScratchLease {
436    fn drop(&mut self) {
437        if let Some(workspace) = self.workspace.take() {
438            TLS_GRADIENT_SCRATCH_POOL.with(|pool| {
439                pool.borrow_mut().insert(self.key, workspace);
440            });
441        }
442    }
443}
444
445#[cfg(feature = "rayon")]
446pub(crate) fn acquire_gradient_scratch(key: GradientScratchKey) -> GradientScratchLease {
447    let mut workspace = TLS_GRADIENT_SCRATCH_POOL.with(|pool| {
448        pool.borrow_mut()
449            .remove(&key)
450            .unwrap_or_else(|| GradientScratchWorkspace::new(key))
451    });
452    if !workspace.matches_key(key) {
453        workspace = GradientScratchWorkspace::new(key);
454    }
455    GradientScratchLease {
456        key,
457        workspace: Some(workspace),
458    }
459}
460
461#[cfg(feature = "rayon")]
462thread_local! {
463    static TLS_GRADIENT_SCRATCH_POOL: RefCell<HashMap<GradientScratchKey, GradientScratchWorkspace>> =
464        RefCell::new(HashMap::new());
465}
466
467/// An extended, unbinned negative log-likelihood evaluator.
468#[derive(Clone)]
469pub struct NLL {
470    /// The internal [`Evaluator`] for data
471    pub data_evaluator: Evaluator,
472    /// The internal [`Evaluator`] for accepted Monte Carlo
473    pub accmc_evaluator: Evaluator,
474    pub(crate) n_mc: f64,
475    pub(crate) projection_active_mask_cache: Arc<Mutex<HashMap<ProjectionMaskCacheKey, Vec<bool>>>>,
476}
477
478impl NLL {
479    /// Construct an [`NLL`] from an [`Expression`] and two [`Dataset`]s (data and Monte Carlo). This mirrors loading a model but starts from
480    /// the expression directly. The number of Monte Carlo events used in the denominator of the
481    /// normalization integral may also be specified (uses the weighted number of accepted Monte
482    /// Carlo events if None is given).
483    pub fn new(
484        expression: &Expression,
485        ds_data: &Arc<Dataset>,
486        ds_accmc: &Arc<Dataset>,
487        n_mc: Option<f64>,
488    ) -> LadduResult<Box<Self>> {
489        let data_evaluator = expression.load(ds_data)?;
490        let accmc_evaluator = expression.load(ds_accmc)?;
491        Ok(Self {
492            data_evaluator,
493            n_mc: n_mc.unwrap_or(accmc_evaluator.dataset.n_events_weighted()),
494            accmc_evaluator,
495            projection_active_mask_cache: Arc::new(Mutex::new(HashMap::new())),
496        }
497        .into())
498    }
499
500    fn normalized_projection_key<T: AsRef<str>>(names: &[T]) -> Vec<String> {
501        let mut key = names
502            .iter()
503            .map(|name| name.as_ref().to_string())
504            .collect::<Vec<_>>();
505        key.sort_unstable();
506        key.dedup();
507        key
508    }
509
510    fn projection_cache_key<T: AsRef<str>>(names: &[T], strict: bool) -> ProjectionMaskCacheKey {
511        (strict, Self::normalized_projection_key(names))
512    }
513
514    fn resolve_projection_active_mask_for_evaluator<T: AsRef<str>>(
515        evaluator: &Evaluator,
516        names: &[T],
517        strict: bool,
518    ) -> LadduResult<Vec<bool>> {
519        let current_active_mask = evaluator.active_mask();
520        let isolate_result = if strict {
521            evaluator.isolate_many_strict(names)
522        } else {
523            evaluator.isolate_many(names);
524            Ok(())
525        };
526        if let Err(err) = isolate_result {
527            evaluator.set_active_mask(&current_active_mask)?;
528            return Err(err);
529        }
530        let resolved_mask = evaluator.active_mask();
531        evaluator.set_active_mask(&current_active_mask)?;
532        Ok(resolved_mask)
533    }
534
535    fn get_or_build_projection_active_mask<T: AsRef<str>>(
536        &self,
537        names: &[T],
538        strict: bool,
539    ) -> LadduResult<Vec<bool>> {
540        let key = Self::projection_cache_key(names, strict);
541        if let Some(mask) = self.projection_active_mask_cache.lock().get(&key).cloned() {
542            return Ok(mask);
543        }
544
545        let resolved_mask = Self::resolve_projection_active_mask_for_evaluator(
546            &self.accmc_evaluator,
547            names,
548            strict,
549        )?;
550        self.projection_active_mask_cache
551            .lock()
552            .insert(key, resolved_mask.clone());
553        Ok(resolved_mask)
554    }
555
556    fn invalidate_projection_mask_cache(&self) {
557        self.projection_active_mask_cache.lock().clear();
558    }
559
560    /// The parameters for this NLL.
561    pub fn parameters(&self) -> ParameterMap {
562        self.data_evaluator.parameters()
563    }
564
565    /// Number of free parameters.
566    pub fn n_free(&self) -> usize {
567        self.data_evaluator.n_free()
568    }
569
570    /// Number of fixed parameters.
571    pub fn n_fixed(&self) -> usize {
572        self.data_evaluator.n_fixed()
573    }
574
575    /// Total number of parameters.
576    pub fn n_parameters(&self) -> usize {
577        self.data_evaluator.n_parameters()
578    }
579
580    /// Returns the expression represented by this NLL.
581    pub fn expression(&self) -> Expression {
582        self.data_evaluator.expression()
583    }
584
585    /// Returns a tree-like diagnostic snapshot of the compiled expression for this NLL's current
586    /// active-amplitude mask.
587    pub fn compiled_expression(&self) -> CompiledExpression {
588        self.data_evaluator.compiled_expression()
589    }
590
591    /// Create a new [`StochasticNLL`] from this [`NLL`].
592    pub fn to_stochastic(
593        &self,
594        batch_size: usize,
595        seed: Option<usize>,
596    ) -> LadduResult<StochasticNLL> {
597        StochasticNLL::new(self.clone(), batch_size, seed)
598    }
599    /// Activate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag, skipping missing entries.
600    pub fn activate<T: AsRef<str>>(&self, name: T) {
601        self.invalidate_projection_mask_cache();
602        self.data_evaluator.activate(&name);
603        self.accmc_evaluator.activate(name);
604    }
605    /// Activate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag and return an error if it is missing.
606    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
607        self.invalidate_projection_mask_cache();
608        self.data_evaluator.activate_strict(&name)?;
609        self.accmc_evaluator.activate_strict(name)?;
610        Ok(())
611    }
612    /// Activate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag, skipping missing entries.
613    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
614        self.invalidate_projection_mask_cache();
615        self.data_evaluator.activate_many(names);
616        self.accmc_evaluator.activate_many(names);
617    }
618    /// Activate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag and return an error if any are missing.
619    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
620        self.invalidate_projection_mask_cache();
621        self.data_evaluator.activate_many_strict(names)?;
622        self.accmc_evaluator.activate_many_strict(names)?;
623        Ok(())
624    }
625    /// Activate all registered [`Amplitude`](`laddu_core::amplitude::Amplitude`)s.
626    pub fn activate_all(&self) {
627        self.invalidate_projection_mask_cache();
628        self.data_evaluator.activate_all();
629        self.accmc_evaluator.activate_all();
630    }
631    /// Deactivate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag, skipping missing entries.
632    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
633        self.invalidate_projection_mask_cache();
634        self.data_evaluator.deactivate(&name);
635        self.accmc_evaluator.deactivate(name);
636    }
637    /// Deactivate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag and return an error if it is missing.
638    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
639        self.invalidate_projection_mask_cache();
640        self.data_evaluator.deactivate_strict(&name)?;
641        self.accmc_evaluator.deactivate_strict(name)?;
642        Ok(())
643    }
644    /// Deactivate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag, skipping missing entries.
645    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
646        self.invalidate_projection_mask_cache();
647        self.data_evaluator.deactivate_many(names);
648        self.accmc_evaluator.deactivate_many(names);
649    }
650    /// Deactivate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag and return an error if any are missing.
651    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
652        self.invalidate_projection_mask_cache();
653        self.data_evaluator.deactivate_many_strict(names)?;
654        self.accmc_evaluator.deactivate_many_strict(names)?;
655        Ok(())
656    }
657    /// Deactivate all registered [`Amplitude`](`laddu_core::amplitude::Amplitude`)s.
658    pub fn deactivate_all(&self) {
659        self.invalidate_projection_mask_cache();
660        self.data_evaluator.deactivate_all();
661        self.accmc_evaluator.deactivate_all();
662    }
663    /// Isolate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag (deactivate the rest), skipping missing entries.
664    pub fn isolate<T: AsRef<str>>(&self, name: T) {
665        self.invalidate_projection_mask_cache();
666        self.data_evaluator.isolate(&name);
667        self.accmc_evaluator.isolate(name);
668    }
669    /// Isolate an [`Amplitude`](`laddu_core::amplitude::Amplitude`) by tag (deactivate the rest) and return an error if it is missing.
670    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
671        self.invalidate_projection_mask_cache();
672        self.data_evaluator.isolate_strict(&name)?;
673        self.accmc_evaluator.isolate_strict(name)?;
674        Ok(())
675    }
676    /// Isolate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag (deactivate the rest), skipping missing entries.
677    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
678        self.invalidate_projection_mask_cache();
679        self.data_evaluator.isolate_many(names);
680        self.accmc_evaluator.isolate_many(names);
681    }
682    /// Isolate several [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag (deactivate the rest) and return an error if any are missing.
683    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
684        self.invalidate_projection_mask_cache();
685        self.data_evaluator.isolate_many_strict(names)?;
686        self.accmc_evaluator.isolate_many_strict(names)?;
687        Ok(())
688    }
689
690    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
691    /// [`Evaluator`] with the given values for free parameters to obtain weights for each
692    /// Monte-Carlo event (non-MPI version).
693    ///
694    /// # Notes
695    ///
696    /// This method is not intended to be called in analyses but rather in writing methods
697    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights`] instead.
698    pub fn project_weights_local(
699        &self,
700        parameters: &[f64],
701        mc_evaluator: Option<Evaluator>,
702    ) -> LadduResult<Vec<f64>> {
703        validate_free_parameter_len(parameters.len(), self.n_free())?;
704        if let Some(mc_evaluator) = mc_evaluator {
705            project_weights_local_from_evaluator(&mc_evaluator, parameters, self.n_mc)
706        } else {
707            project_weights_local_from_evaluator(&self.accmc_evaluator, parameters, self.n_mc)
708        }
709    }
710
711    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
712    /// [`Evaluator`] with the given values for free parameters to obtain weights for each
713    /// Monte-Carlo event (MPI-compatible version).
714    ///
715    /// # Notes
716    ///
717    /// This method is not intended to be called in analyses but rather in writing methods
718    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights`] instead.
719    #[cfg(feature = "mpi")]
720    pub fn project_weights_mpi(
721        &self,
722        parameters: &[f64],
723        mc_evaluator: Option<Evaluator>,
724        world: &SimpleCommunicator,
725    ) -> LadduResult<Vec<f64>> {
726        let n_events = mc_evaluator
727            .as_ref()
728            .unwrap_or(&self.accmc_evaluator)
729            .dataset
730            .n_events();
731        let local_projection = self.project_weights_local(parameters, mc_evaluator)?;
732        let mut buffer: Vec<f64> = vec![0.0; n_events];
733        let (counts, displs) = world.get_counts_displs(n_events);
734        {
735            // NOTE: gather is required because projection returns per-event global outputs.
736            // Use all-reduce only for aggregate scalar/vector reductions.
737            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
738            world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
739        }
740        Ok(buffer)
741    }
742
743    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
744    /// [`Evaluator`] with the given values for free parameters to obtain weights for each
745    /// Monte-Carlo event. This method takes the real part of the given expression (discarding
746    /// the imaginary part entirely, which does not matter if expressions are coherent sums
747    /// wrapped in [`Expression::norm_sqr`](`laddu_core::Expression::norm_sqr`).
748    /// Event weights are determined by the following formula:
749    ///
750    /// ```math
751    /// \text{weight}(\vec{p}; e) = \text{weight}(e) \mathcal{L}(e) / N_{\text{MC}}
752    /// ```
753    ///
754    /// Note that $`N_{\text{MC}}`$ will always be the number of accepted Monte Carlo events,
755    /// regardless of the `mc_evaluator`.
756    pub fn project_weights(
757        &self,
758        parameters: &[f64],
759        mc_evaluator: Option<Evaluator>,
760    ) -> LadduResult<Vec<f64>> {
761        #[cfg(feature = "mpi")]
762        {
763            if let Some(world) = laddu_core::mpi::get_world() {
764                return self.project_weights_mpi(parameters, mc_evaluator, &world);
765            }
766        }
767        self.project_weights_local(parameters, mc_evaluator)
768    }
769
770    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
771    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
772    /// those weights for each Monte-Carlo event (non-MPI version).
773    ///
774    /// # Notes
775    ///
776    /// This method is not intended to be called in analyses but rather in writing methods
777    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_and_gradients`] instead.
778    pub fn project_weights_and_gradients_local(
779        &self,
780        parameters: &[f64],
781        mc_evaluator: Option<Evaluator>,
782    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
783        validate_free_parameter_len(parameters.len(), self.n_free())?;
784        if let Some(mc_evaluator) = mc_evaluator {
785            project_weights_and_gradients_local_from_evaluator(&mc_evaluator, parameters, self.n_mc)
786        } else {
787            project_weights_and_gradients_local_from_evaluator(
788                &self.accmc_evaluator,
789                parameters,
790                self.n_mc,
791            )
792        }
793    }
794
795    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
796    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
797    /// those weights for each Monte-Carlo event (MPI-compatible version).
798    ///
799    /// # Notes
800    ///
801    /// This method is not intended to be called in analyses but rather in writing methods
802    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_and_gradients`] instead.
803    #[cfg(feature = "mpi")]
804    pub fn project_weights_and_gradients_mpi(
805        &self,
806        parameters: &[f64],
807        mc_evaluator: Option<Evaluator>,
808        world: &SimpleCommunicator,
809    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
810        let n_events = mc_evaluator
811            .as_ref()
812            .unwrap_or(&self.accmc_evaluator)
813            .dataset
814            .n_events();
815        let (local_projection, local_gradient_projection) =
816            self.project_weights_and_gradients_local(parameters, mc_evaluator)?;
817        let mut projection_result: Vec<f64> = vec![0.0; n_events];
818        let (counts, displs) = world.get_counts_displs(n_events);
819        {
820            // NOTE: gather is required because projection-gradient returns per-event global outputs.
821            let mut partitioned_buffer = PartitionMut::new(&mut projection_result, counts, displs);
822            world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
823        }
824
825        let flattened_local_gradient_projection = local_gradient_projection
826            .iter()
827            .flat_map(|g| g.data.as_vec().to_vec())
828            .collect::<Vec<f64>>();
829        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
830        let mut flattened_result_buffer = vec![0.0; n_events * parameters.len()];
831        let mut partitioned_flattened_result_buffer =
832            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
833        // NOTE: gather is required because projection-gradient returns full per-event gradients.
834        world.all_gather_varcount_into(
835            &flattened_local_gradient_projection,
836            &mut partitioned_flattened_result_buffer,
837        );
838        let gradient_projection_result = flattened_result_buffer
839            .chunks(parameters.len())
840            .map(DVector::from_row_slice)
841            .collect();
842        Ok((projection_result, gradient_projection_result))
843    }
844    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
845    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
846    /// those weights for each Monte-Carlo event. This method takes the real part of the given
847    /// expression (discarding the imaginary part entirely, which does not matter if expressions
848    /// are coherent sums wrapped in [`Expression::norm_sqr`](`laddu_core::Expression::norm_sqr`).
849    /// Event weights are determined by the following formula:
850    ///
851    /// ```math
852    /// \text{weight}(\vec{p}; e) = \text{weight}(e) \mathcal{L}(e) / N_{\text{MC}}
853    /// ```
854    ///
855    /// Note that $`N_{\text{MC}}`$ will always be the number of accepted Monte Carlo events,
856    /// regardless of the `mc_evaluator`.
857    pub fn project_weights_and_gradients(
858        &self,
859        parameters: &[f64],
860        mc_evaluator: Option<Evaluator>,
861    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
862        #[cfg(feature = "mpi")]
863        {
864            if let Some(world) = laddu_core::mpi::get_world() {
865                return self.project_weights_and_gradients_mpi(parameters, mc_evaluator, &world);
866            }
867        }
868        self.project_weights_and_gradients_local(parameters, mc_evaluator)
869    }
870
871    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
872    /// [`Evaluator`] with the given values for free parameters to obtain weights for each Monte-Carlo event. This method differs from the standard
873    /// [`NLL::project_weights`] in that it first isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s
874    /// by tag, but returns the [`NLL`] to its prior state after calculation (non-MPI version).
875    ///
876    /// # Notes
877    ///
878    /// This method is not intended to be called in analyses but rather in writing methods
879    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_subset`] instead.
880    fn project_weights_subset_local_with_strict<T: AsRef<str>>(
881        &self,
882        parameters: &[f64],
883        names: &[T],
884        mc_evaluator: Option<Evaluator>,
885        strict: bool,
886    ) -> LadduResult<Vec<f64>> {
887        validate_free_parameter_len(parameters.len(), self.n_free())?;
888        if let Some(mc_evaluator) = mc_evaluator.as_ref() {
889            let resolved_mask =
890                Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)?;
891            project_weights_local_from_resolved_mask(
892                mc_evaluator,
893                parameters,
894                self.n_mc,
895                &resolved_mask,
896            )
897        } else {
898            let resolved_mask = self.get_or_build_projection_active_mask(names, strict)?;
899            project_weights_local_from_resolved_mask(
900                &self.accmc_evaluator,
901                parameters,
902                self.n_mc,
903                &resolved_mask,
904            )
905        }
906    }
907
908    /// Project the model over one isolated amplitude subset in local execution, skipping
909    /// missing amplitude tags.
910    pub fn project_weights_subset_local<T: AsRef<str>>(
911        &self,
912        parameters: &[f64],
913        names: &[T],
914        mc_evaluator: Option<Evaluator>,
915    ) -> LadduResult<Vec<f64>> {
916        self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, false)
917    }
918
919    /// Project the model over one isolated amplitude subset in local execution and return
920    /// an error if any requested amplitude tag is missing.
921    pub fn project_weights_subset_local_strict<T: AsRef<str>>(
922        &self,
923        parameters: &[f64],
924        names: &[T],
925        mc_evaluator: Option<Evaluator>,
926    ) -> LadduResult<Vec<f64>> {
927        self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, true)
928    }
929
930    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
931    /// [`Evaluator`] with the given values for free parameters to obtain weights for each Monte-Carlo event. This method differs from the standard
932    /// [`NLL::project_weights`] in that it first isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s
933    /// by tag, but returns the [`NLL`] to its prior state after calculation (MPI-compatible version).
934    ///
935    /// # Notes
936    ///
937    /// This method is not intended to be called in analyses but rather in writing methods
938    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_subset`] instead.
939    #[cfg(feature = "mpi")]
940    fn project_weights_subset_mpi_with_strict<T: AsRef<str>>(
941        &self,
942        parameters: &[f64],
943        names: &[T],
944        mc_evaluator: Option<Evaluator>,
945        world: &SimpleCommunicator,
946        strict: bool,
947    ) -> LadduResult<Vec<f64>> {
948        let n_events = mc_evaluator
949            .as_ref()
950            .unwrap_or(&self.accmc_evaluator)
951            .dataset
952            .n_events();
953        let local_projection =
954            self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, strict)?;
955        let mut buffer: Vec<f64> = vec![0.0; n_events];
956        let (counts, displs) = world.get_counts_displs(n_events);
957        {
958            // NOTE: gather is required because projection returns per-event global outputs.
959            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
960            world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
961        }
962        Ok(buffer)
963    }
964
965    #[cfg(feature = "mpi")]
966    /// Project the model over one isolated amplitude subset in MPI execution, skipping
967    /// missing amplitude tags.
968    pub fn project_weights_subset_mpi<T: AsRef<str>>(
969        &self,
970        parameters: &[f64],
971        names: &[T],
972        mc_evaluator: Option<Evaluator>,
973        world: &SimpleCommunicator,
974    ) -> LadduResult<Vec<f64>> {
975        self.project_weights_subset_mpi_with_strict(parameters, names, mc_evaluator, world, false)
976    }
977
978    #[cfg(feature = "mpi")]
979    /// Project the model over one isolated amplitude subset in MPI execution and return
980    /// an error if any requested amplitude tag is missing.
981    pub fn project_weights_subset_mpi_strict<T: AsRef<str>>(
982        &self,
983        parameters: &[f64],
984        names: &[T],
985        mc_evaluator: Option<Evaluator>,
986        world: &SimpleCommunicator,
987    ) -> LadduResult<Vec<f64>> {
988        self.project_weights_subset_mpi_with_strict(parameters, names, mc_evaluator, world, true)
989    }
990
991    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
992    /// [`Evaluator`] with the given values for free parameters to obtain weights for each Monte-Carlo event. This method differs from the standard
993    /// [`NLL::project_weights`] in that it first isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s
994    /// by tag, but returns the [`NLL`] to its prior state after calculation.
995    ///
996    /// This method takes the real part of the given expression (discarding
997    /// the imaginary part entirely, which does not matter if expressions are coherent sums
998    /// wrapped in [`Expression::norm_sqr`](`laddu_core::Expression::norm_sqr`).
999    /// Event weights are determined by the following formula:
1000    ///
1001    /// ```math
1002    /// \text{weight}(\vec{p}; e) = \text{weight}(e) \mathcal{L}(e) / N_{\text{MC}}
1003    /// ```
1004    ///
1005    /// Note that $`N_{\text{MC}}`$ will always be the number of accepted Monte Carlo events,
1006    /// regardless of the `mc_evaluator`.
1007    fn project_weights_subset_with_strict<T: AsRef<str>>(
1008        &self,
1009        parameters: &[f64],
1010        names: &[T],
1011        mc_evaluator: Option<Evaluator>,
1012        strict: bool,
1013    ) -> LadduResult<Vec<f64>> {
1014        #[cfg(feature = "mpi")]
1015        {
1016            if let Some(world) = laddu_core::mpi::get_world() {
1017                return self.project_weights_subset_mpi_with_strict(
1018                    parameters,
1019                    names,
1020                    mc_evaluator,
1021                    &world,
1022                    strict,
1023                );
1024            }
1025        }
1026        self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, strict)
1027    }
1028
1029    /// Project the model over one isolated amplitude subset, skipping missing amplitude
1030    /// names.
1031    pub fn project_weights_subset<T: AsRef<str>>(
1032        &self,
1033        parameters: &[f64],
1034        names: &[T],
1035        mc_evaluator: Option<Evaluator>,
1036    ) -> LadduResult<Vec<f64>> {
1037        self.project_weights_subset_with_strict(parameters, names, mc_evaluator, false)
1038    }
1039
1040    /// Project the model over one isolated amplitude subset and return an error if any
1041    /// requested amplitude tag is missing.
1042    pub fn project_weights_subset_strict<T: AsRef<str>>(
1043        &self,
1044        parameters: &[f64],
1045        names: &[T],
1046        mc_evaluator: Option<Evaluator>,
1047    ) -> LadduResult<Vec<f64>> {
1048        self.project_weights_subset_with_strict(parameters, names, mc_evaluator, true)
1049    }
1050
1051    /// Project the stored model over multiple isolated amplitude subsets (non-MPI version).
1052    fn project_weights_subsets_local_with_strict<T: AsRef<str>>(
1053        &self,
1054        parameters: &[f64],
1055        subsets: &[Vec<T>],
1056        mc_evaluator: Option<Evaluator>,
1057        strict: bool,
1058    ) -> LadduResult<Vec<Vec<f64>>> {
1059        validate_free_parameter_len(parameters.len(), self.n_free())?;
1060        if subsets.is_empty() {
1061            return Ok(Vec::new());
1062        }
1063        if let Some(mc_evaluator) = mc_evaluator.as_ref() {
1064            let resolved_masks = subsets
1065                .iter()
1066                .map(|names| {
1067                    Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)
1068                })
1069                .collect::<LadduResult<Vec<_>>>()?;
1070            resolved_masks
1071                .iter()
1072                .map(|mask| {
1073                    project_weights_local_from_resolved_mask(
1074                        mc_evaluator,
1075                        parameters,
1076                        self.n_mc,
1077                        mask,
1078                    )
1079                })
1080                .collect()
1081        } else {
1082            let resolved_masks = subsets
1083                .iter()
1084                .map(|names| self.get_or_build_projection_active_mask(names, strict))
1085                .collect::<LadduResult<Vec<_>>>()?;
1086            resolved_masks
1087                .iter()
1088                .map(|mask| {
1089                    project_weights_local_from_resolved_mask(
1090                        &self.accmc_evaluator,
1091                        parameters,
1092                        self.n_mc,
1093                        mask,
1094                    )
1095                })
1096                .collect()
1097        }
1098    }
1099
1100    /// Project the model over multiple isolated amplitude subsets in local execution,
1101    /// skipping missing amplitude tags within each subset.
1102    pub fn project_weights_subsets_local<T: AsRef<str>>(
1103        &self,
1104        parameters: &[f64],
1105        subsets: &[Vec<T>],
1106        mc_evaluator: Option<Evaluator>,
1107    ) -> LadduResult<Vec<Vec<f64>>> {
1108        self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, false)
1109    }
1110
1111    /// Project the model over multiple isolated amplitude subsets in local execution and
1112    /// return an error if any requested amplitude tag is missing.
1113    pub fn project_weights_subsets_local_strict<T: AsRef<str>>(
1114        &self,
1115        parameters: &[f64],
1116        subsets: &[Vec<T>],
1117        mc_evaluator: Option<Evaluator>,
1118    ) -> LadduResult<Vec<Vec<f64>>> {
1119        self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, true)
1120    }
1121
1122    /// Project the stored model over multiple isolated amplitude subsets (MPI-compatible version).
1123    #[cfg(feature = "mpi")]
1124    fn project_weights_subsets_mpi_with_strict<T: AsRef<str>>(
1125        &self,
1126        parameters: &[f64],
1127        subsets: &[Vec<T>],
1128        mc_evaluator: Option<Evaluator>,
1129        world: &SimpleCommunicator,
1130        strict: bool,
1131    ) -> LadduResult<Vec<Vec<f64>>> {
1132        let n_events = mc_evaluator
1133            .as_ref()
1134            .unwrap_or(&self.accmc_evaluator)
1135            .dataset
1136            .n_events();
1137        let local_projections = self.project_weights_subsets_local_with_strict(
1138            parameters,
1139            subsets,
1140            mc_evaluator,
1141            strict,
1142        )?;
1143        let (counts, displs) = world.get_counts_displs(n_events);
1144        let mut gathered = Vec::with_capacity(local_projections.len());
1145        for local_projection in local_projections {
1146            let mut buffer = vec![0.0; n_events];
1147            {
1148                let mut partitioned_buffer =
1149                    PartitionMut::new(&mut buffer, counts.clone(), displs.clone());
1150                world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
1151            }
1152            gathered.push(buffer);
1153        }
1154        Ok(gathered)
1155    }
1156
1157    #[cfg(feature = "mpi")]
1158    /// Project the model over multiple isolated amplitude subsets in MPI execution,
1159    /// skipping missing amplitude tags within each subset.
1160    pub fn project_weights_subsets_mpi<T: AsRef<str>>(
1161        &self,
1162        parameters: &[f64],
1163        subsets: &[Vec<T>],
1164        mc_evaluator: Option<Evaluator>,
1165        world: &SimpleCommunicator,
1166    ) -> LadduResult<Vec<Vec<f64>>> {
1167        self.project_weights_subsets_mpi_with_strict(
1168            parameters,
1169            subsets,
1170            mc_evaluator,
1171            world,
1172            false,
1173        )
1174    }
1175
1176    #[cfg(feature = "mpi")]
1177    /// Project the model over multiple isolated amplitude subsets in MPI execution and
1178    /// return an error if any requested amplitude tag is missing.
1179    pub fn project_weights_subsets_mpi_strict<T: AsRef<str>>(
1180        &self,
1181        parameters: &[f64],
1182        subsets: &[Vec<T>],
1183        mc_evaluator: Option<Evaluator>,
1184        world: &SimpleCommunicator,
1185    ) -> LadduResult<Vec<Vec<f64>>> {
1186        self.project_weights_subsets_mpi_with_strict(parameters, subsets, mc_evaluator, world, true)
1187    }
1188
1189    /// Project the stored model over multiple isolated amplitude subsets.
1190    fn project_weights_subsets_with_strict<T: AsRef<str>>(
1191        &self,
1192        parameters: &[f64],
1193        subsets: &[Vec<T>],
1194        mc_evaluator: Option<Evaluator>,
1195        strict: bool,
1196    ) -> LadduResult<Vec<Vec<f64>>> {
1197        #[cfg(feature = "mpi")]
1198        {
1199            if let Some(world) = laddu_core::mpi::get_world() {
1200                return self.project_weights_subsets_mpi_with_strict(
1201                    parameters,
1202                    subsets,
1203                    mc_evaluator,
1204                    &world,
1205                    strict,
1206                );
1207            }
1208        }
1209        self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, strict)
1210    }
1211
1212    /// Project the model over multiple isolated amplitude subsets, skipping missing
1213    /// amplitude tags within each subset.
1214    pub fn project_weights_subsets<T: AsRef<str>>(
1215        &self,
1216        parameters: &[f64],
1217        subsets: &[Vec<T>],
1218        mc_evaluator: Option<Evaluator>,
1219    ) -> LadduResult<Vec<Vec<f64>>> {
1220        self.project_weights_subsets_with_strict(parameters, subsets, mc_evaluator, false)
1221    }
1222
1223    /// Project the model over multiple isolated amplitude subsets and return an error if
1224    /// any requested amplitude tag is missing.
1225    pub fn project_weights_subsets_strict<T: AsRef<str>>(
1226        &self,
1227        parameters: &[f64],
1228        subsets: &[Vec<T>],
1229        mc_evaluator: Option<Evaluator>,
1230    ) -> LadduResult<Vec<Vec<f64>>> {
1231        self.project_weights_subsets_with_strict(parameters, subsets, mc_evaluator, true)
1232    }
1233
1234    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
1235    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
1236    /// those weights for each Monte-Carlo event. This method differs from the standard
1237    /// [`NLL::project_weights_and_gradients`] in that it first isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s
1238    /// by tag, but returns the [`NLL`] to its prior state after calculation (non-MPI version).
1239    ///
1240    /// # Notes
1241    ///
1242    /// This method is not intended to be called in analyses but rather in writing methods
1243    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_subset`] instead.
1244    fn project_weights_and_gradients_subset_local_with_strict<T: AsRef<str>>(
1245        &self,
1246        parameters: &[f64],
1247        names: &[T],
1248        mc_evaluator: Option<Evaluator>,
1249        strict: bool,
1250    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1251        validate_free_parameter_len(parameters.len(), self.n_free())?;
1252        let evaluator = mc_evaluator.as_ref().unwrap_or(&self.accmc_evaluator);
1253        let resolved_mask = if let Some(mc_evaluator) = mc_evaluator.as_ref() {
1254            Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)?
1255        } else {
1256            self.get_or_build_projection_active_mask(names, strict)?
1257        };
1258        let mc_dataset = &evaluator.dataset;
1259        let result =
1260            evaluator.evaluate_with_gradient_local_with_active_mask(parameters, &resolved_mask)?;
1261        #[cfg(feature = "rayon")]
1262        let (res, res_gradient) = {
1263            (
1264                result
1265                    .par_iter()
1266                    .zip(mc_dataset.weights_local().par_iter())
1267                    .map(|((l, _), e)| e * l.re / self.n_mc)
1268                    .collect(),
1269                result
1270                    .par_iter()
1271                    .zip(mc_dataset.weights_local().par_iter())
1272                    .map(|((_, grad_l), e)| grad_l.map(|g| g.re).scale(e / self.n_mc))
1273                    .collect(),
1274            )
1275        };
1276        #[cfg(not(feature = "rayon"))]
1277        let (res, res_gradient) = {
1278            (
1279                result
1280                    .iter()
1281                    .zip(mc_dataset.weights_local().iter())
1282                    .map(|((l, _), e)| e * l.re / self.n_mc)
1283                    .collect(),
1284                result
1285                    .iter()
1286                    .zip(mc_dataset.weights_local().iter())
1287                    .map(|((_, grad_l), e)| grad_l.map(|g| g.re).scale(e / self.n_mc))
1288                    .collect(),
1289            )
1290        };
1291        Ok((res, res_gradient))
1292    }
1293
1294    /// Project the model and parameter gradients over one isolated amplitude subset in
1295    /// local execution, skipping missing amplitude tags.
1296    pub fn project_weights_and_gradients_subset_local<T: AsRef<str>>(
1297        &self,
1298        parameters: &[f64],
1299        names: &[T],
1300        mc_evaluator: Option<Evaluator>,
1301    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1302        self.project_weights_and_gradients_subset_local_with_strict(
1303            parameters,
1304            names,
1305            mc_evaluator,
1306            false,
1307        )
1308    }
1309
1310    /// Project the model and parameter gradients over one isolated amplitude subset in
1311    /// local execution and return an error if any requested amplitude tag is missing.
1312    pub fn project_weights_and_gradients_subset_local_strict<T: AsRef<str>>(
1313        &self,
1314        parameters: &[f64],
1315        names: &[T],
1316        mc_evaluator: Option<Evaluator>,
1317    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1318        self.project_weights_and_gradients_subset_local_with_strict(
1319            parameters,
1320            names,
1321            mc_evaluator,
1322            true,
1323        )
1324    }
1325
1326    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
1327    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
1328    /// those weights for each Monte-Carlo event. This method differs from the standard
1329    /// [`NLL::project_weights_and_gradients`] in that it first isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s
1330    /// by tag, but returns the [`NLL`] to its prior state after calculation (MPI-compatible version).
1331    ///
1332    /// # Notes
1333    ///
1334    /// This method is not intended to be called in analyses but rather in writing methods
1335    /// that have `mpi`-feature-gated versions. Most users will want to call [`NLL::project_weights_subset`] instead.
1336    #[cfg(feature = "mpi")]
1337    fn project_weights_and_gradients_subset_mpi_with_strict<T: AsRef<str>>(
1338        &self,
1339        parameters: &[f64],
1340        names: &[T],
1341        mc_evaluator: Option<Evaluator>,
1342        world: &SimpleCommunicator,
1343        strict: bool,
1344    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1345        let n_events = mc_evaluator
1346            .as_ref()
1347            .unwrap_or(&self.accmc_evaluator)
1348            .dataset
1349            .n_events();
1350        let (local_projection, local_gradient_projection) = self
1351            .project_weights_and_gradients_subset_local_with_strict(
1352                parameters,
1353                names,
1354                mc_evaluator,
1355                strict,
1356            )?;
1357        let mut projection_result: Vec<f64> = vec![0.0; n_events];
1358        let (counts, displs) = world.get_counts_displs(n_events);
1359        {
1360            // NOTE: gather is required because projection-gradient returns per-event global outputs.
1361            let mut partitioned_buffer = PartitionMut::new(&mut projection_result, counts, displs);
1362            world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
1363        }
1364
1365        let flattened_local_gradient_projection = local_gradient_projection
1366            .iter()
1367            .flat_map(|g| g.data.as_vec().to_vec())
1368            .collect::<Vec<f64>>();
1369        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
1370        let mut flattened_result_buffer = vec![0.0; n_events * parameters.len()];
1371        let mut partitioned_flattened_result_buffer =
1372            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
1373        // NOTE: gather is required because projection-gradient returns full per-event gradients.
1374        world.all_gather_varcount_into(
1375            &flattened_local_gradient_projection,
1376            &mut partitioned_flattened_result_buffer,
1377        );
1378        let gradient_projection_result = flattened_result_buffer
1379            .chunks(parameters.len())
1380            .map(DVector::from_row_slice)
1381            .collect();
1382        Ok((projection_result, gradient_projection_result))
1383    }
1384
1385    #[cfg(feature = "mpi")]
1386    /// Project the model and parameter gradients over one isolated amplitude subset in
1387    /// MPI execution, skipping missing amplitude tags.
1388    pub fn project_weights_and_gradients_subset_mpi<T: AsRef<str>>(
1389        &self,
1390        parameters: &[f64],
1391        names: &[T],
1392        mc_evaluator: Option<Evaluator>,
1393        world: &SimpleCommunicator,
1394    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1395        self.project_weights_and_gradients_subset_mpi_with_strict(
1396            parameters,
1397            names,
1398            mc_evaluator,
1399            world,
1400            false,
1401        )
1402    }
1403
1404    #[cfg(feature = "mpi")]
1405    /// Project the model and parameter gradients over one isolated amplitude subset in
1406    /// MPI execution and return an error if any requested amplitude tag is missing.
1407    pub fn project_weights_and_gradients_subset_mpi_strict<T: AsRef<str>>(
1408        &self,
1409        parameters: &[f64],
1410        names: &[T],
1411        mc_evaluator: Option<Evaluator>,
1412        world: &SimpleCommunicator,
1413    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1414        self.project_weights_and_gradients_subset_mpi_with_strict(
1415            parameters,
1416            names,
1417            mc_evaluator,
1418            world,
1419            true,
1420        )
1421    }
1422    /// Project the stored [`Expression`] over the events in the [`Dataset`] stored by the
1423    /// [`Evaluator`] with the given values for free parameters to obtain weights and gradients of
1424    /// those weights for each
1425    /// Monte-Carlo event. This method differs from the standard [`NLL::project_weights_and_gradients`] in that it first
1426    /// isolates the selected [`Amplitude`](`laddu_core::amplitude::Amplitude`)s by tag, but returns
1427    /// the [`NLL`] to its prior state after calculation.
1428    ///
1429    /// This method takes the real part of the given expression (discarding
1430    /// the imaginary part entirely, which does not matter if expressions are coherent sums
1431    /// wrapped in [`Expression::norm_sqr`](`laddu_core::Expression::norm_sqr`).
1432    /// Event weights are determined by the following formula:
1433    ///
1434    /// ```math
1435    /// \text{weight}(\vec{p}; e) = \text{weight}(e) \mathcal{L}(e) / N_{\text{MC}}
1436    /// ```
1437    ///
1438    /// Note that $`N_{\text{MC}}`$ will always be the number of accepted Monte Carlo events,
1439    /// regardless of the `mc_evaluator`.
1440    fn project_weights_and_gradients_subset_with_strict<T: AsRef<str>>(
1441        &self,
1442        parameters: &[f64],
1443        names: &[T],
1444        mc_evaluator: Option<Evaluator>,
1445        strict: bool,
1446    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1447        #[cfg(feature = "mpi")]
1448        {
1449            if let Some(world) = laddu_core::mpi::get_world() {
1450                return self.project_weights_and_gradients_subset_mpi_with_strict(
1451                    parameters,
1452                    names,
1453                    mc_evaluator,
1454                    &world,
1455                    strict,
1456                );
1457            }
1458        }
1459        self.project_weights_and_gradients_subset_local_with_strict(
1460            parameters,
1461            names,
1462            mc_evaluator,
1463            strict,
1464        )
1465    }
1466
1467    /// Project the model and parameter gradients over one isolated amplitude subset,
1468    /// skipping missing amplitude tags.
1469    pub fn project_weights_and_gradients_subset<T: AsRef<str>>(
1470        &self,
1471        parameters: &[f64],
1472        names: &[T],
1473        mc_evaluator: Option<Evaluator>,
1474    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1475        self.project_weights_and_gradients_subset_with_strict(
1476            parameters,
1477            names,
1478            mc_evaluator,
1479            false,
1480        )
1481    }
1482
1483    /// Project the model and parameter gradients over one isolated amplitude subset and
1484    /// return an error if any requested amplitude tag is missing.
1485    pub fn project_weights_and_gradients_subset_strict<T: AsRef<str>>(
1486        &self,
1487        parameters: &[f64],
1488        names: &[T],
1489        mc_evaluator: Option<Evaluator>,
1490    ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1491        self.project_weights_and_gradients_subset_with_strict(parameters, names, mc_evaluator, true)
1492    }
1493
1494    fn evaluate_data_term_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1495        evaluate_weighted_expression_sum_local(&self.data_evaluator, parameters, |l| f64::ln(l.re))
1496    }
1497
1498    fn evaluate_mc_term_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1499        self.accmc_evaluator
1500            .evaluate_weighted_value_sum_local(parameters)
1501    }
1502
1503    #[doc(hidden)]
1504    pub fn profile_data_term_local_value(&self, parameters: &[f64]) -> LadduResult<f64> {
1505        self.evaluate_data_term_local(parameters)
1506    }
1507
1508    #[doc(hidden)]
1509    pub fn profile_mc_term_local_value(&self, parameters: &[f64]) -> LadduResult<f64> {
1510        self.evaluate_mc_term_local(parameters)
1511    }
1512
1513    pub(crate) fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1514        let data_term = self.evaluate_data_term_local(parameters)?;
1515        let mc_term = self.evaluate_mc_term_local(parameters)?;
1516        Ok(-2.0 * (data_term - mc_term / self.n_mc))
1517    }
1518
1519    #[cfg(feature = "mpi")]
1520    #[doc(hidden)]
1521    pub fn evaluate_mpi(&self, parameters: &[f64], world: &SimpleCommunicator) -> LadduResult<f64> {
1522        let data_term_local = self.evaluate_data_term_local(parameters)?;
1523        let data_term = reduce_scalar(world, data_term_local);
1524        let mc_term = self
1525            .accmc_evaluator
1526            .evaluate_weighted_value_sum_mpi(parameters, world)?;
1527        Ok(-2.0 * (data_term - mc_term / self.n_mc))
1528    }
1529
1530    pub(crate) fn evaluate_data_gradient_term_local(
1531        &self,
1532        parameters: &[f64],
1533    ) -> LadduResult<DVector<f64>> {
1534        let data_resources = self.data_evaluator.resources.read();
1535        let data_parameters = data_resources.parameter_map.assemble(parameters)?;
1536        let data_active_indices = data_resources.active_indices().to_vec();
1537        let data_active_mask = data_resources.active.clone();
1538        #[cfg(feature = "rayon")]
1539        let n_parameters = parameters.len();
1540        #[cfg(feature = "rayon")]
1541        let data_scratch_key = GradientScratchKey {
1542            n_parameters,
1543            n_amplitudes: self.data_evaluator.amplitude_value_slot_count(),
1544            n_expression_slots: self.data_evaluator.expression_slot_count(),
1545        };
1546        #[cfg(feature = "rayon")]
1547        let data_term: DVector<f64> = sum_dvectors_parallel(
1548            self.data_evaluator
1549                .dataset
1550                .weights_local()
1551                .par_iter()
1552                .zip(data_resources.caches.par_iter())
1553                .map_init(
1554                    || acquire_gradient_scratch(data_scratch_key),
1555                    |scratch, (event, cache)| {
1556                        let workspace = scratch.workspace_mut();
1557                        let amp_vals = &mut workspace.amplitude_values;
1558                        let grad_vals = &mut workspace.gradient_values;
1559                        self.data_evaluator.fill_amplitude_values_and_gradients(
1560                            amp_vals,
1561                            grad_vals,
1562                            &data_active_indices,
1563                            &data_active_mask,
1564                            &data_parameters,
1565                            cache,
1566                        );
1567                        let (value, gradient) = self
1568                            .data_evaluator
1569                            .evaluate_expression_value_gradient_with_scratch(
1570                                amp_vals,
1571                                grad_vals,
1572                                &mut workspace.value_slots,
1573                                &mut workspace.gradient_slots,
1574                            );
1575                        (*event, value, gradient)
1576                    },
1577                )
1578                .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re)),
1579            n_parameters,
1580        );
1581        #[cfg(not(feature = "rayon"))]
1582        let data_term: DVector<f64> = {
1583            let amplitude_len = self.data_evaluator.amplitude_value_slot_count();
1584            let mut amp_vals = vec![Complex64::ZERO; amplitude_len];
1585            let mut grad_vals = vec![DVector::zeros(parameters.len()); amplitude_len];
1586            let mut value_slots =
1587                vec![Complex64::ZERO; self.data_evaluator.expression_slot_count()];
1588            let mut gradient_slots =
1589                vec![DVector::zeros(parameters.len()); self.data_evaluator.expression_slot_count()];
1590            self.data_evaluator
1591                .dataset
1592                .weights_local()
1593                .iter()
1594                .zip(data_resources.caches.iter())
1595                .map(|(event, cache)| {
1596                    self.data_evaluator.fill_amplitude_values_and_gradients(
1597                        &mut amp_vals,
1598                        &mut grad_vals,
1599                        &data_active_indices,
1600                        &data_active_mask,
1601                        &data_parameters,
1602                        cache,
1603                    );
1604                    let (value, gradient) = self
1605                        .data_evaluator
1606                        .evaluate_expression_value_gradient_with_scratch(
1607                            &amp_vals,
1608                            &grad_vals,
1609                            &mut value_slots,
1610                            &mut gradient_slots,
1611                        );
1612                    (*event, value, gradient)
1613                })
1614                .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re))
1615                .sum()
1616        };
1617        Ok(data_term)
1618    }
1619
1620    #[doc(hidden)]
1621    pub fn evaluate_gradient_local(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1622        let data_term = self.evaluate_data_gradient_term_local(parameters)?;
1623        let mc_term = self
1624            .accmc_evaluator
1625            .evaluate_weighted_gradient_sum_local(parameters)?;
1626        Ok(-2.0 * (data_term - mc_term / self.n_mc))
1627    }
1628
1629    #[cfg(feature = "mpi")]
1630    #[doc(hidden)]
1631    pub fn evaluate_gradient_mpi(
1632        &self,
1633        parameters: &[f64],
1634        world: &SimpleCommunicator,
1635    ) -> LadduResult<DVector<f64>> {
1636        let data_term_local = self.evaluate_data_gradient_term_local(parameters)?;
1637        let data_term = reduce_gradient(world, &data_term_local);
1638        let mc_term = self
1639            .accmc_evaluator
1640            .evaluate_weighted_gradient_sum_mpi(parameters, world)?;
1641        Ok(-2.0 * (data_term - mc_term / self.n_mc))
1642    }
1643}
1644
1645impl LikelihoodTerm for NLL {
1646    fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
1647        validate_free_parameter_len(parameters.len(), self.n_free())?;
1648        #[cfg(feature = "mpi")]
1649        {
1650            if let Some(world) = laddu_core::mpi::get_world() {
1651                return self.evaluate_mpi(parameters, &world);
1652            }
1653        }
1654        self.evaluate_local(parameters)
1655    }
1656    fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1657        validate_free_parameter_len(parameters.len(), self.n_free())?;
1658        #[cfg(feature = "mpi")]
1659        {
1660            if let Some(world) = laddu_core::mpi::get_world() {
1661                return self.evaluate_gradient_mpi(parameters, &world);
1662            }
1663        }
1664        self.evaluate_gradient_local(parameters)
1665    }
1666    fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1667        self.data_evaluator.fix_parameter(name, value)?;
1668        self.accmc_evaluator.fix_parameter(name, value)?;
1669        Ok(())
1670    }
1671    fn free_parameter(&self, name: &str) -> LadduResult<()> {
1672        self.data_evaluator.free_parameter(name)?;
1673        self.accmc_evaluator.free_parameter(name)?;
1674        Ok(())
1675    }
1676    fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
1677        self.data_evaluator.rename_parameter(old, new)?;
1678        self.accmc_evaluator.rename_parameter(old, new)?;
1679        Ok(())
1680    }
1681    fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1682        self.data_evaluator.rename_parameters(mapping)?;
1683        self.accmc_evaluator.rename_parameters(mapping)?;
1684        Ok(())
1685    }
1686    fn parameter_map(&self) -> ParameterMap {
1687        self.data_evaluator.resources.read().parameter_map.clone()
1688    }
1689}
1690
1691/// A stochastic [`NLL`] term.
1692///
1693/// While a regular [`NLL`] will operate over the entire dataset, this term will only operate over
1694/// a random subset of the data, determined by the `batch_size` parameter. This will make the
1695/// objective function faster to evaluate at the cost of adding random noise to the likelihood.
1696#[derive(Clone)]
1697pub struct StochasticNLL {
1698    /// A handle to the original [`NLL`] term.
1699    pub nll: NLL,
1700    n: usize,
1701    batch_size: usize,
1702    batch_indices: Arc<Mutex<Vec<usize>>>,
1703    rng: Arc<Mutex<Rng>>,
1704}
1705
1706impl LikelihoodTerm for StochasticNLL {
1707    fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
1708        validate_free_parameter_len(parameters.len(), self.nll.n_free())?;
1709        let indices = self.batch_indices.lock();
1710        #[cfg(feature = "mpi")]
1711        {
1712            if let Some(world) = laddu_core::mpi::get_world() {
1713                return self.evaluate_mpi(parameters, &indices, &world);
1714            }
1715        }
1716        #[cfg(feature = "rayon")]
1717        let n_data_batch_local = indices
1718            .par_iter()
1719            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1720            .parallel_sum_with_accumulator::<Klein<f64>>();
1721        #[cfg(not(feature = "rayon"))]
1722        let n_data_batch_local = indices
1723            .iter()
1724            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1725            .sum_with_accumulator::<Klein<f64>>();
1726        self.evaluate_local(parameters, &indices, n_data_batch_local)
1727    }
1728    fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1729        validate_free_parameter_len(parameters.len(), self.nll.n_free())?;
1730        let indices = self.batch_indices.lock();
1731        #[cfg(feature = "mpi")]
1732        {
1733            if let Some(world) = laddu_core::mpi::get_world() {
1734                return self.evaluate_gradient_mpi(parameters, &indices, &world);
1735            }
1736        }
1737        #[cfg(feature = "rayon")]
1738        let n_data_batch_local = indices
1739            .par_iter()
1740            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1741            .parallel_sum_with_accumulator::<Klein<f64>>();
1742        #[cfg(not(feature = "rayon"))]
1743        let n_data_batch_local = indices
1744            .iter()
1745            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1746            .sum_with_accumulator::<Klein<f64>>();
1747        self.evaluate_gradient_local(parameters, &indices, n_data_batch_local)
1748    }
1749    fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1750        self.nll.fix_parameter(name, value)
1751    }
1752    fn free_parameter(&self, name: &str) -> LadduResult<()> {
1753        self.nll.free_parameter(name)
1754    }
1755    fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
1756        self.nll.rename_parameter(old, new)
1757    }
1758    fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1759        self.nll.rename_parameters(mapping)
1760    }
1761    fn update(&self) {
1762        self.resample();
1763    }
1764    fn parameter_map(&self) -> ParameterMap {
1765        self.nll.parameter_map()
1766    }
1767}
1768
1769impl StochasticNLL {
1770    /// Generate a new [`StochasticNLL`] with the given [`NLL`], batch size, and optional random seed
1771    ///
1772    /// # See Also
1773    ///
1774    /// [`NLL::to_stochastic`]
1775    pub fn new(nll: NLL, batch_size: usize, seed: Option<usize>) -> LadduResult<Self> {
1776        let mut rng = seed.map_or_else(Rng::new, |seed| Rng::with_seed(seed as u64));
1777        let n = nll.data_evaluator.dataset.n_events();
1778        validate_stochastic_batch_size(batch_size, n)?;
1779        let batch_indices = rng.subset(batch_size, n);
1780        Ok(Self {
1781            nll,
1782            n,
1783            batch_size,
1784            batch_indices: Arc::new(Mutex::new(batch_indices)),
1785            rng: Arc::new(Mutex::new(rng)),
1786        })
1787    }
1788    /// Resample the batch indices used in evaluation
1789    pub fn resample(&self) {
1790        let mut rng = self.rng.lock();
1791        *self.batch_indices.lock() = rng.subset(self.batch_size, self.n);
1792    }
1793
1794    /// The parameters for this stochastic NLL.
1795    pub fn parameters(&self) -> ParameterMap {
1796        self.nll.parameters()
1797    }
1798
1799    /// Number of free parameters.
1800    pub fn n_free(&self) -> usize {
1801        self.nll.n_free()
1802    }
1803
1804    /// Number of fixed parameters.
1805    pub fn n_fixed(&self) -> usize {
1806        self.nll.n_fixed()
1807    }
1808
1809    /// Total number of parameters.
1810    pub fn n_parameters(&self) -> usize {
1811        self.nll.n_parameters()
1812    }
1813
1814    /// Returns the expression represented by this stochastic NLL.
1815    pub fn expression(&self) -> Expression {
1816        self.nll.expression()
1817    }
1818
1819    /// Returns a tree-like diagnostic snapshot of the compiled expression for this stochastic
1820    /// NLL's current active-amplitude mask.
1821    pub fn compiled_expression(&self) -> CompiledExpression {
1822        self.nll.compiled_expression()
1823    }
1824
1825    #[cfg(feature = "mpi")]
1826    fn data_batch_weight_local(&self, indices: &[usize]) -> f64 {
1827        #[cfg(feature = "rayon")]
1828        return indices
1829            .par_iter()
1830            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1831            .parallel_sum_with_accumulator::<Klein<f64>>();
1832        #[cfg(not(feature = "rayon"))]
1833        return indices
1834            .iter()
1835            .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1836            .sum_with_accumulator::<Klein<f64>>();
1837    }
1838
1839    fn evaluate_data_term_local(&self, parameters: &[f64], indices: &[usize]) -> LadduResult<f64> {
1840        let data_result = self
1841            .nll
1842            .data_evaluator
1843            .evaluate_batch_local(parameters, indices)?;
1844        #[cfg(feature = "rayon")]
1845        {
1846            Ok(indices
1847                .par_iter()
1848                .zip(data_result.par_iter())
1849                .map(|(&i, &l)| {
1850                    let e = &self.nll.data_evaluator.dataset.weights_local()[i];
1851                    e * l.re.ln()
1852                })
1853                .parallel_sum_with_accumulator::<Klein<f64>>())
1854        }
1855        #[cfg(not(feature = "rayon"))]
1856        {
1857            Ok(indices
1858                .iter()
1859                .zip(data_result.iter())
1860                .map(|(&i, &l)| {
1861                    let e = &self.nll.data_evaluator.dataset.weights_local()[i];
1862                    e * l.re.ln()
1863                })
1864                .sum_with_accumulator::<Klein<f64>>())
1865        }
1866    }
1867
1868    fn evaluate_local(
1869        &self,
1870        parameters: &[f64],
1871        indices: &[usize],
1872        n_data_batch: f64,
1873    ) -> LadduResult<f64> {
1874        let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
1875        let data_term = self.evaluate_data_term_local(parameters, indices)?;
1876        let mc_term = self
1877            .nll
1878            .accmc_evaluator
1879            .evaluate_weighted_value_sum_local(parameters)?;
1880        Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
1881    }
1882
1883    #[cfg(feature = "mpi")]
1884    fn evaluate_mpi(
1885        &self,
1886        parameters: &[f64],
1887        indices: &[usize],
1888        world: &SimpleCommunicator,
1889    ) -> LadduResult<f64> {
1890        let total = self.nll.data_evaluator.dataset.n_events();
1891        let locals = world.locals_from_globals(indices, total);
1892        let n_data_batch_local = self.data_batch_weight_local(&locals);
1893        let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
1894        let data_term_local = self.evaluate_data_term_local(parameters, &locals)?;
1895        let n_data_batch = reduce_scalar(world, n_data_batch_local);
1896        let data_term = reduce_scalar(world, data_term_local);
1897        let mc_term = self
1898            .nll
1899            .accmc_evaluator
1900            .evaluate_weighted_value_sum_mpi(parameters, world)?;
1901        Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
1902    }
1903
1904    fn evaluate_data_gradient_term_local(
1905        &self,
1906        parameters: &[f64],
1907        indices: &[usize],
1908    ) -> LadduResult<DVector<f64>> {
1909        let data_resources = self.nll.data_evaluator.resources.read();
1910        let data_parameters = data_resources.parameter_map.assemble(parameters)?;
1911        let data_active_indices = data_resources.active_indices().to_vec();
1912        let data_active_mask = data_resources.active.clone();
1913        #[cfg(feature = "rayon")]
1914        let n_parameters = parameters.len();
1915        #[cfg(feature = "rayon")]
1916        let data_scratch_key = GradientScratchKey {
1917            n_parameters,
1918            n_amplitudes: self.nll.data_evaluator.amplitude_value_slot_count(),
1919            n_expression_slots: self.nll.data_evaluator.expression_slot_count(),
1920        };
1921        #[cfg(feature = "rayon")]
1922        let data_term: DVector<f64> = sum_dvectors_parallel(
1923            indices
1924                .par_iter()
1925                .map_init(
1926                    || acquire_gradient_scratch(data_scratch_key),
1927                    |scratch, &idx| {
1928                        let workspace = scratch.workspace_mut();
1929                        let amp_vals = &mut workspace.amplitude_values;
1930                        let grad_vals = &mut workspace.gradient_values;
1931                        let event = &self.nll.data_evaluator.dataset.weights_local()[idx];
1932                        let cache = &data_resources.caches[idx];
1933                        self.nll.data_evaluator.fill_amplitude_values_and_gradients(
1934                            amp_vals,
1935                            grad_vals,
1936                            &data_active_indices,
1937                            &data_active_mask,
1938                            &data_parameters,
1939                            cache,
1940                        );
1941                        let (value, gradient) = self
1942                            .nll
1943                            .data_evaluator
1944                            .evaluate_expression_value_gradient_with_scratch(
1945                                amp_vals,
1946                                grad_vals,
1947                                &mut workspace.value_slots,
1948                                &mut workspace.gradient_slots,
1949                            );
1950                        (*event, value, gradient)
1951                    },
1952                )
1953                .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re)),
1954            n_parameters,
1955        );
1956        #[cfg(not(feature = "rayon"))]
1957        let data_term: DVector<f64> = {
1958            let amplitude_len = self.nll.data_evaluator.amplitude_value_slot_count();
1959            let mut amp_vals = vec![Complex64::ZERO; amplitude_len];
1960            let mut grad_vals = vec![DVector::zeros(parameters.len()); amplitude_len];
1961            let mut value_slots =
1962                vec![Complex64::ZERO; self.nll.data_evaluator.expression_slot_count()];
1963            let mut gradient_slots = vec![
1964                DVector::zeros(parameters.len());
1965                self.nll.data_evaluator.expression_slot_count()
1966            ];
1967            indices
1968                .iter()
1969                .map(|&idx| {
1970                    let event = &self.nll.data_evaluator.dataset.weights_local()[idx];
1971                    let cache = &data_resources.caches[idx];
1972                    self.nll.data_evaluator.fill_amplitude_values_and_gradients(
1973                        &mut amp_vals,
1974                        &mut grad_vals,
1975                        &data_active_indices,
1976                        &data_active_mask,
1977                        &data_parameters,
1978                        cache,
1979                    );
1980                    let (value, gradient) = self
1981                        .nll
1982                        .data_evaluator
1983                        .evaluate_expression_value_gradient_with_scratch(
1984                            &amp_vals,
1985                            &grad_vals,
1986                            &mut value_slots,
1987                            &mut gradient_slots,
1988                        );
1989                    (*event, value, gradient)
1990                })
1991                .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re))
1992                .sum()
1993        };
1994        Ok(data_term)
1995    }
1996
1997    fn evaluate_gradient_local(
1998        &self,
1999        parameters: &[f64],
2000        indices: &[usize],
2001        n_data_batch: f64,
2002    ) -> LadduResult<DVector<f64>> {
2003        let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
2004        let data_term = self.evaluate_data_gradient_term_local(parameters, indices)?;
2005        let mc_term = self
2006            .nll
2007            .accmc_evaluator
2008            .evaluate_weighted_gradient_sum_local(parameters)?;
2009        Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
2010    }
2011
2012    #[cfg(feature = "mpi")]
2013    fn evaluate_gradient_mpi(
2014        &self,
2015        parameters: &[f64],
2016        indices: &[usize],
2017        world: &SimpleCommunicator,
2018    ) -> LadduResult<DVector<f64>> {
2019        let total = self.nll.data_evaluator.dataset.n_events();
2020        let locals = world.locals_from_globals(indices, total);
2021        let n_data_batch_local = self.data_batch_weight_local(&locals);
2022        let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
2023        let data_term_local = self.evaluate_data_gradient_term_local(parameters, &locals)?;
2024        let n_data_batch = reduce_scalar(world, n_data_batch_local);
2025        let data_term = reduce_gradient(world, &data_term_local);
2026        let mc_term = self
2027            .nll
2028            .accmc_evaluator
2029            .evaluate_weighted_gradient_sum_mpi(parameters, world)?;
2030        Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
2031    }
2032}