Skip to main content

laddu_extensions/likelihood/
expression.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4};
5
6use auto_ops::*;
7use laddu_core::{amplitude::ParameterMap, LadduError, LadduResult};
8use nalgebra::DVector;
9
10use super::term::LikelihoodTerm;
11
12#[derive(Debug)]
13struct LikelihoodValues(Vec<f64>);
14
15#[derive(Debug)]
16struct LikelihoodGradients(Vec<DVector<f64>>);
17
18#[derive(Clone, Default)]
19enum LikelihoodNode {
20    #[default]
21    Zero,
22    One,
23    Term(usize),
24    Add(Box<LikelihoodNode>, Box<LikelihoodNode>),
25    Mul(Box<LikelihoodNode>, Box<LikelihoodNode>),
26}
27
28impl LikelihoodNode {
29    fn remap(&self, mapping: &[usize]) -> Self {
30        match self {
31            Self::Term(idx) => Self::Term(mapping[*idx]),
32            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
33            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
34            Self::Zero => Self::Zero,
35            Self::One => Self::One,
36        }
37    }
38
39    fn evaluate(&self, likelihood_values: &LikelihoodValues) -> f64 {
40        match self {
41            LikelihoodNode::Zero => 0.0,
42            LikelihoodNode::One => 1.0,
43            LikelihoodNode::Term(idx) => likelihood_values.0[*idx],
44            LikelihoodNode::Add(a, b) => {
45                a.evaluate(likelihood_values) + b.evaluate(likelihood_values)
46            }
47            LikelihoodNode::Mul(a, b) => {
48                a.evaluate(likelihood_values) * b.evaluate(likelihood_values)
49            }
50        }
51    }
52
53    fn evaluate_gradient(
54        &self,
55        likelihood_values: &LikelihoodValues,
56        likelihood_gradients: &LikelihoodGradients,
57    ) -> DVector<f64> {
58        match self {
59            LikelihoodNode::Zero => DVector::zeros(0),
60            LikelihoodNode::One => DVector::zeros(0),
61            LikelihoodNode::Term(idx) => likelihood_gradients.0[*idx].clone(),
62            LikelihoodNode::Add(a, b) => {
63                a.evaluate_gradient(likelihood_values, likelihood_gradients)
64                    + b.evaluate_gradient(likelihood_values, likelihood_gradients)
65            }
66            LikelihoodNode::Mul(a, b) => {
67                a.evaluate_gradient(likelihood_values, likelihood_gradients)
68                    * b.evaluate(likelihood_values)
69                    + b.evaluate_gradient(likelihood_values, likelihood_gradients)
70                        * a.evaluate(likelihood_values)
71            }
72        }
73    }
74
75    fn write_tree(
76        &self,
77        f: &mut std::fmt::Formatter<'_>,
78        parent_prefix: &str,
79        immediate_prefix: &str,
80        parent_suffix: &str,
81    ) -> std::fmt::Result {
82        let display_string = match self {
83            Self::Zero => "0".to_string(),
84            Self::One => "1".to_string(),
85            Self::Term(idx) => format!("term({idx})"),
86            Self::Add(_, _) => "+".to_string(),
87            Self::Mul(_, _) => "*".to_string(),
88        };
89        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
90        match self {
91            Self::Term(_) | Self::Zero | Self::One => {}
92            Self::Add(a, b) | Self::Mul(a, b) => {
93                let terms = [a, b];
94                let mut it = terms.iter().peekable();
95                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
96                while let Some(child) = it.next() {
97                    match it.peek() {
98                        Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│  ")?,
99                        None => child.write_tree(f, &child_prefix, "└─ ", "   ")?,
100                    }
101                }
102            }
103        }
104        Ok(())
105    }
106}
107
108/// A combination of [`LikelihoodTerm`]s as well as sums and products of them.
109///
110/// # Notes
111/// When multiple terms provide parameters with the same name, the term earliest in the expression
112/// (or argument list) defines the fixed/free status and default value.
113#[derive(Clone, Default)]
114pub struct LikelihoodExpression {
115    registry: LikelihoodRegistry,
116    tree: LikelihoodNode,
117}
118
119impl LikelihoodExpression {
120    /// Build a [`LikelihoodExpression`] from a single [`LikelihoodTerm`].
121    pub fn from_term(term: Box<dyn LikelihoodTerm>) -> LadduResult<Self> {
122        let registry = LikelihoodRegistry::singleton(term)?;
123        Ok(Self {
124            registry,
125            tree: LikelihoodNode::Term(0),
126        })
127    }
128
129    /// Create an expression representing zero, the additive identity.
130    pub fn zero() -> Self {
131        Self {
132            registry: LikelihoodRegistry::default(),
133            tree: LikelihoodNode::Zero,
134        }
135    }
136
137    /// Create an expression representing one, the multiplicative identity.
138    pub fn one() -> Self {
139        Self {
140            registry: LikelihoodRegistry::default(),
141            tree: LikelihoodNode::One,
142        }
143    }
144
145    fn binary_op(
146        a: &LikelihoodExpression,
147        b: &LikelihoodExpression,
148        build: impl Fn(Box<LikelihoodNode>, Box<LikelihoodNode>) -> LikelihoodNode,
149    ) -> LikelihoodExpression {
150        let (registry, left_map, right_map) = a.registry.merge(&b.registry);
151        let left_tree = a.tree.remap(&left_map);
152        let right_tree = b.tree.remap(&right_map);
153        LikelihoodExpression {
154            registry,
155            tree: build(Box::new(left_tree), Box::new(right_tree)),
156        }
157    }
158
159    fn write_tree(
160        &self,
161        f: &mut std::fmt::Formatter<'_>,
162        parent_prefix: &str,
163        immediate_prefix: &str,
164        parent_suffix: &str,
165    ) -> std::fmt::Result {
166        self.tree
167            .write_tree(f, parent_prefix, immediate_prefix, parent_suffix)
168    }
169
170    /// The parameters referenced across all terms in this expression.
171    pub fn parameters(&self) -> ParameterMap {
172        self.registry.global_parameter_map().clone()
173    }
174
175    /// Number of free parameters.
176    pub fn n_free(&self) -> usize {
177        self.registry.global_parameter_map().free().len()
178    }
179
180    /// Number of fixed parameters.
181    pub fn n_fixed(&self) -> usize {
182        self.registry.global_parameter_map().fixed().len()
183    }
184
185    /// Total number of parameters (free + fixed).
186    pub fn n_parameters(&self) -> usize {
187        self.registry.global_parameter_map().len()
188    }
189
190    /// Evaluate the sum/product of all terms.
191    pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
192        let layout = self.registry.global_layout()?;
193        layout.global_map.assemble(parameters)?; // NOTE: just a check
194        let likelihood_values = LikelihoodValues(
195            self.registry
196                .terms
197                .iter()
198                .zip(layout.layouts.iter())
199                .map(|(term, term_layout)| {
200                    term.evaluate(
201                        &term_layout
202                            .iter()
203                            .map(|&global_idx| parameters[global_idx])
204                            .collect::<Vec<_>>(),
205                    )
206                })
207                .collect::<LadduResult<Vec<_>>>()?,
208        );
209        Ok(self.tree.evaluate(&likelihood_values))
210    }
211
212    /// Evaluate the gradient.
213    pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
214        let free_parameter_count = parameters.len();
215        let layout = self.registry.global_layout()?;
216        layout.global_map.assemble(parameters)?; // NOTE: just a check
217        let parameter_sets = layout
218            .layouts
219            .iter()
220            .map(|term_layout| {
221                term_layout
222                    .iter()
223                    .map(|&global_idx| parameters[global_idx])
224                    .collect::<Vec<_>>()
225            })
226            .collect::<Vec<_>>();
227        let likelihood_values = LikelihoodValues(
228            self.registry
229                .terms
230                .iter()
231                .zip(parameter_sets.iter())
232                .map(|(term, term_parameters)| term.evaluate(term_parameters))
233                .collect::<LadduResult<Vec<_>>>()?,
234        );
235        let mut gradient_buffers: Vec<DVector<f64>> = (0..self.registry.terms.len())
236            .map(|_| DVector::zeros(parameters.len()))
237            .collect();
238        for (((term, term_parameters), gradient_buffer), layout) in self
239            .registry
240            .terms
241            .iter()
242            .zip(parameter_sets.iter())
243            .zip(gradient_buffers.iter_mut())
244            .zip(layout.layouts.iter())
245        {
246            let term_gradient = term.evaluate_gradient(term_parameters)?; // This has a local layout
247            for (term_idx, &buffer_idx) in layout.iter().enumerate() {
248                gradient_buffer[buffer_idx] = term_gradient[term_idx] // This has a global layout
249            }
250        }
251        let likelihood_gradients = LikelihoodGradients(gradient_buffers);
252        let full_gradient = self
253            .tree
254            .evaluate_gradient(&likelihood_values, &likelihood_gradients);
255        let mut reduced = DVector::zeros(free_parameter_count);
256        for (out_idx, &global_idx) in layout
257            .global_map
258            .free_parameter_indices()
259            .iter()
260            .enumerate()
261        {
262            reduced[out_idx] = full_gradient[global_idx];
263        }
264        Ok(reduced)
265    }
266}
267
268impl LikelihoodTerm for LikelihoodExpression {
269    fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
270        LikelihoodExpression::evaluate(self, parameters)
271    }
272    fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
273        LikelihoodExpression::evaluate_gradient(self, parameters)
274    }
275    fn update(&self) {
276        self.registry.terms.iter().for_each(|term| term.update())
277    }
278    fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
279        self.registry.fix_parameter(name, value)
280    }
281
282    fn free_parameter(&self, name: &str) -> LadduResult<()> {
283        self.registry.free_parameter(name)
284    }
285
286    fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
287        self.registry.rename_parameter(old, new)
288    }
289
290    fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
291        self.registry.rename_parameters(mapping)
292    }
293
294    fn parameter_map(&self) -> ParameterMap {
295        self.registry.global_parameter_map().clone()
296    }
297}
298
299impl Debug for LikelihoodExpression {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        self.write_tree(f, "", "", "")
302    }
303}
304
305impl Display for LikelihoodExpression {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        self.write_tree(f, "", "", "")
308    }
309}
310
311impl_op_ex!(+ |a: &LikelihoodExpression, b: &LikelihoodExpression| -> LikelihoodExpression {
312    LikelihoodExpression::binary_op(a, b, LikelihoodNode::Add)
313});
314impl_op_ex!(
315    *|a: &LikelihoodExpression, b: &LikelihoodExpression| -> LikelihoodExpression {
316        LikelihoodExpression::binary_op(a, b, LikelihoodNode::Mul)
317    }
318);
319
320struct GlobalParameterLayout {
321    global_map: ParameterMap,
322    layouts: Vec<Vec<usize>>,
323}
324
325#[derive(Clone, Default)]
326struct LikelihoodRegistry {
327    terms: Vec<Box<dyn LikelihoodTerm>>,
328}
329
330impl LikelihoodRegistry {
331    fn singleton(term: Box<dyn LikelihoodTerm>) -> LadduResult<Self> {
332        let mut registry = Self::default();
333        registry.push_term(term);
334        Ok(registry)
335    }
336
337    fn push_term(&mut self, term: Box<dyn LikelihoodTerm>) -> usize {
338        let term_idx = self.terms.len();
339        self.terms.push(term);
340        term_idx
341    }
342
343    fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
344        let mut registry = Self::default();
345        let mut left_map = Vec::with_capacity(self.terms.len());
346        for term in &self.terms {
347            let idx = registry.push_term(dyn_clone::clone_box(&**term));
348            left_map.push(idx);
349        }
350        let mut right_map = Vec::with_capacity(other.terms.len());
351        for term in &other.terms {
352            let idx = registry.push_term(dyn_clone::clone_box(&**term));
353            right_map.push(idx);
354        }
355        (registry, left_map, right_map)
356    }
357
358    fn global_parameter_map(&self) -> ParameterMap {
359        let mut global = ParameterMap::default();
360        for term in &self.terms {
361            (global, _, _) = global.merge(&term.parameter_map());
362        }
363        global
364    }
365
366    fn global_layout(&self) -> LadduResult<GlobalParameterLayout> {
367        let global_map = self.global_parameter_map();
368        let global_free_index: HashMap<String, usize> = global_map
369            .free()
370            .names()
371            .into_iter()
372            .enumerate()
373            .map(|(idx, name)| (name, idx))
374            .collect();
375
376        let layouts = self
377            .terms
378            .iter()
379            .map(|term| {
380                term.parameter_map()
381                    .free()
382                    .names()
383                    .into_iter()
384                    .map(|name| {
385                        global_free_index.get(&name).copied().ok_or_else(|| {
386                            LadduError::UnregisteredParameter {
387                                name,
388                                reason: "free parameter missing in global parameter map"
389                                    .to_string(),
390                            }
391                        })
392                    })
393                    .collect()
394            })
395            .collect::<LadduResult<Vec<_>>>()?;
396
397        Ok(GlobalParameterLayout {
398            global_map,
399            layouts,
400        })
401    }
402
403    fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
404        for term in &self.terms {
405            if term.parameter_map().contains_key(name) {
406                term.parameter_map().fix_parameter(name, value)?;
407            }
408        }
409        Ok(())
410    }
411
412    fn free_parameter(&self, name: &str) -> LadduResult<()> {
413        for term in &self.terms {
414            if term.parameter_map().contains_key(name) {
415                term.parameter_map().free_parameter(name)?;
416            }
417        }
418        Ok(())
419    }
420
421    fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
422        for term in &self.terms {
423            if term.parameter_map().contains_key(new) {
424                return Err(LadduError::ParameterConflict {
425                    name: new.to_string(),
426                    reason: "rename target already exists".to_string(),
427                });
428            }
429        }
430        for term in &self.terms {
431            if term.parameter_map().contains_key(old) {
432                term.rename_parameter(old, new)?;
433            }
434        }
435        Ok(())
436    }
437
438    fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
439        for (old, new) in mapping {
440            self.rename_parameter(old, new)?;
441        }
442        Ok(())
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use std::sync::Arc;
449
450    use approx::assert_relative_eq;
451    #[cfg(feature = "mpi")]
452    use laddu_core::mpi::{finalize_mpi, get_world, use_mpi, LadduMPI};
453    use laddu_core::{
454        amplitude::{Amplitude, AmplitudeID, ExpressionDependence, Parameter},
455        data::{Dataset, DatasetMetadata, EventData},
456        parameter,
457        resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
458        vectors::Vec4,
459        Expression, LadduError, LadduResult,
460    };
461    #[cfg(feature = "mpi")]
462    use mpi::topology::{Communicator, SimpleCommunicator};
463    #[cfg(feature = "mpi")]
464    use mpi_test::mpi_test;
465    use nalgebra::DVector;
466    use num::complex::Complex64;
467    use serde::{Deserialize, Serialize};
468
469    use crate::likelihood::{LikelihoodScalar, LikelihoodTerm, NLL};
470
471    const LENGTH_MISMATCH_MESSAGE_FRAGMENT: &str = "length mismatch";
472    const AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT: &str = "No registered amplitude";
473
474    #[derive(Clone, Serialize, Deserialize)]
475    struct ConstantAmplitude {
476        name: String,
477        parameter: Parameter,
478        pid: ParameterID,
479    }
480
481    impl ConstantAmplitude {
482        #[allow(clippy::new_ret_no_self)]
483        fn new(name: &str, parameter: Parameter) -> LadduResult<Expression> {
484            Self {
485                name: name.to_string(),
486                parameter,
487                pid: ParameterID::default(),
488            }
489            .into_expression()
490        }
491    }
492
493    #[typetag::serde]
494    impl Amplitude for ConstantAmplitude {
495        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
496            self.pid = resources.register_parameter(&self.parameter)?;
497            resources.register_amplitude(&self.name)
498        }
499
500        fn dependence_hint(&self) -> ExpressionDependence {
501            ExpressionDependence::ParameterOnly
502        }
503
504        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
505            Complex64::new(parameters.get(self.pid), 0.0)
506        }
507
508        fn compute_gradient(
509            &self,
510            parameters: &Parameters,
511            _cache: &Cache,
512            gradient: &mut DVector<Complex64>,
513        ) {
514            if let Some(index) = parameters.free_index(self.pid) {
515                gradient[index] = Complex64::ONE;
516            }
517        }
518    }
519
520    #[derive(Clone, Serialize, Deserialize)]
521    struct CachedBeamScaleAmplitude {
522        name: String,
523        parameter: Parameter,
524        pid: ParameterID,
525        sid: ScalarID,
526        p4_index: usize,
527    }
528
529    impl CachedBeamScaleAmplitude {
530        #[allow(clippy::new_ret_no_self)]
531        fn new(name: &str, parameter: Parameter, p4_index: usize) -> LadduResult<Expression> {
532            Self {
533                name: name.to_string(),
534                parameter,
535                pid: ParameterID::default(),
536                sid: ScalarID::default(),
537                p4_index,
538            }
539            .into_expression()
540        }
541    }
542
543    #[typetag::serde]
544    impl Amplitude for CachedBeamScaleAmplitude {
545        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
546            self.pid = resources.register_parameter(&self.parameter)?;
547            self.sid = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
548            resources.register_amplitude(&self.name)
549        }
550
551        fn dependence_hint(&self) -> ExpressionDependence {
552            ExpressionDependence::Mixed
553        }
554
555        fn precompute(&self, event: &laddu_core::data::Event<'_>, cache: &mut Cache) {
556            cache.store_scalar(self.sid, event.p4_at(self.p4_index).e());
557        }
558
559        fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64 {
560            Complex64::new(parameters.get(self.pid), 0.0) * cache.get_scalar(self.sid)
561        }
562
563        fn compute_gradient(
564            &self,
565            parameters: &Parameters,
566            cache: &Cache,
567            gradient: &mut DVector<Complex64>,
568        ) {
569            if let Some(index) = parameters.free_index(self.pid) {
570                gradient[index] = Complex64::new(cache.get_scalar(self.sid), 0.0);
571            }
572        }
573    }
574
575    #[derive(Clone, Serialize, Deserialize)]
576    struct CacheOnlyBeamAmplitude {
577        name: String,
578        sid: ScalarID,
579        p4_index: usize,
580    }
581
582    impl CacheOnlyBeamAmplitude {
583        #[allow(clippy::new_ret_no_self)]
584        fn new(name: &str, p4_index: usize) -> LadduResult<Expression> {
585            Self {
586                name: name.to_string(),
587                sid: ScalarID::default(),
588                p4_index,
589            }
590            .into_expression()
591        }
592    }
593
594    #[typetag::serde]
595    impl Amplitude for CacheOnlyBeamAmplitude {
596        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
597            self.sid = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
598            resources.register_amplitude(&self.name)
599        }
600
601        fn dependence_hint(&self) -> ExpressionDependence {
602            ExpressionDependence::CacheOnly
603        }
604
605        fn precompute(&self, event: &laddu_core::data::Event<'_>, cache: &mut Cache) {
606            cache.store_scalar(self.sid, event.p4_at(self.p4_index).e());
607        }
608
609        fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
610            Complex64::new(cache.get_scalar(self.sid), 0.0)
611        }
612    }
613
614    fn dataset_with_weights(weights: &[f64]) -> Arc<Dataset> {
615        let metadata = Arc::new(DatasetMetadata::default());
616        let events = weights
617            .iter()
618            .map(|&weight| {
619                Arc::new(EventData {
620                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
621                    aux: vec![],
622                    weight,
623                })
624            })
625            .collect();
626        Arc::new(Dataset::new_with_metadata(events, metadata))
627    }
628
629    fn dataset_with_two_p4_and_weights(
630        beam_energies: &[(f64, f64)],
631        weights: &[f64],
632    ) -> Arc<Dataset> {
633        assert_eq!(beam_energies.len(), weights.len());
634        let metadata = Arc::new(DatasetMetadata::default());
635        let events = beam_energies
636            .iter()
637            .zip(weights.iter())
638            .map(|(&(e0, e1), &weight)| {
639                Arc::new(EventData {
640                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, e0), Vec4::new(0.0, 0.0, 0.0, e1)],
641                    aux: vec![],
642                    weight,
643                })
644            })
645            .collect();
646        Arc::new(Dataset::new_with_metadata(events, metadata))
647    }
648
649    #[cfg(feature = "mpi")]
650    fn read_resident_rss_kb() -> Option<u64> {
651        #[cfg(target_os = "linux")]
652        {
653            let status = std::fs::read_to_string("/proc/self/status").ok()?;
654            let vm_rss = status
655                .lines()
656                .find(|line| line.starts_with("VmRSS:"))?
657                .split_whitespace()
658                .nth(1)?;
659            vm_rss.parse::<u64>().ok()
660        }
661
662        #[cfg(not(target_os = "linux"))]
663        {
664            None
665        }
666    }
667
668    #[cfg(feature = "mpi")]
669    fn generated_two_p4_dataset(
670        n_events: usize,
671        base_energy: f64,
672        weight_scale: f64,
673    ) -> Arc<Dataset> {
674        let metadata = Arc::new(DatasetMetadata::default());
675        let events = (0..n_events)
676            .map(|index| {
677                let idx = index as f64;
678                let beam_e0 = base_energy + (idx % 17.0) * 0.35 + idx * 0.0025;
679                let beam_e1 = 0.5 * base_energy + (idx % 11.0) * 0.2 + idx * 0.0015;
680                let weight = 0.75 + weight_scale * (1.0 + (index % 9) as f64);
681                Arc::new(EventData {
682                    p4s: vec![
683                        Vec4::new(0.0, 0.0, 0.0, beam_e0),
684                        Vec4::new(0.0, 0.0, 0.0, beam_e1),
685                    ],
686                    aux: vec![],
687                    weight,
688                })
689            })
690            .collect();
691        Arc::new(Dataset::new_with_metadata(events, metadata))
692    }
693
694    fn make_constant_nll() -> (Box<NLL>, Vec<f64>) {
695        let amp = ConstantAmplitude::new("amp", parameter!("scale")).unwrap();
696        let expr = amp.norm_sqr();
697        let data = dataset_with_weights(&[1.0, 2.0]);
698        let mc = dataset_with_weights(&[0.5, 1.5]);
699        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
700        (nll, vec![2.0])
701    }
702
703    fn make_two_parameter_nll() -> (Box<NLL>, Vec<f64>) {
704        let amp_a = ConstantAmplitude::new("amp_a", parameter!("alpha")).unwrap();
705        let amp_b = ConstantAmplitude::new("amp_b", parameter!("beta")).unwrap();
706        let expr = (amp_a + amp_b).norm_sqr();
707        let data = dataset_with_weights(&[1.0, 2.0, 3.0, 1.0]);
708        let mc = dataset_with_weights(&[0.5, 1.5, 2.5, 0.5]);
709        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
710        (nll, vec![0.75, -1.25])
711    }
712
713    #[test]
714    fn nll_handles_reused_amplitudes_in_coherent_expression() {
715        let amp_a = ConstantAmplitude::new("amp_a", parameter!("alpha")).unwrap();
716        let amp_b = ConstantAmplitude::new("amp_b", parameter!("beta")).unwrap();
717
718        let coherent_plus = amp_a.clone() + amp_b.clone();
719        let coherent_minus = amp_a - amp_b;
720        let expr = coherent_plus.norm_sqr() + coherent_minus.norm_sqr();
721
722        let data = dataset_with_weights(&[1.0, 2.0, 3.0]);
723        let mc = dataset_with_weights(&[0.5, 1.5, 2.5]);
724        let params = vec![0.75, -1.25];
725
726        let evaluator = expr.load(&data).unwrap();
727        let direct_values = evaluator.evaluate(&params).unwrap();
728        assert_eq!(direct_values.len(), 3);
729
730        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
731        let value = nll.evaluate(&params).unwrap();
732        assert!(value.is_finite());
733
734        let gradient = nll.evaluate_gradient(&params).unwrap();
735        assert_eq!(gradient.len(), params.len());
736        assert!(gradient.iter().all(|value| value.is_finite()));
737
738        let projection = nll.project_weights(&params, None).unwrap();
739        assert_eq!(projection.len(), mc.n_events());
740        assert!(projection.iter().all(|value| value.is_finite()));
741
742        let (_, projection_gradient) = nll.project_weights_and_gradients(&params, None).unwrap();
743        assert_eq!(projection_gradient.len(), mc.n_events());
744        assert!(projection_gradient
745            .iter()
746            .all(|gradient| gradient.iter().all(|value| value.is_finite())));
747    }
748
749    #[test]
750    fn nll_exposes_expression_and_current_compiled_expression() {
751        let (nll, _) = make_two_parameter_nll();
752
753        let expression_display = nll.expression().compiled_expression().to_string();
754        assert!(expression_display.contains("amp_a(id=0)"));
755        assert!(expression_display.contains("amp_b(id=1)"));
756
757        nll.deactivate("amp_b");
758        let compiled = nll.compiled_expression().to_string();
759        assert!(compiled.contains("amp_a(id=0)"));
760        assert!(!compiled.contains("amp_b(id=1)"));
761        assert!(!compiled.contains("const 0"));
762        assert!(!compiled.contains("+"));
763    }
764
765    #[test]
766    fn stochastic_nll_exposes_expression_and_current_compiled_expression() {
767        let (nll, _) = make_two_parameter_nll();
768        let stochastic = nll
769            .to_stochastic(2, Some(0))
770            .expect("stochastic NLL should build");
771
772        assert!(stochastic
773            .expression()
774            .compiled_expression()
775            .to_string()
776            .contains("amp_a(id=0)"));
777        assert!(stochastic
778            .compiled_expression()
779            .to_string()
780            .contains("amp_b(id=1)"));
781    }
782
783    #[derive(Clone, Copy)]
784    enum DeterministicModelKind {
785        Separable,
786        Partial,
787        NonSeparable,
788    }
789
790    struct DeterministicNllFixture {
791        nll: Box<NLL>,
792        parameters: Vec<f64>,
793    }
794
795    const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
796    const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
797
798    fn assert_nll_fixture_matches_weighted_baseline(fixture: &DeterministicNllFixture) {
799        let expected_value = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
800            &fixture.nll.data_evaluator,
801            &fixture.parameters,
802            |l| f64::ln(l.re),
803        )
804        .expect("evaluate should succeed");
805        let expected_mc_term = fixture
806            .nll
807            .accmc_evaluator
808            .evaluate_weighted_value_sum_local(&fixture.parameters)
809            .expect("evaluate should succeed");
810        let expected_value = -2.0 * (expected_value - expected_mc_term / fixture.nll.n_mc);
811
812        let expected_data_gradient = fixture
813            .nll
814            .evaluate_data_gradient_term_local(&fixture.parameters)
815            .expect("evaluate should succeed");
816        let expected_mc_gradient = fixture
817            .nll
818            .accmc_evaluator
819            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
820            .expect("evaluate should succeed");
821        let expected_gradient =
822            -2.0 * (expected_data_gradient - expected_mc_gradient / fixture.nll.n_mc);
823
824        let actual_value = fixture
825            .nll
826            .evaluate_local(&fixture.parameters)
827            .expect("evaluate should succeed");
828        assert_relative_eq!(
829            actual_value,
830            expected_value,
831            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
832            max_relative = DETERMINISTIC_STRICT_REL_TOL
833        );
834
835        let actual_gradient = fixture
836            .nll
837            .evaluate_gradient_local(&fixture.parameters)
838            .expect("evaluate should succeed");
839        assert_eq!(
840            actual_gradient.len(),
841            expected_gradient.len(),
842            "fixture NLL gradient length mismatch (actual={}, expected={})",
843            actual_gradient.len(),
844            expected_gradient.len()
845        );
846        for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
847            assert_relative_eq!(
848                *actual_item,
849                *expected_item,
850                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
851                max_relative = DETERMINISTIC_STRICT_REL_TOL
852            );
853        }
854    }
855
856    #[cfg(feature = "mpi")]
857    fn assert_nll_fixture_matches_mpi_reduced_baseline(
858        fixture: &DeterministicNllFixture,
859        world: &SimpleCommunicator,
860    ) {
861        let data_term_local = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
862            &fixture.nll.data_evaluator,
863            &fixture.parameters,
864            |l| f64::ln(l.re),
865        )
866        .expect("evaluate should succeed");
867        let mc_term_local = fixture
868            .nll
869            .accmc_evaluator
870            .evaluate_weighted_value_sum_local(&fixture.parameters)
871            .expect("evaluate should succeed");
872        let data_term = crate::likelihood::nll::reduce_scalar(world, data_term_local);
873        let mc_term = crate::likelihood::nll::reduce_scalar(world, mc_term_local);
874        let expected_value = -2.0 * (data_term - mc_term / fixture.nll.n_mc);
875        let mpi_value = fixture
876            .nll
877            .evaluate_mpi(&fixture.parameters, world)
878            .expect("evaluate should succeed");
879        assert_relative_eq!(
880            mpi_value,
881            expected_value,
882            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
883            max_relative = DETERMINISTIC_STRICT_REL_TOL
884        );
885
886        let data_gradient_local = fixture
887            .nll
888            .evaluate_data_gradient_term_local(&fixture.parameters)
889            .expect("evaluate should succeed");
890        let mc_gradient_local = fixture
891            .nll
892            .accmc_evaluator
893            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
894            .expect("evaluate should succeed");
895        let data_gradient = crate::likelihood::nll::reduce_gradient(world, &data_gradient_local);
896        let mc_gradient = crate::likelihood::nll::reduce_gradient(world, &mc_gradient_local);
897        let expected_gradient = -2.0 * (data_gradient - mc_gradient / fixture.nll.n_mc);
898        let mpi_gradient = fixture
899            .nll
900            .evaluate_gradient_mpi(&fixture.parameters, world)
901            .expect("evaluate should succeed");
902        assert_eq!(
903            mpi_gradient.len(),
904            expected_gradient.len(),
905            "fixture MPI gradient length mismatch (actual={}, expected={})",
906            mpi_gradient.len(),
907            expected_gradient.len()
908        );
909        for (actual_item, expected_item) in mpi_gradient.iter().zip(expected_gradient.iter()) {
910            assert_relative_eq!(
911                *actual_item,
912                *expected_item,
913                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
914                max_relative = DETERMINISTIC_STRICT_REL_TOL
915            );
916        }
917    }
918
919    fn make_deterministic_nll_fixture(kind: DeterministicModelKind) -> DeterministicNllFixture {
920        let data = dataset_with_two_p4_and_weights(
921            &[
922                (1.0, 0.8),
923                (2.5, 1.7),
924                (4.0, 2.4),
925                (3.3, 1.1),
926                (5.2, 2.8),
927                (1.7, 0.9),
928            ],
929            &[0.7, 1.2, 0.9, 1.5, 0.8, 1.1],
930        );
931        let mc = dataset_with_two_p4_and_weights(
932            &[
933                (1.5, 1.0),
934                (3.0, 2.1),
935                (5.5, 2.9),
936                (2.0, 1.2),
937                (4.2, 1.8),
938                (2.8, 1.4),
939            ],
940            &[0.8, 1.4, 0.6, 1.1, 0.75, 1.25],
941        );
942
943        match kind {
944            DeterministicModelKind::Separable => {
945                let p1 = ConstantAmplitude::new("p1", parameter!("p1"))
946                    .expect("separable p1 should build");
947                let p2 = ConstantAmplitude::new("p2", parameter!("p2"))
948                    .expect("separable p2 should build");
949                let c1 = CacheOnlyBeamAmplitude::new("c1", 0).expect("separable c1 should build");
950                let c2 = CacheOnlyBeamAmplitude::new("c2", 1).expect("separable c2 should build");
951                let expression = (&p1 * &c1) + &(&p2 * &c2);
952                DeterministicNllFixture {
953                    nll: NLL::new(&expression, &data, &mc, None)
954                        .expect("separable NLL should build"),
955                    parameters: vec![0.4, 0.2],
956                }
957            }
958            DeterministicModelKind::Partial => {
959                let p =
960                    ConstantAmplitude::new("p", parameter!("p")).expect("partial p should build");
961                let c = CacheOnlyBeamAmplitude::new("c", 0).expect("partial c should build");
962                let m = CachedBeamScaleAmplitude::new("m", parameter!("m"), 1)
963                    .expect("partial m should build");
964                let expression = (&p * &c) + &m;
965                DeterministicNllFixture {
966                    nll: NLL::new(&expression, &data, &mc, None).expect("partial NLL should build"),
967                    parameters: vec![0.35, 0.25],
968                }
969            }
970            DeterministicModelKind::NonSeparable => {
971                let m1 = CachedBeamScaleAmplitude::new("m1", parameter!("m1"), 0)
972                    .expect("non-separable m1 should build");
973                let m2 = CachedBeamScaleAmplitude::new("m2", parameter!("m2"), 1)
974                    .expect("non-separable m2 should build");
975                let expression = &m1 * &m2;
976                DeterministicNllFixture {
977                    nll: NLL::new(&expression, &data, &mc, None)
978                        .expect("non-separable NLL should build"),
979                    parameters: vec![0.2, 0.15],
980                }
981            }
982        }
983    }
984
985    #[cfg(feature = "mpi")]
986    fn make_mixed_workload_nll_fixture(n_events: usize) -> DeterministicNllFixture {
987        let data = generated_two_p4_dataset(n_events, 1.4, 0.08);
988        let mc = generated_two_p4_dataset(n_events, 1.9, 0.11);
989        let p =
990            ConstantAmplitude::new("p", parameter!("p")).expect("mixed-workload p should build");
991        let c = CacheOnlyBeamAmplitude::new("c", 0)
992            .expect("mixed-workload cache amplitude should build");
993        let m = CachedBeamScaleAmplitude::new("m", parameter!("m"), 1)
994            .expect("mixed-workload beam amplitude should build");
995        let expression = (&p * &c) + &m;
996        DeterministicNllFixture {
997            nll: NLL::new(&expression, &data, &mc, None).expect("mixed-workload NLL should build"),
998            parameters: vec![0.35, 0.25],
999        }
1000    }
1001
1002    fn case_nll_evaluate_short(nll: &NLL) -> LadduResult<()> {
1003        nll.evaluate(&[]).map(|_| ())
1004    }
1005
1006    fn case_nll_evaluate_gradient_long(nll: &NLL) -> LadduResult<()> {
1007        nll.evaluate_gradient(&[1.0, 2.0]).map(|_| ())
1008    }
1009
1010    fn case_nll_project_short(nll: &NLL) -> LadduResult<()> {
1011        nll.project_weights(&[], None).map(|_| ())
1012    }
1013
1014    fn case_nll_project_weights_and_gradients_long(nll: &NLL) -> LadduResult<()> {
1015        nll.project_weights_and_gradients(&[1.0, 2.0], None)
1016            .map(|_| ())
1017    }
1018
1019    fn case_nll_project_weights_subset_short(nll: &NLL) -> LadduResult<()> {
1020        nll.project_weights_subset_local::<&str>(&[], &["missing_amplitude"], None)
1021            .map(|_| ())
1022    }
1023
1024    fn case_nll_project_weights_and_gradients_subset_long(nll: &NLL) -> LadduResult<()> {
1025        nll.project_weights_and_gradients_subset_local::<&str>(
1026            &[1.0, 2.0],
1027            &["missing_amplitude"],
1028            None,
1029        )
1030        .map(|_| ())
1031    }
1032
1033    fn case_likelihood_evaluate_short() -> LadduResult<()> {
1034        let alpha = LikelihoodScalar::new("alpha")?;
1035        alpha.evaluate(&[]).map(|_| ())
1036    }
1037
1038    fn case_likelihood_gradient_long() -> LadduResult<()> {
1039        let alpha = LikelihoodScalar::new("alpha")?;
1040        alpha.evaluate_gradient(&[1.0, 2.0]).map(|_| ())
1041    }
1042
1043    #[test]
1044    fn table_driven_length_mismatch_errors() {
1045        let (nll, _) = make_constant_nll();
1046        let cases: [(&str, LadduResult<()>); 8] = [
1047            ("nll.evaluate short", case_nll_evaluate_short(nll.as_ref())),
1048            (
1049                "nll.evaluate_gradient long",
1050                case_nll_evaluate_gradient_long(nll.as_ref()),
1051            ),
1052            (
1053                "nll.project_weights short",
1054                case_nll_project_short(nll.as_ref()),
1055            ),
1056            (
1057                "nll.project_weights_and_gradients long",
1058                case_nll_project_weights_and_gradients_long(nll.as_ref()),
1059            ),
1060            (
1061                "nll.project_weights_subset short",
1062                case_nll_project_weights_subset_short(nll.as_ref()),
1063            ),
1064            (
1065                "nll.project_weights_and_gradients_subset long",
1066                case_nll_project_weights_and_gradients_subset_long(nll.as_ref()),
1067            ),
1068            (
1069                "likelihood.evaluate short",
1070                case_likelihood_evaluate_short(),
1071            ),
1072            (
1073                "likelihood.evaluate_gradient long",
1074                case_likelihood_gradient_long(),
1075            ),
1076        ];
1077        for (label, result) in cases {
1078            let err = result.unwrap_err();
1079            assert!(
1080                matches!(err, LadduError::LengthMismatch { .. }),
1081                "expected LengthMismatch for {label}, got {err:?}"
1082            );
1083            assert!(
1084                err.to_string().contains(LENGTH_MISMATCH_MESSAGE_FRAGMENT),
1085                "expected message containing \"{LENGTH_MISMATCH_MESSAGE_FRAGMENT}\" for {label}, got {}",
1086                err
1087            );
1088        }
1089    }
1090
1091    #[test]
1092    fn table_driven_unknown_amplitude_errors() {
1093        let (nll, params) = make_constant_nll();
1094        let cases: [(&str, LadduResult<()>); 4] = [
1095            (
1096                "activate_strict unknown",
1097                nll.activate_strict("missing_amplitude"),
1098            ),
1099            (
1100                "isolate_strict unknown",
1101                nll.isolate_strict("missing_amplitude"),
1102            ),
1103            (
1104                "project_weights_subset unknown",
1105                nll.project_weights_subset_local_strict::<&str>(
1106                    &params,
1107                    &["missing_amplitude"],
1108                    None,
1109                )
1110                .map(|_| ()),
1111            ),
1112            (
1113                "project_weights_and_gradients_subset unknown",
1114                nll.project_weights_and_gradients_subset_local_strict::<&str>(
1115                    &params,
1116                    &["missing_amplitude"],
1117                    None,
1118                )
1119                .map(|_| ()),
1120            ),
1121        ];
1122        for (label, result) in cases {
1123            let err = result.unwrap_err();
1124            assert!(
1125                matches!(err, LadduError::AmplitudeNotFoundError { .. }),
1126                "expected AmplitudeNotFoundError for {label}, got {err:?}"
1127            );
1128            assert!(
1129                err.to_string()
1130                    .contains(AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT),
1131                "expected message containing \"{AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT}\" for {label}, got {}",
1132                err
1133            );
1134        }
1135    }
1136
1137    #[test]
1138    fn likelihood_expression_evaluates_scalar_sum() {
1139        let alpha = LikelihoodScalar::new("alpha").unwrap();
1140        let beta = LikelihoodScalar::new("beta").unwrap();
1141        let expr = &alpha + &beta;
1142        assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1143        let params = vec![2.0, 3.0];
1144        assert_relative_eq!(expr.evaluate(&params).unwrap(), 5.0);
1145        let grad = expr.evaluate_gradient(&params).unwrap();
1146        assert_relative_eq!(grad[0], 1.0);
1147        assert_relative_eq!(grad[1], 1.0);
1148    }
1149
1150    #[test]
1151    fn likelihood_expression_evaluates_scalar_product() {
1152        let alpha = LikelihoodScalar::new("alpha").unwrap();
1153        let beta = LikelihoodScalar::new("beta").unwrap();
1154        let expr = &alpha * &beta;
1155        let params = vec![2.0, 3.0];
1156        assert_relative_eq!(expr.evaluate(&params).unwrap(), 6.0);
1157        let grad = expr.evaluate_gradient(&params).unwrap();
1158        assert_relative_eq!(grad[0], 3.0);
1159        assert_relative_eq!(grad[1], 2.0);
1160    }
1161
1162    #[test]
1163    fn likelihood_expression_tracks_fixed_parameters() {
1164        let alpha = LikelihoodScalar::new("alpha").unwrap();
1165        let beta = LikelihoodScalar::new("beta").unwrap();
1166        let expr = &alpha + &beta;
1167        expr.fix_parameter("alpha", 1.5).unwrap();
1168        assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1169        assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1170        assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1171        let params_free = vec![2.0];
1172        assert_relative_eq!(expr.evaluate(&params_free).unwrap(), 3.5);
1173        let grad_free = expr.evaluate_gradient(&params_free).unwrap();
1174        assert_eq!(grad_free.len(), 1);
1175        assert_relative_eq!(grad_free[0], 1.0);
1176    }
1177
1178    #[test]
1179    fn likelihood_expression_handles_term_local_fixed_parameters() {
1180        let alpha = LikelihoodScalar::new("alpha").unwrap();
1181        alpha.fix_parameter("alpha", 1.5).unwrap();
1182        let beta = LikelihoodScalar::new("beta").unwrap();
1183        let expr = &alpha + &beta;
1184        assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1185        assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1186        assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1187
1188        let params_free = vec![2.0];
1189        assert_relative_eq!(expr.evaluate(&params_free).unwrap(), 3.5);
1190        let grad_free = expr.evaluate_gradient(&params_free).unwrap();
1191        assert_eq!(grad_free.len(), 1);
1192        assert_relative_eq!(grad_free[0], 1.0);
1193    }
1194
1195    #[test]
1196    fn likelihood_product_handles_term_local_fixed_parameters() {
1197        let alpha = LikelihoodScalar::new("alpha").unwrap();
1198        alpha.fix_parameter("alpha", 1.5).unwrap();
1199        let beta = LikelihoodScalar::new("beta").unwrap();
1200        let expr = &alpha * &beta;
1201        assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1202        assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1203        assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1204
1205        let params_free = vec![2.0];
1206        assert_relative_eq!(expr.evaluate(&params_free).unwrap(), 3.0);
1207        let grad_free = expr.evaluate_gradient(&params_free).unwrap();
1208        assert_eq!(grad_free.len(), 1);
1209        assert_relative_eq!(grad_free[0], 1.5);
1210    }
1211
1212    #[test]
1213    fn nll_evaluate_and_gradient_match_closed_form() {
1214        let (nll, params) = make_constant_nll();
1215        let intensity = params[0] * params[0];
1216        let weight_sum = 3.0;
1217        let expected = -2.0 * (weight_sum * intensity.ln() - intensity);
1218        assert_relative_eq!(nll.evaluate(&params).unwrap(), expected, epsilon = 1e-12);
1219        let grad = nll.evaluate_gradient(&params).unwrap();
1220        let expected_grad = -4.0 * (weight_sum / params[0] - params[0]);
1221        assert_relative_eq!(grad[0], expected_grad, epsilon = 1e-12);
1222    }
1223
1224    #[cfg(feature = "rayon")]
1225    #[test]
1226    fn gradient_scratch_reuse_is_thread_safe_across_parallel_calls() {
1227        let (nll_single, params_single) = make_constant_nll();
1228        let (nll_multi, params_multi) = make_two_parameter_nll();
1229        let nll_single = Arc::new(*nll_single);
1230        let nll_multi = Arc::new(*nll_multi);
1231        let expected_single = nll_single
1232            .evaluate_gradient(&params_single)
1233            .expect("single-parameter gradient should evaluate");
1234        let expected_multi = nll_multi
1235            .evaluate_gradient(&params_multi)
1236            .expect("two-parameter gradient should evaluate");
1237        std::thread::scope(|scope| {
1238            for _ in 0..8 {
1239                let nll_single = Arc::clone(&nll_single);
1240                let nll_multi = Arc::clone(&nll_multi);
1241                let params_single = params_single.clone();
1242                let params_multi = params_multi.clone();
1243                let expected_single = expected_single.clone();
1244                let expected_multi = expected_multi.clone();
1245                scope.spawn(move || {
1246                    for _ in 0..100 {
1247                        let single_gradient = nll_single
1248                            .evaluate_gradient(&params_single)
1249                            .expect("single-parameter gradient should evaluate");
1250                        assert_relative_eq!(
1251                            single_gradient[0],
1252                            expected_single[0],
1253                            epsilon = 1e-12
1254                        );
1255                        let multi_gradient = nll_multi
1256                            .evaluate_gradient(&params_multi)
1257                            .expect("two-parameter gradient should evaluate");
1258                        assert_eq!(multi_gradient.len(), expected_multi.len());
1259                        for index in 0..expected_multi.len() {
1260                            assert_relative_eq!(
1261                                multi_gradient[index],
1262                                expected_multi[index],
1263                                epsilon = 1e-12
1264                            );
1265                        }
1266                    }
1267                });
1268            }
1269        });
1270    }
1271
1272    #[test]
1273    fn nll_value_matches_mixed_scale_weighted_closed_form() {
1274        let amp = ConstantAmplitude::new("amp", parameter!("scale")).unwrap();
1275        let expr = amp.norm_sqr();
1276        let data = dataset_with_weights(&[1.0e12, 1.0e-12, 3.5, 7.25e4, 2.0e-3]);
1277        let mc = dataset_with_weights(&[4.0e9, 9.0e-6, 1.25, 2.5e2, 8.0e-4]);
1278        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1279        let params = vec![1.125];
1280
1281        let intensity: f64 = params[0] * params[0];
1282        let data_weight_sum = data.weights_local().iter().copied().sum::<f64>();
1283        let mc_weight_sum = mc.weights_local().iter().copied().sum::<f64>();
1284        let n_mc = mc.n_events_weighted();
1285        let expected = -2.0 * (data_weight_sum * intensity.ln() - mc_weight_sum * intensity / n_mc);
1286
1287        let value = nll.evaluate(&params).unwrap();
1288        assert_relative_eq!(value, expected, epsilon = 1e-9, max_relative = 1e-12);
1289    }
1290
1291    #[test]
1292    fn nll_evaluate_and_gradient_match_hardcoded_weighted_reference() {
1293        let amp_a = CachedBeamScaleAmplitude::new("amp_a", parameter!("alpha"), 0).unwrap();
1294        let amp_b = CachedBeamScaleAmplitude::new("amp_b", parameter!("beta"), 1).unwrap();
1295        let expr = (&amp_a + &amp_b).norm_sqr();
1296        let data = dataset_with_two_p4_and_weights(
1297            &[(1.0, 0.8), (2.5, 1.7), (4.0, 2.4), (3.3, 1.1)],
1298            &[0.7, 1.2, 0.9, 1.5],
1299        );
1300        let mc = dataset_with_two_p4_and_weights(
1301            &[(1.5, 1.0), (3.0, 2.1), (5.5, 2.9), (2.0, 1.2), (4.2, 1.8)],
1302            &[0.8, 1.4, 0.6, 1.1, 0.75],
1303        );
1304        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1305        let params = vec![0.6, 1.1];
1306        assert_eq!(nll.parameters().free().names(), vec!["alpha", "beta"]);
1307
1308        let value = nll.evaluate(&params).unwrap();
1309        assert_relative_eq!(value, 12.242296380697244, epsilon = 1e-12);
1310
1311        let gradient = nll.evaluate_gradient(&params).unwrap();
1312        assert_eq!(gradient.len(), 2);
1313        assert_relative_eq!(gradient[0], 37.78259267741666, epsilon = 1e-12);
1314        assert_relative_eq!(gradient[1], 21.8538272590435, epsilon = 1e-12);
1315    }
1316
1317    #[test]
1318    fn nll_deterministic_fixtures_cover_separable_partial_and_non_separable_models() {
1319        let separable = make_deterministic_nll_fixture(DeterministicModelKind::Separable);
1320        let partial = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1321        let non_separable = make_deterministic_nll_fixture(DeterministicModelKind::NonSeparable);
1322
1323        for fixture in [separable, partial, non_separable] {
1324            assert_nll_fixture_matches_weighted_baseline(&fixture);
1325        }
1326    }
1327
1328    #[test]
1329    fn nll_deterministic_fixture_matches_baseline_across_activation_toggles() {
1330        let fixture = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1331        assert_nll_fixture_matches_weighted_baseline(&fixture);
1332
1333        fixture.nll.isolate_many(&["p", "c"]);
1334        assert_nll_fixture_matches_weighted_baseline(&fixture);
1335
1336        fixture.nll.activate_all();
1337        assert_nll_fixture_matches_weighted_baseline(&fixture);
1338    }
1339
1340    #[test]
1341    fn nll_project_returns_weighted_intensity() {
1342        let (nll, params) = make_constant_nll();
1343        let projection = nll.project_weights_local(&params, None).unwrap();
1344        assert_relative_eq!(projection[0], 1.0, epsilon = 1e-12);
1345        assert_relative_eq!(projection[1], 3.0, epsilon = 1e-12);
1346    }
1347
1348    #[test]
1349    fn nll_project_reports_structured_length_error() {
1350        let (nll, _) = make_constant_nll();
1351        let err = nll.project_weights(&[], None).unwrap_err();
1352        assert!(matches!(
1353            err,
1354            LadduError::LengthMismatch {
1355                expected: 1,
1356                actual: 0,
1357                ..
1358            }
1359        ));
1360    }
1361
1362    #[test]
1363    fn nll_project_weights_subset_reports_structured_missing_amplitude_error() {
1364        let (nll, params) = make_constant_nll();
1365        let err = nll
1366            .project_weights_subset_local_strict::<&str>(&params, &["missing_amplitude"], None)
1367            .unwrap_err();
1368        assert!(matches!(err, LadduError::AmplitudeNotFoundError { .. }));
1369    }
1370
1371    #[test]
1372    fn nll_project_weights_subsets_matches_repeated_project_weights_subset_calls() {
1373        let (nll, params) = make_two_parameter_nll();
1374        let subsets = vec![
1375            vec!["amp_a".to_string()],
1376            vec!["amp_b".to_string()],
1377            vec!["amp_a".to_string(), "amp_b".to_string()],
1378        ];
1379        let batched = nll
1380            .project_weights_subsets_local(&params, &subsets, None)
1381            .expect("batched projection should evaluate");
1382        let repeated = subsets
1383            .iter()
1384            .map(|subset| {
1385                nll.project_weights_subset_local(&params, subset, None)
1386                    .expect("single subset projection should evaluate")
1387            })
1388            .collect::<Vec<_>>();
1389        assert_eq!(batched.len(), repeated.len());
1390        for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1391            assert_eq!(lhs.len(), rhs.len());
1392            for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1393                assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1394            }
1395        }
1396    }
1397
1398    #[test]
1399    fn nll_project_weights_subsets_handles_empty_and_duplicate_subsets() {
1400        let (nll, params) = make_two_parameter_nll();
1401        let empty: Vec<Vec<String>> = Vec::new();
1402        let empty_projection = nll
1403            .project_weights_subsets_local(&params, &empty, None)
1404            .expect("empty subset list should evaluate");
1405        assert!(empty_projection.is_empty());
1406
1407        let subsets = vec![
1408            vec!["amp_b".to_string()],
1409            vec!["amp_a".to_string()],
1410            vec!["amp_a".to_string(), "amp_b".to_string()],
1411            vec!["amp_a".to_string()],
1412            vec!["amp_b".to_string()],
1413        ];
1414        let batched = nll
1415            .project_weights_subsets_local(&params, &subsets, None)
1416            .expect("batched projection should evaluate");
1417        let repeated = subsets
1418            .iter()
1419            .map(|subset| {
1420                nll.project_weights_subset_local(&params, subset, None)
1421                    .expect("single subset projection should evaluate")
1422            })
1423            .collect::<Vec<_>>();
1424        assert_eq!(batched.len(), repeated.len());
1425        for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1426            assert_eq!(lhs.len(), rhs.len());
1427            for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1428                assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1429            }
1430        }
1431    }
1432
1433    #[test]
1434    fn nll_project_weights_subsets_reports_missing_amplitude_error() {
1435        let (nll, params) = make_two_parameter_nll();
1436        let subsets = vec![vec!["amp_a".to_string()], vec!["missing".to_string()]];
1437        let err = nll
1438            .project_weights_subsets_local_strict(&params, &subsets, None)
1439            .expect_err("missing amplitude should fail");
1440        assert!(matches!(err, LadduError::AmplitudeNotFoundError { .. }));
1441    }
1442
1443    #[test]
1444    fn nll_project_weights_and_gradients_subset_matches_repeated_calls() {
1445        let (nll, params) = make_two_parameter_nll();
1446        let subsets = vec![
1447            vec!["amp_b".to_string()],
1448            vec!["amp_a".to_string()],
1449            vec!["amp_a".to_string(), "amp_b".to_string()],
1450            vec!["amp_a".to_string()],
1451        ];
1452        for subset in subsets {
1453            let (weights_local, gradients_local) = nll
1454                .project_weights_and_gradients_subset_local(&params, &subset, None)
1455                .expect("local gradient projection should evaluate");
1456            let (weights_auto, gradients_auto) = nll
1457                .project_weights_and_gradients_subset(&params, &subset, None)
1458                .expect("auto gradient projection should evaluate");
1459            assert_eq!(weights_local.len(), weights_auto.len());
1460            assert_eq!(gradients_local.len(), gradients_auto.len());
1461            for (lhs, rhs) in weights_local.iter().zip(weights_auto.iter()) {
1462                assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1463            }
1464            for (lhs, rhs) in gradients_local.iter().zip(gradients_auto.iter()) {
1465                assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1466            }
1467        }
1468    }
1469
1470    #[test]
1471    fn nll_activation_changes_invalidate_projection_mask_cache() {
1472        let (nll, params) = make_constant_nll();
1473        assert!(nll.projection_active_mask_cache.lock().is_empty());
1474
1475        let _ = nll
1476            .project_weights_subset_local::<&str>(&params, &["amp"], None)
1477            .unwrap();
1478        assert!(!nll.projection_active_mask_cache.lock().is_empty());
1479
1480        nll.deactivate("amp");
1481        assert!(nll.projection_active_mask_cache.lock().is_empty());
1482
1483        let projection = nll
1484            .project_weights_subset_local::<&str>(&params, &["amp"], None)
1485            .unwrap();
1486        assert_relative_eq!(projection[0], 1.0, epsilon = 1e-12);
1487        assert_relative_eq!(projection[1], 3.0, epsilon = 1e-12);
1488    }
1489
1490    #[test]
1491    fn nll_project_weights_subset_validates_length_before_isolation() {
1492        let (nll, _) = make_constant_nll();
1493        let err = nll
1494            .project_weights_subset_local::<&str>(&[], &["missing_amplitude"], None)
1495            .unwrap_err();
1496        assert!(matches!(
1497            err,
1498            LadduError::LengthMismatch {
1499                expected: 1,
1500                actual: 0,
1501                ..
1502            }
1503        ));
1504    }
1505
1506    #[test]
1507    fn nll_project_weights_and_gradients_subset_validates_length_before_isolation() {
1508        let (nll, _) = make_constant_nll();
1509        let err = nll
1510            .project_weights_and_gradients_subset_local::<&str>(
1511                &[1.0, 2.0],
1512                &["missing_amplitude"],
1513                None,
1514            )
1515            .unwrap_err();
1516        assert!(matches!(
1517            err,
1518            LadduError::LengthMismatch {
1519                expected: 1,
1520                actual: 2,
1521                ..
1522            }
1523        ));
1524    }
1525
1526    #[test]
1527    fn stochastic_nll_validates_batch_size() {
1528        let (nll, _params) = make_constant_nll();
1529        let err_zero = match nll.to_stochastic(0, Some(0)) {
1530            Ok(_) => panic!("expected batch_size=0 to return an error"),
1531            Err(err) => err,
1532        };
1533        assert!(matches!(
1534            err_zero,
1535            LadduError::LengthMismatch {
1536                expected: 2,
1537                actual: 0,
1538                ..
1539            }
1540        ));
1541
1542        let err_large = match nll.to_stochastic(3, Some(0)) {
1543            Ok(_) => panic!("expected oversized batch to return an error"),
1544            Err(err) => err,
1545        };
1546        assert!(matches!(
1547            err_large,
1548            LadduError::LengthMismatch {
1549                expected: 2,
1550                actual: 3,
1551                ..
1552            }
1553        ));
1554    }
1555
1556    #[test]
1557    fn stochastic_nll_accepts_full_dataset_batch() {
1558        let (nll, params) = make_constant_nll();
1559        let stochastic = nll.to_stochastic(2, Some(0)).unwrap();
1560        let value = stochastic.evaluate(&params).unwrap();
1561        assert!(value.is_finite());
1562    }
1563
1564    #[test]
1565    fn stochastic_nll_matches_closed_form_on_full_batch() {
1566        let (nll, params) = make_constant_nll();
1567        let stochastic = nll
1568            .to_stochastic(nll.data_evaluator.dataset.n_events(), Some(0))
1569            .unwrap();
1570        let stochastic_value = stochastic.evaluate(&params).unwrap();
1571        let deterministic_value = nll.evaluate(&params).unwrap();
1572        assert_relative_eq!(stochastic_value, deterministic_value, epsilon = 1e-12);
1573    }
1574
1575    #[test]
1576    fn likelihood_evaluator_reports_length_mismatch() {
1577        let alpha = LikelihoodScalar::new("alpha").unwrap();
1578
1579        let err_short = alpha.evaluate(&[]).unwrap_err();
1580        assert!(matches!(
1581            err_short,
1582            LadduError::LengthMismatch {
1583                expected: 1,
1584                actual: 0,
1585                ..
1586            }
1587        ));
1588
1589        let err_long = alpha.evaluate_gradient(&[1.0, 2.0]).unwrap_err();
1590        assert!(matches!(
1591            err_long,
1592            LadduError::LengthMismatch {
1593                expected: 1,
1594                actual: 2,
1595                ..
1596            }
1597        ));
1598    }
1599
1600    #[cfg(feature = "mpi")]
1601    #[mpi_test(np = [2])]
1602    fn mpi_negative_paths_report_structured_errors() {
1603        use_mpi(true);
1604        let world = get_world().expect("MPI world should be initialized");
1605        let (nll, params) = make_constant_nll();
1606
1607        let err_len = nll.project_weights_mpi(&[], None, &world).unwrap_err();
1608        assert!(matches!(
1609            err_len,
1610            LadduError::LengthMismatch {
1611                expected: 1,
1612                actual: 0,
1613                ..
1614            }
1615        ));
1616
1617        let err_amp = nll
1618            .project_weights_subset_mpi_strict::<&str>(
1619                &params,
1620                &["missing_amplitude"],
1621                None,
1622                &world,
1623            )
1624            .unwrap_err();
1625        assert!(matches!(err_amp, LadduError::AmplitudeNotFoundError { .. }));
1626        finalize_mpi();
1627    }
1628
1629    #[cfg(feature = "mpi")]
1630    #[mpi_test(np = [2])]
1631    fn mpi_value_and_gradient_match_total_non_mpi() {
1632        use_mpi(true);
1633        let world = get_world().expect("MPI world should be initialized");
1634        let (nll, params) = make_constant_nll();
1635        let data_term_local = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
1636            &nll.data_evaluator,
1637            &params,
1638            |l| f64::ln(l.re),
1639        )
1640        .expect("evaluate should succeed");
1641        let mc_term_local = nll
1642            .accmc_evaluator
1643            .evaluate_weighted_value_sum_local(&params)
1644            .expect("evaluate should succeed");
1645        let data_term = crate::likelihood::nll::reduce_scalar(&world, data_term_local);
1646        let mc_term = crate::likelihood::nll::reduce_scalar(&world, mc_term_local);
1647        let expected_value = -2.0 * (data_term - mc_term / nll.n_mc);
1648
1649        let mpi_value = nll
1650            .evaluate_mpi(&params, &world)
1651            .expect("evaluate should succeed");
1652        assert_relative_eq!(mpi_value, expected_value);
1653
1654        let data_gradient_local = nll
1655            .evaluate_data_gradient_term_local(&params)
1656            .expect("evaluate should succeed");
1657        let mc_gradient_local = nll
1658            .accmc_evaluator
1659            .evaluate_weighted_gradient_sum_local(&params)
1660            .expect("evaluate should succeed");
1661        let data_gradient = crate::likelihood::nll::reduce_gradient(&world, &data_gradient_local);
1662        let mc_gradient = crate::likelihood::nll::reduce_gradient(&world, &mc_gradient_local);
1663        let expected_gradient = -2.0 * (data_gradient - mc_gradient / nll.n_mc);
1664        let mpi_gradient = nll
1665            .evaluate_gradient_mpi(&params, &world)
1666            .expect("evaluate should succeed");
1667        assert_relative_eq!(mpi_gradient, expected_gradient);
1668
1669        finalize_mpi();
1670    }
1671
1672    #[cfg(feature = "mpi")]
1673    #[mpi_test(np = [2])]
1674    fn mpi_deterministic_fixture_matches_local_and_reduced_baselines_across_activation_toggles() {
1675        use_mpi(true);
1676        let world = get_world().expect("MPI world should be initialized");
1677
1678        let fixture = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1679        assert_nll_fixture_matches_weighted_baseline(&fixture);
1680        assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1681
1682        fixture.nll.isolate_many(&["p", "c"]);
1683        assert_nll_fixture_matches_weighted_baseline(&fixture);
1684        assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1685
1686        fixture.nll.activate_all();
1687        assert_nll_fixture_matches_weighted_baseline(&fixture);
1688        assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1689
1690        finalize_mpi();
1691    }
1692
1693    #[cfg(feature = "mpi")]
1694    #[mpi_test(np = [2])]
1695    fn mpi_mixed_scale_value_matches_local_evaluate() {
1696        use_mpi(true);
1697        let world = get_world().expect("MPI world should be initialized");
1698        let amp_a = CachedBeamScaleAmplitude::new("amp_a", parameter!("scale_a"), 0).unwrap();
1699        let amp_b = CachedBeamScaleAmplitude::new("amp_b", parameter!("scale_b"), 1).unwrap();
1700        let expr = (amp_a + amp_b).norm_sqr();
1701        let data = dataset_with_two_p4_and_weights(
1702            &[(1.0, 0.5), (10.0, 1.0), (3.0, 5.0), (1.0e2, 2.0e-1)],
1703            &[1.0e12, 1.0e-12, 3.5, 7.25e4],
1704        );
1705        let mc = dataset_with_two_p4_and_weights(
1706            &[(4.0, 0.1), (6.0, 2.0), (8.0, 1.5), (1.0e1, 3.0)],
1707            &[4.0e9, 9.0e-6, 1.25, 2.5e2],
1708        );
1709        let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1710        let params = vec![1.125, -0.375];
1711
1712        let data_local = nll
1713            .data_evaluator
1714            .evaluate_local(&params)
1715            .expect("evaluate should succeed");
1716        let mc_local = nll
1717            .accmc_evaluator
1718            .evaluate_local(&params)
1719            .expect("evaluate should succeed");
1720        let data_term_local: f64 = data_local
1721            .iter()
1722            .zip(nll.data_evaluator.dataset.weights_local().iter())
1723            .map(|(value, event)| *event * value.re.ln())
1724            .sum();
1725        let mc_term_local: f64 = mc_local
1726            .iter()
1727            .zip(nll.accmc_evaluator.dataset.weights_local().iter())
1728            .map(|(value, event)| *event * value.re)
1729            .sum();
1730        let data_term = crate::likelihood::nll::reduce_scalar(&world, data_term_local);
1731        let mc_term = crate::likelihood::nll::reduce_scalar(&world, mc_term_local);
1732        let expected = -2.0 * (data_term - mc_term / nll.n_mc);
1733        let mpi_value = nll
1734            .evaluate_mpi(&params, &world)
1735            .expect("evaluate should succeed");
1736        assert_relative_eq!(mpi_value, expected, epsilon = 1e-9, max_relative = 1e-12);
1737        finalize_mpi();
1738    }
1739
1740    #[cfg(feature = "mpi")]
1741    #[mpi_test(np = [2])]
1742    fn mpi_projection_paths_are_explicit_global_gathers() {
1743        use_mpi(true);
1744        let world = get_world().expect("MPI world should be initialized");
1745        let (nll, params) = make_constant_nll();
1746
1747        let local_projection = nll
1748            .project_weights_local(&params, None)
1749            .expect("local projection should evaluate");
1750        let gathered_projection = nll
1751            .project_weights_mpi(&params, None, &world)
1752            .expect("mpi projection should gather global projection");
1753        let local_len = nll.accmc_evaluator.dataset.n_events_local();
1754        let total_len = nll.accmc_evaluator.dataset.n_events();
1755        assert_eq!(local_projection.len(), local_len);
1756        assert_eq!(gathered_projection.len(), total_len);
1757
1758        let (counts, displs) = world.get_counts_displs(total_len);
1759        let rank = world.rank() as usize;
1760        let start = displs[rank] as usize;
1761        let end = start + counts[rank] as usize;
1762        assert_eq!(
1763            &gathered_projection[start..end],
1764            local_projection.as_slice()
1765        );
1766
1767        let (local_weights, local_gradients) = nll
1768            .project_weights_and_gradients_local(&params, None)
1769            .expect("local projection gradient should evaluate");
1770        let (gathered_weights, gathered_gradients) = nll
1771            .project_weights_and_gradients_mpi(&params, None, &world)
1772            .expect("mpi projection gradient should gather global projection");
1773        assert_eq!(local_weights.len(), local_len);
1774        assert_eq!(local_gradients.len(), local_len);
1775        assert_eq!(gathered_weights.len(), total_len);
1776        assert_eq!(gathered_gradients.len(), total_len);
1777        assert_eq!(&gathered_weights[start..end], local_weights.as_slice());
1778
1779        let local_grad_slice = &gathered_gradients[start..end];
1780        for (lhs, rhs) in local_grad_slice.iter().zip(local_gradients.iter()) {
1781            assert_relative_eq!(lhs, rhs);
1782        }
1783        finalize_mpi();
1784    }
1785
1786    #[cfg(feature = "mpi")]
1787    #[mpi_test(np = [2])]
1788    fn mpi_project_weights_subsets_matches_repeated_project_weights_subset_mpi() {
1789        use_mpi(true);
1790        let world = get_world().expect("MPI world should be initialized");
1791        let (nll, params) = make_two_parameter_nll();
1792        let subsets = vec![
1793            vec!["amp_b".to_string()],
1794            vec!["amp_a".to_string()],
1795            vec!["amp_a".to_string(), "amp_b".to_string()],
1796            vec!["amp_a".to_string()],
1797        ];
1798        let batched = nll
1799            .project_weights_subsets_mpi(&params, &subsets, None, &world)
1800            .expect("batched mpi projection should evaluate");
1801        let repeated = subsets
1802            .iter()
1803            .map(|subset| {
1804                nll.project_weights_subset_mpi(&params, subset, None, &world)
1805                    .expect("single mpi subset projection should evaluate")
1806            })
1807            .collect::<Vec<_>>();
1808        assert_eq!(batched.len(), repeated.len());
1809        for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1810            assert_eq!(lhs.len(), rhs.len());
1811            for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1812                assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1813            }
1814        }
1815        finalize_mpi();
1816    }
1817
1818    #[cfg(feature = "mpi")]
1819    #[mpi_test(np = [2])]
1820    fn mpi_project_weights_and_gradients_subset_matches_repeated_project_weights_and_gradients_subset_mpi(
1821    ) {
1822        use_mpi(true);
1823        let world = get_world().expect("MPI world should be initialized");
1824        let (nll, params) = make_two_parameter_nll();
1825        let subsets = vec![
1826            vec!["amp_b".to_string()],
1827            vec!["amp_a".to_string()],
1828            vec!["amp_a".to_string(), "amp_b".to_string()],
1829        ];
1830        for subset in subsets {
1831            let (weights_mpi, gradients_mpi) = nll
1832                .project_weights_and_gradients_subset_mpi(&params, &subset, None, &world)
1833                .expect("mpi gradient projection should evaluate");
1834            let (weights_auto, gradients_auto) = nll
1835                .project_weights_and_gradients_subset(&params, &subset, None)
1836                .expect("auto gradient projection should evaluate");
1837            assert_eq!(weights_mpi.len(), weights_auto.len());
1838            assert_eq!(gradients_mpi.len(), gradients_auto.len());
1839            for (lhs, rhs) in weights_mpi.iter().zip(weights_auto.iter()) {
1840                assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1841            }
1842            for (lhs, rhs) in gradients_mpi.iter().zip(gradients_auto.iter()) {
1843                assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1844            }
1845        }
1846        finalize_mpi();
1847    }
1848
1849    #[cfg(feature = "mpi")]
1850    #[mpi_test(np = [2])]
1851    fn mpi_mixed_workload_rss_stays_bounded() {
1852        use_mpi(true);
1853        let world = get_world().expect("MPI world should be initialized");
1854        let fixture = make_mixed_workload_nll_fixture(2_048);
1855
1856        let baseline_value = fixture
1857            .nll
1858            .evaluate_mpi(&fixture.parameters, &world)
1859            .expect("evaluate should succeed");
1860        let baseline_gradient = fixture
1861            .nll
1862            .evaluate_gradient_mpi(&fixture.parameters, &world)
1863            .expect("evaluate should succeed");
1864        let baseline_weights = fixture
1865            .nll
1866            .project_weights_mpi(&fixture.parameters, None, &world)
1867            .expect("baseline MPI projection should evaluate");
1868        let (baseline_projection_weights, baseline_projection_gradients) = fixture
1869            .nll
1870            .project_weights_and_gradients_mpi(&fixture.parameters, None, &world)
1871            .expect("baseline MPI projection gradient should evaluate");
1872        let mut post_warmup_rss_kb = Vec::new();
1873
1874        assert_relative_eq!(
1875            baseline_weights.as_slice(),
1876            baseline_projection_weights.as_slice(),
1877            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1878            max_relative = DETERMINISTIC_STRICT_REL_TOL
1879        );
1880
1881        for pass_index in 0..24 {
1882            let value = fixture
1883                .nll
1884                .evaluate_mpi(&fixture.parameters, &world)
1885                .expect("evaluate should succeed");
1886            assert_relative_eq!(
1887                value,
1888                baseline_value,
1889                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1890                max_relative = DETERMINISTIC_STRICT_REL_TOL
1891            );
1892
1893            let gradient = fixture
1894                .nll
1895                .evaluate_gradient_mpi(&fixture.parameters, &world)
1896                .expect("evaluate should succeed");
1897            assert_eq!(
1898                gradient.len(),
1899                baseline_gradient.len(),
1900                "mixed-workload MPI gradient length should remain stable"
1901            );
1902            for (actual_item, expected_item) in gradient.iter().zip(baseline_gradient.iter()) {
1903                assert_relative_eq!(
1904                    *actual_item,
1905                    *expected_item,
1906                    epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1907                    max_relative = DETERMINISTIC_STRICT_REL_TOL
1908                );
1909            }
1910
1911            let weights = fixture
1912                .nll
1913                .project_weights_mpi(&fixture.parameters, None, &world)
1914                .expect("MPI projection should remain evaluable");
1915            assert_eq!(
1916                weights.len(),
1917                baseline_weights.len(),
1918                "mixed-workload MPI projection length should remain stable"
1919            );
1920            for (actual_item, expected_item) in weights.iter().zip(baseline_weights.iter()) {
1921                assert_relative_eq!(
1922                    *actual_item,
1923                    *expected_item,
1924                    epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1925                    max_relative = DETERMINISTIC_STRICT_REL_TOL
1926                );
1927            }
1928
1929            let (projection_weights, projection_gradients) = fixture
1930                .nll
1931                .project_weights_and_gradients_mpi(&fixture.parameters, None, &world)
1932                .expect("MPI projection gradients should remain evaluable");
1933            assert_eq!(
1934                projection_weights.len(),
1935                baseline_projection_weights.len(),
1936                "mixed-workload MPI projection-gradient weight length should remain stable"
1937            );
1938            assert_eq!(
1939                projection_gradients.len(),
1940                baseline_projection_gradients.len(),
1941                "mixed-workload MPI projection-gradient length should remain stable"
1942            );
1943            for (actual_item, expected_item) in projection_weights
1944                .iter()
1945                .zip(baseline_projection_weights.iter())
1946            {
1947                assert_relative_eq!(
1948                    *actual_item,
1949                    *expected_item,
1950                    epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1951                    max_relative = DETERMINISTIC_STRICT_REL_TOL
1952                );
1953            }
1954            for (actual_gradient, expected_gradient) in projection_gradients
1955                .iter()
1956                .zip(baseline_projection_gradients.iter())
1957            {
1958                assert_eq!(
1959                    actual_gradient.len(),
1960                    expected_gradient.len(),
1961                    "mixed-workload MPI projection-gradient vector length should remain stable"
1962                );
1963                for (actual_item, expected_item) in
1964                    actual_gradient.iter().zip(expected_gradient.iter())
1965                {
1966                    assert_relative_eq!(
1967                        *actual_item,
1968                        *expected_item,
1969                        epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1970                        max_relative = DETERMINISTIC_STRICT_REL_TOL
1971                    );
1972                }
1973            }
1974
1975            if pass_index >= 3 {
1976                if let Some(rss_kb) = read_resident_rss_kb() {
1977                    post_warmup_rss_kb.push(rss_kb);
1978                }
1979            }
1980        }
1981
1982        if let Some((&first_rss_kb, rest_rss_kb)) = post_warmup_rss_kb.split_first() {
1983            let last_rss_kb = *rest_rss_kb.last().unwrap_or(&first_rss_kb);
1984            let min_rss_kb = post_warmup_rss_kb
1985                .iter()
1986                .copied()
1987                .min()
1988                .expect("post-warmup RSS sample should exist");
1989            let max_rss_kb = post_warmup_rss_kb
1990                .iter()
1991                .copied()
1992                .max()
1993                .expect("post-warmup RSS sample should exist");
1994            const MAX_POST_WARMUP_RSS_GROWTH_KB: u64 = 64 * 1024;
1995            const MAX_POST_WARMUP_RSS_SPREAD_KB: u64 = 64 * 1024;
1996            assert!(
1997                last_rss_kb.saturating_sub(first_rss_kb) <= MAX_POST_WARMUP_RSS_GROWTH_KB,
1998                "mixed-workload post-warmup RSS grew by {} KiB (first={} KiB, last={} KiB)",
1999                last_rss_kb.saturating_sub(first_rss_kb),
2000                first_rss_kb,
2001                last_rss_kb
2002            );
2003            assert!(
2004                max_rss_kb.saturating_sub(min_rss_kb) <= MAX_POST_WARMUP_RSS_SPREAD_KB,
2005                "mixed-workload post-warmup RSS spread was {} KiB (min={} KiB, max={} KiB)",
2006                max_rss_kb.saturating_sub(min_rss_kb),
2007                min_rss_kb,
2008                max_rss_kb
2009            );
2010        }
2011
2012        finalize_mpi();
2013    }
2014}