1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4};
5
6use auto_ops::*;
7use laddu_core::{amplitude::ParameterMap, LadduError, LadduResult};
8use nalgebra::DVector;
9
10use super::term::LikelihoodTerm;
11
12#[derive(Debug)]
13struct LikelihoodValues(Vec<f64>);
14
15#[derive(Debug)]
16struct LikelihoodGradients(Vec<DVector<f64>>);
17
18#[derive(Clone, Default)]
19enum LikelihoodNode {
20 #[default]
21 Zero,
22 One,
23 Term(usize),
24 Add(Box<LikelihoodNode>, Box<LikelihoodNode>),
25 Mul(Box<LikelihoodNode>, Box<LikelihoodNode>),
26}
27
28impl LikelihoodNode {
29 fn remap(&self, mapping: &[usize]) -> Self {
30 match self {
31 Self::Term(idx) => Self::Term(mapping[*idx]),
32 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
33 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
34 Self::Zero => Self::Zero,
35 Self::One => Self::One,
36 }
37 }
38
39 fn evaluate(&self, likelihood_values: &LikelihoodValues) -> f64 {
40 match self {
41 LikelihoodNode::Zero => 0.0,
42 LikelihoodNode::One => 1.0,
43 LikelihoodNode::Term(idx) => likelihood_values.0[*idx],
44 LikelihoodNode::Add(a, b) => {
45 a.evaluate(likelihood_values) + b.evaluate(likelihood_values)
46 }
47 LikelihoodNode::Mul(a, b) => {
48 a.evaluate(likelihood_values) * b.evaluate(likelihood_values)
49 }
50 }
51 }
52
53 fn evaluate_gradient(
54 &self,
55 likelihood_values: &LikelihoodValues,
56 likelihood_gradients: &LikelihoodGradients,
57 ) -> DVector<f64> {
58 match self {
59 LikelihoodNode::Zero => DVector::zeros(0),
60 LikelihoodNode::One => DVector::zeros(0),
61 LikelihoodNode::Term(idx) => likelihood_gradients.0[*idx].clone(),
62 LikelihoodNode::Add(a, b) => {
63 a.evaluate_gradient(likelihood_values, likelihood_gradients)
64 + b.evaluate_gradient(likelihood_values, likelihood_gradients)
65 }
66 LikelihoodNode::Mul(a, b) => {
67 a.evaluate_gradient(likelihood_values, likelihood_gradients)
68 * b.evaluate(likelihood_values)
69 + b.evaluate_gradient(likelihood_values, likelihood_gradients)
70 * a.evaluate(likelihood_values)
71 }
72 }
73 }
74
75 fn write_tree(
76 &self,
77 f: &mut std::fmt::Formatter<'_>,
78 parent_prefix: &str,
79 immediate_prefix: &str,
80 parent_suffix: &str,
81 ) -> std::fmt::Result {
82 let display_string = match self {
83 Self::Zero => "0".to_string(),
84 Self::One => "1".to_string(),
85 Self::Term(idx) => format!("term({idx})"),
86 Self::Add(_, _) => "+".to_string(),
87 Self::Mul(_, _) => "*".to_string(),
88 };
89 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
90 match self {
91 Self::Term(_) | Self::Zero | Self::One => {}
92 Self::Add(a, b) | Self::Mul(a, b) => {
93 let terms = [a, b];
94 let mut it = terms.iter().peekable();
95 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
96 while let Some(child) = it.next() {
97 match it.peek() {
98 Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ ")?,
99 None => child.write_tree(f, &child_prefix, "└─ ", " ")?,
100 }
101 }
102 }
103 }
104 Ok(())
105 }
106}
107
108#[derive(Clone, Default)]
114pub struct LikelihoodExpression {
115 registry: LikelihoodRegistry,
116 tree: LikelihoodNode,
117}
118
119impl LikelihoodExpression {
120 pub fn from_term(term: Box<dyn LikelihoodTerm>) -> LadduResult<Self> {
122 let registry = LikelihoodRegistry::singleton(term)?;
123 Ok(Self {
124 registry,
125 tree: LikelihoodNode::Term(0),
126 })
127 }
128
129 pub fn zero() -> Self {
131 Self {
132 registry: LikelihoodRegistry::default(),
133 tree: LikelihoodNode::Zero,
134 }
135 }
136
137 pub fn one() -> Self {
139 Self {
140 registry: LikelihoodRegistry::default(),
141 tree: LikelihoodNode::One,
142 }
143 }
144
145 fn binary_op(
146 a: &LikelihoodExpression,
147 b: &LikelihoodExpression,
148 build: impl Fn(Box<LikelihoodNode>, Box<LikelihoodNode>) -> LikelihoodNode,
149 ) -> LikelihoodExpression {
150 let (registry, left_map, right_map) = a.registry.merge(&b.registry);
151 let left_tree = a.tree.remap(&left_map);
152 let right_tree = b.tree.remap(&right_map);
153 LikelihoodExpression {
154 registry,
155 tree: build(Box::new(left_tree), Box::new(right_tree)),
156 }
157 }
158
159 fn write_tree(
160 &self,
161 f: &mut std::fmt::Formatter<'_>,
162 parent_prefix: &str,
163 immediate_prefix: &str,
164 parent_suffix: &str,
165 ) -> std::fmt::Result {
166 self.tree
167 .write_tree(f, parent_prefix, immediate_prefix, parent_suffix)
168 }
169
170 pub fn parameters(&self) -> ParameterMap {
172 self.registry.global_parameter_map().clone()
173 }
174
175 pub fn n_free(&self) -> usize {
177 self.registry.global_parameter_map().free().len()
178 }
179
180 pub fn n_fixed(&self) -> usize {
182 self.registry.global_parameter_map().fixed().len()
183 }
184
185 pub fn n_parameters(&self) -> usize {
187 self.registry.global_parameter_map().len()
188 }
189
190 pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
192 let layout = self.registry.global_layout()?;
193 layout.global_map.assemble(parameters)?; let likelihood_values = LikelihoodValues(
195 self.registry
196 .terms
197 .iter()
198 .zip(layout.layouts.iter())
199 .map(|(term, term_layout)| {
200 term.evaluate(
201 &term_layout
202 .iter()
203 .map(|&global_idx| parameters[global_idx])
204 .collect::<Vec<_>>(),
205 )
206 })
207 .collect::<LadduResult<Vec<_>>>()?,
208 );
209 Ok(self.tree.evaluate(&likelihood_values))
210 }
211
212 pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
214 let free_parameter_count = parameters.len();
215 let layout = self.registry.global_layout()?;
216 layout.global_map.assemble(parameters)?; let parameter_sets = layout
218 .layouts
219 .iter()
220 .map(|term_layout| {
221 term_layout
222 .iter()
223 .map(|&global_idx| parameters[global_idx])
224 .collect::<Vec<_>>()
225 })
226 .collect::<Vec<_>>();
227 let likelihood_values = LikelihoodValues(
228 self.registry
229 .terms
230 .iter()
231 .zip(parameter_sets.iter())
232 .map(|(term, term_parameters)| term.evaluate(term_parameters))
233 .collect::<LadduResult<Vec<_>>>()?,
234 );
235 let mut gradient_buffers: Vec<DVector<f64>> = (0..self.registry.terms.len())
236 .map(|_| DVector::zeros(parameters.len()))
237 .collect();
238 for (((term, term_parameters), gradient_buffer), layout) in self
239 .registry
240 .terms
241 .iter()
242 .zip(parameter_sets.iter())
243 .zip(gradient_buffers.iter_mut())
244 .zip(layout.layouts.iter())
245 {
246 let term_gradient = term.evaluate_gradient(term_parameters)?; for (term_idx, &buffer_idx) in layout.iter().enumerate() {
248 gradient_buffer[buffer_idx] = term_gradient[term_idx] }
250 }
251 let likelihood_gradients = LikelihoodGradients(gradient_buffers);
252 let full_gradient = self
253 .tree
254 .evaluate_gradient(&likelihood_values, &likelihood_gradients);
255 let mut reduced = DVector::zeros(free_parameter_count);
256 for (out_idx, &global_idx) in layout
257 .global_map
258 .free_parameter_indices()
259 .iter()
260 .enumerate()
261 {
262 reduced[out_idx] = full_gradient[global_idx];
263 }
264 Ok(reduced)
265 }
266}
267
268impl LikelihoodTerm for LikelihoodExpression {
269 fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
270 LikelihoodExpression::evaluate(self, parameters)
271 }
272 fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
273 LikelihoodExpression::evaluate_gradient(self, parameters)
274 }
275 fn update(&self) {
276 self.registry.terms.iter().for_each(|term| term.update())
277 }
278 fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
279 self.registry.fix_parameter(name, value)
280 }
281
282 fn free_parameter(&self, name: &str) -> LadduResult<()> {
283 self.registry.free_parameter(name)
284 }
285
286 fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
287 self.registry.rename_parameter(old, new)
288 }
289
290 fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
291 self.registry.rename_parameters(mapping)
292 }
293
294 fn parameter_map(&self) -> ParameterMap {
295 self.registry.global_parameter_map().clone()
296 }
297}
298
299impl Debug for LikelihoodExpression {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 self.write_tree(f, "", "", "")
302 }
303}
304
305impl Display for LikelihoodExpression {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 self.write_tree(f, "", "", "")
308 }
309}
310
311impl_op_ex!(+ |a: &LikelihoodExpression, b: &LikelihoodExpression| -> LikelihoodExpression {
312 LikelihoodExpression::binary_op(a, b, LikelihoodNode::Add)
313});
314impl_op_ex!(
315 *|a: &LikelihoodExpression, b: &LikelihoodExpression| -> LikelihoodExpression {
316 LikelihoodExpression::binary_op(a, b, LikelihoodNode::Mul)
317 }
318);
319
320struct GlobalParameterLayout {
321 global_map: ParameterMap,
322 layouts: Vec<Vec<usize>>,
323}
324
325#[derive(Clone, Default)]
326struct LikelihoodRegistry {
327 terms: Vec<Box<dyn LikelihoodTerm>>,
328}
329
330impl LikelihoodRegistry {
331 fn singleton(term: Box<dyn LikelihoodTerm>) -> LadduResult<Self> {
332 let mut registry = Self::default();
333 registry.push_term(term);
334 Ok(registry)
335 }
336
337 fn push_term(&mut self, term: Box<dyn LikelihoodTerm>) -> usize {
338 let term_idx = self.terms.len();
339 self.terms.push(term);
340 term_idx
341 }
342
343 fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
344 let mut registry = Self::default();
345 let mut left_map = Vec::with_capacity(self.terms.len());
346 for term in &self.terms {
347 let idx = registry.push_term(dyn_clone::clone_box(&**term));
348 left_map.push(idx);
349 }
350 let mut right_map = Vec::with_capacity(other.terms.len());
351 for term in &other.terms {
352 let idx = registry.push_term(dyn_clone::clone_box(&**term));
353 right_map.push(idx);
354 }
355 (registry, left_map, right_map)
356 }
357
358 fn global_parameter_map(&self) -> ParameterMap {
359 let mut global = ParameterMap::default();
360 for term in &self.terms {
361 (global, _, _) = global.merge(&term.parameter_map());
362 }
363 global
364 }
365
366 fn global_layout(&self) -> LadduResult<GlobalParameterLayout> {
367 let global_map = self.global_parameter_map();
368 let global_free_index: HashMap<String, usize> = global_map
369 .free()
370 .names()
371 .into_iter()
372 .enumerate()
373 .map(|(idx, name)| (name, idx))
374 .collect();
375
376 let layouts = self
377 .terms
378 .iter()
379 .map(|term| {
380 term.parameter_map()
381 .free()
382 .names()
383 .into_iter()
384 .map(|name| {
385 global_free_index.get(&name).copied().ok_or_else(|| {
386 LadduError::UnregisteredParameter {
387 name,
388 reason: "free parameter missing in global parameter map"
389 .to_string(),
390 }
391 })
392 })
393 .collect()
394 })
395 .collect::<LadduResult<Vec<_>>>()?;
396
397 Ok(GlobalParameterLayout {
398 global_map,
399 layouts,
400 })
401 }
402
403 fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
404 for term in &self.terms {
405 if term.parameter_map().contains_key(name) {
406 term.parameter_map().fix_parameter(name, value)?;
407 }
408 }
409 Ok(())
410 }
411
412 fn free_parameter(&self, name: &str) -> LadduResult<()> {
413 for term in &self.terms {
414 if term.parameter_map().contains_key(name) {
415 term.parameter_map().free_parameter(name)?;
416 }
417 }
418 Ok(())
419 }
420
421 fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
422 for term in &self.terms {
423 if term.parameter_map().contains_key(new) {
424 return Err(LadduError::ParameterConflict {
425 name: new.to_string(),
426 reason: "rename target already exists".to_string(),
427 });
428 }
429 }
430 for term in &self.terms {
431 if term.parameter_map().contains_key(old) {
432 term.rename_parameter(old, new)?;
433 }
434 }
435 Ok(())
436 }
437
438 fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
439 for (old, new) in mapping {
440 self.rename_parameter(old, new)?;
441 }
442 Ok(())
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use std::sync::Arc;
449
450 use approx::assert_relative_eq;
451 #[cfg(feature = "mpi")]
452 use laddu_core::mpi::{finalize_mpi, get_world, use_mpi, LadduMPI};
453 use laddu_core::{
454 amplitude::{Amplitude, AmplitudeID, ExpressionDependence, Parameter},
455 data::{Dataset, DatasetMetadata, EventData},
456 parameter,
457 resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
458 vectors::Vec4,
459 Expression, LadduError, LadduResult,
460 };
461 #[cfg(feature = "mpi")]
462 use mpi::topology::{Communicator, SimpleCommunicator};
463 #[cfg(feature = "mpi")]
464 use mpi_test::mpi_test;
465 use nalgebra::DVector;
466 use num::complex::Complex64;
467 use serde::{Deserialize, Serialize};
468
469 use crate::likelihood::{LikelihoodScalar, LikelihoodTerm, NLL};
470
471 const LENGTH_MISMATCH_MESSAGE_FRAGMENT: &str = "length mismatch";
472 const AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT: &str = "No registered amplitude";
473
474 #[derive(Clone, Serialize, Deserialize)]
475 struct ConstantAmplitude {
476 name: String,
477 parameter: Parameter,
478 pid: ParameterID,
479 }
480
481 impl ConstantAmplitude {
482 #[allow(clippy::new_ret_no_self)]
483 fn new(name: &str, parameter: Parameter) -> LadduResult<Expression> {
484 Self {
485 name: name.to_string(),
486 parameter,
487 pid: ParameterID::default(),
488 }
489 .into_expression()
490 }
491 }
492
493 #[typetag::serde]
494 impl Amplitude for ConstantAmplitude {
495 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
496 self.pid = resources.register_parameter(&self.parameter)?;
497 resources.register_amplitude(&self.name)
498 }
499
500 fn dependence_hint(&self) -> ExpressionDependence {
501 ExpressionDependence::ParameterOnly
502 }
503
504 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
505 Complex64::new(parameters.get(self.pid), 0.0)
506 }
507
508 fn compute_gradient(
509 &self,
510 parameters: &Parameters,
511 _cache: &Cache,
512 gradient: &mut DVector<Complex64>,
513 ) {
514 if let Some(index) = parameters.free_index(self.pid) {
515 gradient[index] = Complex64::ONE;
516 }
517 }
518 }
519
520 #[derive(Clone, Serialize, Deserialize)]
521 struct CachedBeamScaleAmplitude {
522 name: String,
523 parameter: Parameter,
524 pid: ParameterID,
525 sid: ScalarID,
526 p4_index: usize,
527 }
528
529 impl CachedBeamScaleAmplitude {
530 #[allow(clippy::new_ret_no_self)]
531 fn new(name: &str, parameter: Parameter, p4_index: usize) -> LadduResult<Expression> {
532 Self {
533 name: name.to_string(),
534 parameter,
535 pid: ParameterID::default(),
536 sid: ScalarID::default(),
537 p4_index,
538 }
539 .into_expression()
540 }
541 }
542
543 #[typetag::serde]
544 impl Amplitude for CachedBeamScaleAmplitude {
545 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
546 self.pid = resources.register_parameter(&self.parameter)?;
547 self.sid = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
548 resources.register_amplitude(&self.name)
549 }
550
551 fn dependence_hint(&self) -> ExpressionDependence {
552 ExpressionDependence::Mixed
553 }
554
555 fn precompute(&self, event: &laddu_core::data::Event<'_>, cache: &mut Cache) {
556 cache.store_scalar(self.sid, event.p4_at(self.p4_index).e());
557 }
558
559 fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64 {
560 Complex64::new(parameters.get(self.pid), 0.0) * cache.get_scalar(self.sid)
561 }
562
563 fn compute_gradient(
564 &self,
565 parameters: &Parameters,
566 cache: &Cache,
567 gradient: &mut DVector<Complex64>,
568 ) {
569 if let Some(index) = parameters.free_index(self.pid) {
570 gradient[index] = Complex64::new(cache.get_scalar(self.sid), 0.0);
571 }
572 }
573 }
574
575 #[derive(Clone, Serialize, Deserialize)]
576 struct CacheOnlyBeamAmplitude {
577 name: String,
578 sid: ScalarID,
579 p4_index: usize,
580 }
581
582 impl CacheOnlyBeamAmplitude {
583 #[allow(clippy::new_ret_no_self)]
584 fn new(name: &str, p4_index: usize) -> LadduResult<Expression> {
585 Self {
586 name: name.to_string(),
587 sid: ScalarID::default(),
588 p4_index,
589 }
590 .into_expression()
591 }
592 }
593
594 #[typetag::serde]
595 impl Amplitude for CacheOnlyBeamAmplitude {
596 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
597 self.sid = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
598 resources.register_amplitude(&self.name)
599 }
600
601 fn dependence_hint(&self) -> ExpressionDependence {
602 ExpressionDependence::CacheOnly
603 }
604
605 fn precompute(&self, event: &laddu_core::data::Event<'_>, cache: &mut Cache) {
606 cache.store_scalar(self.sid, event.p4_at(self.p4_index).e());
607 }
608
609 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
610 Complex64::new(cache.get_scalar(self.sid), 0.0)
611 }
612 }
613
614 fn dataset_with_weights(weights: &[f64]) -> Arc<Dataset> {
615 let metadata = Arc::new(DatasetMetadata::default());
616 let events = weights
617 .iter()
618 .map(|&weight| {
619 Arc::new(EventData {
620 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
621 aux: vec![],
622 weight,
623 })
624 })
625 .collect();
626 Arc::new(Dataset::new_with_metadata(events, metadata))
627 }
628
629 fn dataset_with_two_p4_and_weights(
630 beam_energies: &[(f64, f64)],
631 weights: &[f64],
632 ) -> Arc<Dataset> {
633 assert_eq!(beam_energies.len(), weights.len());
634 let metadata = Arc::new(DatasetMetadata::default());
635 let events = beam_energies
636 .iter()
637 .zip(weights.iter())
638 .map(|(&(e0, e1), &weight)| {
639 Arc::new(EventData {
640 p4s: vec![Vec4::new(0.0, 0.0, 0.0, e0), Vec4::new(0.0, 0.0, 0.0, e1)],
641 aux: vec![],
642 weight,
643 })
644 })
645 .collect();
646 Arc::new(Dataset::new_with_metadata(events, metadata))
647 }
648
649 #[cfg(feature = "mpi")]
650 fn read_resident_rss_kb() -> Option<u64> {
651 #[cfg(target_os = "linux")]
652 {
653 let status = std::fs::read_to_string("/proc/self/status").ok()?;
654 let vm_rss = status
655 .lines()
656 .find(|line| line.starts_with("VmRSS:"))?
657 .split_whitespace()
658 .nth(1)?;
659 vm_rss.parse::<u64>().ok()
660 }
661
662 #[cfg(not(target_os = "linux"))]
663 {
664 None
665 }
666 }
667
668 #[cfg(feature = "mpi")]
669 fn generated_two_p4_dataset(
670 n_events: usize,
671 base_energy: f64,
672 weight_scale: f64,
673 ) -> Arc<Dataset> {
674 let metadata = Arc::new(DatasetMetadata::default());
675 let events = (0..n_events)
676 .map(|index| {
677 let idx = index as f64;
678 let beam_e0 = base_energy + (idx % 17.0) * 0.35 + idx * 0.0025;
679 let beam_e1 = 0.5 * base_energy + (idx % 11.0) * 0.2 + idx * 0.0015;
680 let weight = 0.75 + weight_scale * (1.0 + (index % 9) as f64);
681 Arc::new(EventData {
682 p4s: vec![
683 Vec4::new(0.0, 0.0, 0.0, beam_e0),
684 Vec4::new(0.0, 0.0, 0.0, beam_e1),
685 ],
686 aux: vec![],
687 weight,
688 })
689 })
690 .collect();
691 Arc::new(Dataset::new_with_metadata(events, metadata))
692 }
693
694 fn make_constant_nll() -> (Box<NLL>, Vec<f64>) {
695 let amp = ConstantAmplitude::new("amp", parameter!("scale")).unwrap();
696 let expr = amp.norm_sqr();
697 let data = dataset_with_weights(&[1.0, 2.0]);
698 let mc = dataset_with_weights(&[0.5, 1.5]);
699 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
700 (nll, vec![2.0])
701 }
702
703 fn make_two_parameter_nll() -> (Box<NLL>, Vec<f64>) {
704 let amp_a = ConstantAmplitude::new("amp_a", parameter!("alpha")).unwrap();
705 let amp_b = ConstantAmplitude::new("amp_b", parameter!("beta")).unwrap();
706 let expr = (amp_a + amp_b).norm_sqr();
707 let data = dataset_with_weights(&[1.0, 2.0, 3.0, 1.0]);
708 let mc = dataset_with_weights(&[0.5, 1.5, 2.5, 0.5]);
709 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
710 (nll, vec![0.75, -1.25])
711 }
712
713 #[test]
714 fn nll_handles_reused_amplitudes_in_coherent_expression() {
715 let amp_a = ConstantAmplitude::new("amp_a", parameter!("alpha")).unwrap();
716 let amp_b = ConstantAmplitude::new("amp_b", parameter!("beta")).unwrap();
717
718 let coherent_plus = amp_a.clone() + amp_b.clone();
719 let coherent_minus = amp_a - amp_b;
720 let expr = coherent_plus.norm_sqr() + coherent_minus.norm_sqr();
721
722 let data = dataset_with_weights(&[1.0, 2.0, 3.0]);
723 let mc = dataset_with_weights(&[0.5, 1.5, 2.5]);
724 let params = vec![0.75, -1.25];
725
726 let evaluator = expr.load(&data).unwrap();
727 let direct_values = evaluator.evaluate(¶ms).unwrap();
728 assert_eq!(direct_values.len(), 3);
729
730 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
731 let value = nll.evaluate(¶ms).unwrap();
732 assert!(value.is_finite());
733
734 let gradient = nll.evaluate_gradient(¶ms).unwrap();
735 assert_eq!(gradient.len(), params.len());
736 assert!(gradient.iter().all(|value| value.is_finite()));
737
738 let projection = nll.project_weights(¶ms, None).unwrap();
739 assert_eq!(projection.len(), mc.n_events());
740 assert!(projection.iter().all(|value| value.is_finite()));
741
742 let (_, projection_gradient) = nll.project_weights_and_gradients(¶ms, None).unwrap();
743 assert_eq!(projection_gradient.len(), mc.n_events());
744 assert!(projection_gradient
745 .iter()
746 .all(|gradient| gradient.iter().all(|value| value.is_finite())));
747 }
748
749 #[test]
750 fn nll_exposes_expression_and_current_compiled_expression() {
751 let (nll, _) = make_two_parameter_nll();
752
753 let expression_display = nll.expression().compiled_expression().to_string();
754 assert!(expression_display.contains("amp_a(id=0)"));
755 assert!(expression_display.contains("amp_b(id=1)"));
756
757 nll.deactivate("amp_b");
758 let compiled = nll.compiled_expression().to_string();
759 assert!(compiled.contains("amp_a(id=0)"));
760 assert!(!compiled.contains("amp_b(id=1)"));
761 assert!(!compiled.contains("const 0"));
762 assert!(!compiled.contains("+"));
763 }
764
765 #[test]
766 fn stochastic_nll_exposes_expression_and_current_compiled_expression() {
767 let (nll, _) = make_two_parameter_nll();
768 let stochastic = nll
769 .to_stochastic(2, Some(0))
770 .expect("stochastic NLL should build");
771
772 assert!(stochastic
773 .expression()
774 .compiled_expression()
775 .to_string()
776 .contains("amp_a(id=0)"));
777 assert!(stochastic
778 .compiled_expression()
779 .to_string()
780 .contains("amp_b(id=1)"));
781 }
782
783 #[derive(Clone, Copy)]
784 enum DeterministicModelKind {
785 Separable,
786 Partial,
787 NonSeparable,
788 }
789
790 struct DeterministicNllFixture {
791 nll: Box<NLL>,
792 parameters: Vec<f64>,
793 }
794
795 const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
796 const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
797
798 fn assert_nll_fixture_matches_weighted_baseline(fixture: &DeterministicNllFixture) {
799 let expected_value = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
800 &fixture.nll.data_evaluator,
801 &fixture.parameters,
802 |l| f64::ln(l.re),
803 )
804 .expect("evaluate should succeed");
805 let expected_mc_term = fixture
806 .nll
807 .accmc_evaluator
808 .evaluate_weighted_value_sum_local(&fixture.parameters)
809 .expect("evaluate should succeed");
810 let expected_value = -2.0 * (expected_value - expected_mc_term / fixture.nll.n_mc);
811
812 let expected_data_gradient = fixture
813 .nll
814 .evaluate_data_gradient_term_local(&fixture.parameters)
815 .expect("evaluate should succeed");
816 let expected_mc_gradient = fixture
817 .nll
818 .accmc_evaluator
819 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
820 .expect("evaluate should succeed");
821 let expected_gradient =
822 -2.0 * (expected_data_gradient - expected_mc_gradient / fixture.nll.n_mc);
823
824 let actual_value = fixture
825 .nll
826 .evaluate_local(&fixture.parameters)
827 .expect("evaluate should succeed");
828 assert_relative_eq!(
829 actual_value,
830 expected_value,
831 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
832 max_relative = DETERMINISTIC_STRICT_REL_TOL
833 );
834
835 let actual_gradient = fixture
836 .nll
837 .evaluate_gradient_local(&fixture.parameters)
838 .expect("evaluate should succeed");
839 assert_eq!(
840 actual_gradient.len(),
841 expected_gradient.len(),
842 "fixture NLL gradient length mismatch (actual={}, expected={})",
843 actual_gradient.len(),
844 expected_gradient.len()
845 );
846 for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
847 assert_relative_eq!(
848 *actual_item,
849 *expected_item,
850 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
851 max_relative = DETERMINISTIC_STRICT_REL_TOL
852 );
853 }
854 }
855
856 #[cfg(feature = "mpi")]
857 fn assert_nll_fixture_matches_mpi_reduced_baseline(
858 fixture: &DeterministicNllFixture,
859 world: &SimpleCommunicator,
860 ) {
861 let data_term_local = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
862 &fixture.nll.data_evaluator,
863 &fixture.parameters,
864 |l| f64::ln(l.re),
865 )
866 .expect("evaluate should succeed");
867 let mc_term_local = fixture
868 .nll
869 .accmc_evaluator
870 .evaluate_weighted_value_sum_local(&fixture.parameters)
871 .expect("evaluate should succeed");
872 let data_term = crate::likelihood::nll::reduce_scalar(world, data_term_local);
873 let mc_term = crate::likelihood::nll::reduce_scalar(world, mc_term_local);
874 let expected_value = -2.0 * (data_term - mc_term / fixture.nll.n_mc);
875 let mpi_value = fixture
876 .nll
877 .evaluate_mpi(&fixture.parameters, world)
878 .expect("evaluate should succeed");
879 assert_relative_eq!(
880 mpi_value,
881 expected_value,
882 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
883 max_relative = DETERMINISTIC_STRICT_REL_TOL
884 );
885
886 let data_gradient_local = fixture
887 .nll
888 .evaluate_data_gradient_term_local(&fixture.parameters)
889 .expect("evaluate should succeed");
890 let mc_gradient_local = fixture
891 .nll
892 .accmc_evaluator
893 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
894 .expect("evaluate should succeed");
895 let data_gradient = crate::likelihood::nll::reduce_gradient(world, &data_gradient_local);
896 let mc_gradient = crate::likelihood::nll::reduce_gradient(world, &mc_gradient_local);
897 let expected_gradient = -2.0 * (data_gradient - mc_gradient / fixture.nll.n_mc);
898 let mpi_gradient = fixture
899 .nll
900 .evaluate_gradient_mpi(&fixture.parameters, world)
901 .expect("evaluate should succeed");
902 assert_eq!(
903 mpi_gradient.len(),
904 expected_gradient.len(),
905 "fixture MPI gradient length mismatch (actual={}, expected={})",
906 mpi_gradient.len(),
907 expected_gradient.len()
908 );
909 for (actual_item, expected_item) in mpi_gradient.iter().zip(expected_gradient.iter()) {
910 assert_relative_eq!(
911 *actual_item,
912 *expected_item,
913 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
914 max_relative = DETERMINISTIC_STRICT_REL_TOL
915 );
916 }
917 }
918
919 fn make_deterministic_nll_fixture(kind: DeterministicModelKind) -> DeterministicNllFixture {
920 let data = dataset_with_two_p4_and_weights(
921 &[
922 (1.0, 0.8),
923 (2.5, 1.7),
924 (4.0, 2.4),
925 (3.3, 1.1),
926 (5.2, 2.8),
927 (1.7, 0.9),
928 ],
929 &[0.7, 1.2, 0.9, 1.5, 0.8, 1.1],
930 );
931 let mc = dataset_with_two_p4_and_weights(
932 &[
933 (1.5, 1.0),
934 (3.0, 2.1),
935 (5.5, 2.9),
936 (2.0, 1.2),
937 (4.2, 1.8),
938 (2.8, 1.4),
939 ],
940 &[0.8, 1.4, 0.6, 1.1, 0.75, 1.25],
941 );
942
943 match kind {
944 DeterministicModelKind::Separable => {
945 let p1 = ConstantAmplitude::new("p1", parameter!("p1"))
946 .expect("separable p1 should build");
947 let p2 = ConstantAmplitude::new("p2", parameter!("p2"))
948 .expect("separable p2 should build");
949 let c1 = CacheOnlyBeamAmplitude::new("c1", 0).expect("separable c1 should build");
950 let c2 = CacheOnlyBeamAmplitude::new("c2", 1).expect("separable c2 should build");
951 let expression = (&p1 * &c1) + &(&p2 * &c2);
952 DeterministicNllFixture {
953 nll: NLL::new(&expression, &data, &mc, None)
954 .expect("separable NLL should build"),
955 parameters: vec![0.4, 0.2],
956 }
957 }
958 DeterministicModelKind::Partial => {
959 let p =
960 ConstantAmplitude::new("p", parameter!("p")).expect("partial p should build");
961 let c = CacheOnlyBeamAmplitude::new("c", 0).expect("partial c should build");
962 let m = CachedBeamScaleAmplitude::new("m", parameter!("m"), 1)
963 .expect("partial m should build");
964 let expression = (&p * &c) + &m;
965 DeterministicNllFixture {
966 nll: NLL::new(&expression, &data, &mc, None).expect("partial NLL should build"),
967 parameters: vec![0.35, 0.25],
968 }
969 }
970 DeterministicModelKind::NonSeparable => {
971 let m1 = CachedBeamScaleAmplitude::new("m1", parameter!("m1"), 0)
972 .expect("non-separable m1 should build");
973 let m2 = CachedBeamScaleAmplitude::new("m2", parameter!("m2"), 1)
974 .expect("non-separable m2 should build");
975 let expression = &m1 * &m2;
976 DeterministicNllFixture {
977 nll: NLL::new(&expression, &data, &mc, None)
978 .expect("non-separable NLL should build"),
979 parameters: vec![0.2, 0.15],
980 }
981 }
982 }
983 }
984
985 #[cfg(feature = "mpi")]
986 fn make_mixed_workload_nll_fixture(n_events: usize) -> DeterministicNllFixture {
987 let data = generated_two_p4_dataset(n_events, 1.4, 0.08);
988 let mc = generated_two_p4_dataset(n_events, 1.9, 0.11);
989 let p =
990 ConstantAmplitude::new("p", parameter!("p")).expect("mixed-workload p should build");
991 let c = CacheOnlyBeamAmplitude::new("c", 0)
992 .expect("mixed-workload cache amplitude should build");
993 let m = CachedBeamScaleAmplitude::new("m", parameter!("m"), 1)
994 .expect("mixed-workload beam amplitude should build");
995 let expression = (&p * &c) + &m;
996 DeterministicNllFixture {
997 nll: NLL::new(&expression, &data, &mc, None).expect("mixed-workload NLL should build"),
998 parameters: vec![0.35, 0.25],
999 }
1000 }
1001
1002 fn case_nll_evaluate_short(nll: &NLL) -> LadduResult<()> {
1003 nll.evaluate(&[]).map(|_| ())
1004 }
1005
1006 fn case_nll_evaluate_gradient_long(nll: &NLL) -> LadduResult<()> {
1007 nll.evaluate_gradient(&[1.0, 2.0]).map(|_| ())
1008 }
1009
1010 fn case_nll_project_short(nll: &NLL) -> LadduResult<()> {
1011 nll.project_weights(&[], None).map(|_| ())
1012 }
1013
1014 fn case_nll_project_weights_and_gradients_long(nll: &NLL) -> LadduResult<()> {
1015 nll.project_weights_and_gradients(&[1.0, 2.0], None)
1016 .map(|_| ())
1017 }
1018
1019 fn case_nll_project_weights_subset_short(nll: &NLL) -> LadduResult<()> {
1020 nll.project_weights_subset_local::<&str>(&[], &["missing_amplitude"], None)
1021 .map(|_| ())
1022 }
1023
1024 fn case_nll_project_weights_and_gradients_subset_long(nll: &NLL) -> LadduResult<()> {
1025 nll.project_weights_and_gradients_subset_local::<&str>(
1026 &[1.0, 2.0],
1027 &["missing_amplitude"],
1028 None,
1029 )
1030 .map(|_| ())
1031 }
1032
1033 fn case_likelihood_evaluate_short() -> LadduResult<()> {
1034 let alpha = LikelihoodScalar::new("alpha")?;
1035 alpha.evaluate(&[]).map(|_| ())
1036 }
1037
1038 fn case_likelihood_gradient_long() -> LadduResult<()> {
1039 let alpha = LikelihoodScalar::new("alpha")?;
1040 alpha.evaluate_gradient(&[1.0, 2.0]).map(|_| ())
1041 }
1042
1043 #[test]
1044 fn table_driven_length_mismatch_errors() {
1045 let (nll, _) = make_constant_nll();
1046 let cases: [(&str, LadduResult<()>); 8] = [
1047 ("nll.evaluate short", case_nll_evaluate_short(nll.as_ref())),
1048 (
1049 "nll.evaluate_gradient long",
1050 case_nll_evaluate_gradient_long(nll.as_ref()),
1051 ),
1052 (
1053 "nll.project_weights short",
1054 case_nll_project_short(nll.as_ref()),
1055 ),
1056 (
1057 "nll.project_weights_and_gradients long",
1058 case_nll_project_weights_and_gradients_long(nll.as_ref()),
1059 ),
1060 (
1061 "nll.project_weights_subset short",
1062 case_nll_project_weights_subset_short(nll.as_ref()),
1063 ),
1064 (
1065 "nll.project_weights_and_gradients_subset long",
1066 case_nll_project_weights_and_gradients_subset_long(nll.as_ref()),
1067 ),
1068 (
1069 "likelihood.evaluate short",
1070 case_likelihood_evaluate_short(),
1071 ),
1072 (
1073 "likelihood.evaluate_gradient long",
1074 case_likelihood_gradient_long(),
1075 ),
1076 ];
1077 for (label, result) in cases {
1078 let err = result.unwrap_err();
1079 assert!(
1080 matches!(err, LadduError::LengthMismatch { .. }),
1081 "expected LengthMismatch for {label}, got {err:?}"
1082 );
1083 assert!(
1084 err.to_string().contains(LENGTH_MISMATCH_MESSAGE_FRAGMENT),
1085 "expected message containing \"{LENGTH_MISMATCH_MESSAGE_FRAGMENT}\" for {label}, got {}",
1086 err
1087 );
1088 }
1089 }
1090
1091 #[test]
1092 fn table_driven_unknown_amplitude_errors() {
1093 let (nll, params) = make_constant_nll();
1094 let cases: [(&str, LadduResult<()>); 4] = [
1095 (
1096 "activate_strict unknown",
1097 nll.activate_strict("missing_amplitude"),
1098 ),
1099 (
1100 "isolate_strict unknown",
1101 nll.isolate_strict("missing_amplitude"),
1102 ),
1103 (
1104 "project_weights_subset unknown",
1105 nll.project_weights_subset_local_strict::<&str>(
1106 ¶ms,
1107 &["missing_amplitude"],
1108 None,
1109 )
1110 .map(|_| ()),
1111 ),
1112 (
1113 "project_weights_and_gradients_subset unknown",
1114 nll.project_weights_and_gradients_subset_local_strict::<&str>(
1115 ¶ms,
1116 &["missing_amplitude"],
1117 None,
1118 )
1119 .map(|_| ()),
1120 ),
1121 ];
1122 for (label, result) in cases {
1123 let err = result.unwrap_err();
1124 assert!(
1125 matches!(err, LadduError::AmplitudeNotFoundError { .. }),
1126 "expected AmplitudeNotFoundError for {label}, got {err:?}"
1127 );
1128 assert!(
1129 err.to_string()
1130 .contains(AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT),
1131 "expected message containing \"{AMPLITUDE_NOT_FOUND_MESSAGE_FRAGMENT}\" for {label}, got {}",
1132 err
1133 );
1134 }
1135 }
1136
1137 #[test]
1138 fn likelihood_expression_evaluates_scalar_sum() {
1139 let alpha = LikelihoodScalar::new("alpha").unwrap();
1140 let beta = LikelihoodScalar::new("beta").unwrap();
1141 let expr = &alpha + β
1142 assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1143 let params = vec![2.0, 3.0];
1144 assert_relative_eq!(expr.evaluate(¶ms).unwrap(), 5.0);
1145 let grad = expr.evaluate_gradient(¶ms).unwrap();
1146 assert_relative_eq!(grad[0], 1.0);
1147 assert_relative_eq!(grad[1], 1.0);
1148 }
1149
1150 #[test]
1151 fn likelihood_expression_evaluates_scalar_product() {
1152 let alpha = LikelihoodScalar::new("alpha").unwrap();
1153 let beta = LikelihoodScalar::new("beta").unwrap();
1154 let expr = &alpha * β
1155 let params = vec![2.0, 3.0];
1156 assert_relative_eq!(expr.evaluate(¶ms).unwrap(), 6.0);
1157 let grad = expr.evaluate_gradient(¶ms).unwrap();
1158 assert_relative_eq!(grad[0], 3.0);
1159 assert_relative_eq!(grad[1], 2.0);
1160 }
1161
1162 #[test]
1163 fn likelihood_expression_tracks_fixed_parameters() {
1164 let alpha = LikelihoodScalar::new("alpha").unwrap();
1165 let beta = LikelihoodScalar::new("beta").unwrap();
1166 let expr = &alpha + β
1167 expr.fix_parameter("alpha", 1.5).unwrap();
1168 assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1169 assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1170 assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1171 let params_free = vec![2.0];
1172 assert_relative_eq!(expr.evaluate(¶ms_free).unwrap(), 3.5);
1173 let grad_free = expr.evaluate_gradient(¶ms_free).unwrap();
1174 assert_eq!(grad_free.len(), 1);
1175 assert_relative_eq!(grad_free[0], 1.0);
1176 }
1177
1178 #[test]
1179 fn likelihood_expression_handles_term_local_fixed_parameters() {
1180 let alpha = LikelihoodScalar::new("alpha").unwrap();
1181 alpha.fix_parameter("alpha", 1.5).unwrap();
1182 let beta = LikelihoodScalar::new("beta").unwrap();
1183 let expr = &alpha + β
1184 assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1185 assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1186 assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1187
1188 let params_free = vec![2.0];
1189 assert_relative_eq!(expr.evaluate(¶ms_free).unwrap(), 3.5);
1190 let grad_free = expr.evaluate_gradient(¶ms_free).unwrap();
1191 assert_eq!(grad_free.len(), 1);
1192 assert_relative_eq!(grad_free[0], 1.0);
1193 }
1194
1195 #[test]
1196 fn likelihood_product_handles_term_local_fixed_parameters() {
1197 let alpha = LikelihoodScalar::new("alpha").unwrap();
1198 alpha.fix_parameter("alpha", 1.5).unwrap();
1199 let beta = LikelihoodScalar::new("beta").unwrap();
1200 let expr = &alpha * β
1201 assert_eq!(expr.parameters().names(), vec!["alpha", "beta"]);
1202 assert_eq!(expr.parameters().free().names(), vec!["beta"]);
1203 assert_eq!(expr.parameters().fixed().names(), vec!["alpha"]);
1204
1205 let params_free = vec![2.0];
1206 assert_relative_eq!(expr.evaluate(¶ms_free).unwrap(), 3.0);
1207 let grad_free = expr.evaluate_gradient(¶ms_free).unwrap();
1208 assert_eq!(grad_free.len(), 1);
1209 assert_relative_eq!(grad_free[0], 1.5);
1210 }
1211
1212 #[test]
1213 fn nll_evaluate_and_gradient_match_closed_form() {
1214 let (nll, params) = make_constant_nll();
1215 let intensity = params[0] * params[0];
1216 let weight_sum = 3.0;
1217 let expected = -2.0 * (weight_sum * intensity.ln() - intensity);
1218 assert_relative_eq!(nll.evaluate(¶ms).unwrap(), expected, epsilon = 1e-12);
1219 let grad = nll.evaluate_gradient(¶ms).unwrap();
1220 let expected_grad = -4.0 * (weight_sum / params[0] - params[0]);
1221 assert_relative_eq!(grad[0], expected_grad, epsilon = 1e-12);
1222 }
1223
1224 #[cfg(feature = "rayon")]
1225 #[test]
1226 fn gradient_scratch_reuse_is_thread_safe_across_parallel_calls() {
1227 let (nll_single, params_single) = make_constant_nll();
1228 let (nll_multi, params_multi) = make_two_parameter_nll();
1229 let nll_single = Arc::new(*nll_single);
1230 let nll_multi = Arc::new(*nll_multi);
1231 let expected_single = nll_single
1232 .evaluate_gradient(¶ms_single)
1233 .expect("single-parameter gradient should evaluate");
1234 let expected_multi = nll_multi
1235 .evaluate_gradient(¶ms_multi)
1236 .expect("two-parameter gradient should evaluate");
1237 std::thread::scope(|scope| {
1238 for _ in 0..8 {
1239 let nll_single = Arc::clone(&nll_single);
1240 let nll_multi = Arc::clone(&nll_multi);
1241 let params_single = params_single.clone();
1242 let params_multi = params_multi.clone();
1243 let expected_single = expected_single.clone();
1244 let expected_multi = expected_multi.clone();
1245 scope.spawn(move || {
1246 for _ in 0..100 {
1247 let single_gradient = nll_single
1248 .evaluate_gradient(¶ms_single)
1249 .expect("single-parameter gradient should evaluate");
1250 assert_relative_eq!(
1251 single_gradient[0],
1252 expected_single[0],
1253 epsilon = 1e-12
1254 );
1255 let multi_gradient = nll_multi
1256 .evaluate_gradient(¶ms_multi)
1257 .expect("two-parameter gradient should evaluate");
1258 assert_eq!(multi_gradient.len(), expected_multi.len());
1259 for index in 0..expected_multi.len() {
1260 assert_relative_eq!(
1261 multi_gradient[index],
1262 expected_multi[index],
1263 epsilon = 1e-12
1264 );
1265 }
1266 }
1267 });
1268 }
1269 });
1270 }
1271
1272 #[test]
1273 fn nll_value_matches_mixed_scale_weighted_closed_form() {
1274 let amp = ConstantAmplitude::new("amp", parameter!("scale")).unwrap();
1275 let expr = amp.norm_sqr();
1276 let data = dataset_with_weights(&[1.0e12, 1.0e-12, 3.5, 7.25e4, 2.0e-3]);
1277 let mc = dataset_with_weights(&[4.0e9, 9.0e-6, 1.25, 2.5e2, 8.0e-4]);
1278 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1279 let params = vec![1.125];
1280
1281 let intensity: f64 = params[0] * params[0];
1282 let data_weight_sum = data.weights_local().iter().copied().sum::<f64>();
1283 let mc_weight_sum = mc.weights_local().iter().copied().sum::<f64>();
1284 let n_mc = mc.n_events_weighted();
1285 let expected = -2.0 * (data_weight_sum * intensity.ln() - mc_weight_sum * intensity / n_mc);
1286
1287 let value = nll.evaluate(¶ms).unwrap();
1288 assert_relative_eq!(value, expected, epsilon = 1e-9, max_relative = 1e-12);
1289 }
1290
1291 #[test]
1292 fn nll_evaluate_and_gradient_match_hardcoded_weighted_reference() {
1293 let amp_a = CachedBeamScaleAmplitude::new("amp_a", parameter!("alpha"), 0).unwrap();
1294 let amp_b = CachedBeamScaleAmplitude::new("amp_b", parameter!("beta"), 1).unwrap();
1295 let expr = (&_a + &_b).norm_sqr();
1296 let data = dataset_with_two_p4_and_weights(
1297 &[(1.0, 0.8), (2.5, 1.7), (4.0, 2.4), (3.3, 1.1)],
1298 &[0.7, 1.2, 0.9, 1.5],
1299 );
1300 let mc = dataset_with_two_p4_and_weights(
1301 &[(1.5, 1.0), (3.0, 2.1), (5.5, 2.9), (2.0, 1.2), (4.2, 1.8)],
1302 &[0.8, 1.4, 0.6, 1.1, 0.75],
1303 );
1304 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1305 let params = vec![0.6, 1.1];
1306 assert_eq!(nll.parameters().free().names(), vec!["alpha", "beta"]);
1307
1308 let value = nll.evaluate(¶ms).unwrap();
1309 assert_relative_eq!(value, 12.242296380697244, epsilon = 1e-12);
1310
1311 let gradient = nll.evaluate_gradient(¶ms).unwrap();
1312 assert_eq!(gradient.len(), 2);
1313 assert_relative_eq!(gradient[0], 37.78259267741666, epsilon = 1e-12);
1314 assert_relative_eq!(gradient[1], 21.8538272590435, epsilon = 1e-12);
1315 }
1316
1317 #[test]
1318 fn nll_deterministic_fixtures_cover_separable_partial_and_non_separable_models() {
1319 let separable = make_deterministic_nll_fixture(DeterministicModelKind::Separable);
1320 let partial = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1321 let non_separable = make_deterministic_nll_fixture(DeterministicModelKind::NonSeparable);
1322
1323 for fixture in [separable, partial, non_separable] {
1324 assert_nll_fixture_matches_weighted_baseline(&fixture);
1325 }
1326 }
1327
1328 #[test]
1329 fn nll_deterministic_fixture_matches_baseline_across_activation_toggles() {
1330 let fixture = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1331 assert_nll_fixture_matches_weighted_baseline(&fixture);
1332
1333 fixture.nll.isolate_many(&["p", "c"]);
1334 assert_nll_fixture_matches_weighted_baseline(&fixture);
1335
1336 fixture.nll.activate_all();
1337 assert_nll_fixture_matches_weighted_baseline(&fixture);
1338 }
1339
1340 #[test]
1341 fn nll_project_returns_weighted_intensity() {
1342 let (nll, params) = make_constant_nll();
1343 let projection = nll.project_weights_local(¶ms, None).unwrap();
1344 assert_relative_eq!(projection[0], 1.0, epsilon = 1e-12);
1345 assert_relative_eq!(projection[1], 3.0, epsilon = 1e-12);
1346 }
1347
1348 #[test]
1349 fn nll_project_reports_structured_length_error() {
1350 let (nll, _) = make_constant_nll();
1351 let err = nll.project_weights(&[], None).unwrap_err();
1352 assert!(matches!(
1353 err,
1354 LadduError::LengthMismatch {
1355 expected: 1,
1356 actual: 0,
1357 ..
1358 }
1359 ));
1360 }
1361
1362 #[test]
1363 fn nll_project_weights_subset_reports_structured_missing_amplitude_error() {
1364 let (nll, params) = make_constant_nll();
1365 let err = nll
1366 .project_weights_subset_local_strict::<&str>(¶ms, &["missing_amplitude"], None)
1367 .unwrap_err();
1368 assert!(matches!(err, LadduError::AmplitudeNotFoundError { .. }));
1369 }
1370
1371 #[test]
1372 fn nll_project_weights_subsets_matches_repeated_project_weights_subset_calls() {
1373 let (nll, params) = make_two_parameter_nll();
1374 let subsets = vec![
1375 vec!["amp_a".to_string()],
1376 vec!["amp_b".to_string()],
1377 vec!["amp_a".to_string(), "amp_b".to_string()],
1378 ];
1379 let batched = nll
1380 .project_weights_subsets_local(¶ms, &subsets, None)
1381 .expect("batched projection should evaluate");
1382 let repeated = subsets
1383 .iter()
1384 .map(|subset| {
1385 nll.project_weights_subset_local(¶ms, subset, None)
1386 .expect("single subset projection should evaluate")
1387 })
1388 .collect::<Vec<_>>();
1389 assert_eq!(batched.len(), repeated.len());
1390 for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1391 assert_eq!(lhs.len(), rhs.len());
1392 for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1393 assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1394 }
1395 }
1396 }
1397
1398 #[test]
1399 fn nll_project_weights_subsets_handles_empty_and_duplicate_subsets() {
1400 let (nll, params) = make_two_parameter_nll();
1401 let empty: Vec<Vec<String>> = Vec::new();
1402 let empty_projection = nll
1403 .project_weights_subsets_local(¶ms, &empty, None)
1404 .expect("empty subset list should evaluate");
1405 assert!(empty_projection.is_empty());
1406
1407 let subsets = vec![
1408 vec!["amp_b".to_string()],
1409 vec!["amp_a".to_string()],
1410 vec!["amp_a".to_string(), "amp_b".to_string()],
1411 vec!["amp_a".to_string()],
1412 vec!["amp_b".to_string()],
1413 ];
1414 let batched = nll
1415 .project_weights_subsets_local(¶ms, &subsets, None)
1416 .expect("batched projection should evaluate");
1417 let repeated = subsets
1418 .iter()
1419 .map(|subset| {
1420 nll.project_weights_subset_local(¶ms, subset, None)
1421 .expect("single subset projection should evaluate")
1422 })
1423 .collect::<Vec<_>>();
1424 assert_eq!(batched.len(), repeated.len());
1425 for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1426 assert_eq!(lhs.len(), rhs.len());
1427 for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1428 assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1429 }
1430 }
1431 }
1432
1433 #[test]
1434 fn nll_project_weights_subsets_reports_missing_amplitude_error() {
1435 let (nll, params) = make_two_parameter_nll();
1436 let subsets = vec![vec!["amp_a".to_string()], vec!["missing".to_string()]];
1437 let err = nll
1438 .project_weights_subsets_local_strict(¶ms, &subsets, None)
1439 .expect_err("missing amplitude should fail");
1440 assert!(matches!(err, LadduError::AmplitudeNotFoundError { .. }));
1441 }
1442
1443 #[test]
1444 fn nll_project_weights_and_gradients_subset_matches_repeated_calls() {
1445 let (nll, params) = make_two_parameter_nll();
1446 let subsets = vec![
1447 vec!["amp_b".to_string()],
1448 vec!["amp_a".to_string()],
1449 vec!["amp_a".to_string(), "amp_b".to_string()],
1450 vec!["amp_a".to_string()],
1451 ];
1452 for subset in subsets {
1453 let (weights_local, gradients_local) = nll
1454 .project_weights_and_gradients_subset_local(¶ms, &subset, None)
1455 .expect("local gradient projection should evaluate");
1456 let (weights_auto, gradients_auto) = nll
1457 .project_weights_and_gradients_subset(¶ms, &subset, None)
1458 .expect("auto gradient projection should evaluate");
1459 assert_eq!(weights_local.len(), weights_auto.len());
1460 assert_eq!(gradients_local.len(), gradients_auto.len());
1461 for (lhs, rhs) in weights_local.iter().zip(weights_auto.iter()) {
1462 assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1463 }
1464 for (lhs, rhs) in gradients_local.iter().zip(gradients_auto.iter()) {
1465 assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1466 }
1467 }
1468 }
1469
1470 #[test]
1471 fn nll_activation_changes_invalidate_projection_mask_cache() {
1472 let (nll, params) = make_constant_nll();
1473 assert!(nll.projection_active_mask_cache.lock().is_empty());
1474
1475 let _ = nll
1476 .project_weights_subset_local::<&str>(¶ms, &["amp"], None)
1477 .unwrap();
1478 assert!(!nll.projection_active_mask_cache.lock().is_empty());
1479
1480 nll.deactivate("amp");
1481 assert!(nll.projection_active_mask_cache.lock().is_empty());
1482
1483 let projection = nll
1484 .project_weights_subset_local::<&str>(¶ms, &["amp"], None)
1485 .unwrap();
1486 assert_relative_eq!(projection[0], 1.0, epsilon = 1e-12);
1487 assert_relative_eq!(projection[1], 3.0, epsilon = 1e-12);
1488 }
1489
1490 #[test]
1491 fn nll_project_weights_subset_validates_length_before_isolation() {
1492 let (nll, _) = make_constant_nll();
1493 let err = nll
1494 .project_weights_subset_local::<&str>(&[], &["missing_amplitude"], None)
1495 .unwrap_err();
1496 assert!(matches!(
1497 err,
1498 LadduError::LengthMismatch {
1499 expected: 1,
1500 actual: 0,
1501 ..
1502 }
1503 ));
1504 }
1505
1506 #[test]
1507 fn nll_project_weights_and_gradients_subset_validates_length_before_isolation() {
1508 let (nll, _) = make_constant_nll();
1509 let err = nll
1510 .project_weights_and_gradients_subset_local::<&str>(
1511 &[1.0, 2.0],
1512 &["missing_amplitude"],
1513 None,
1514 )
1515 .unwrap_err();
1516 assert!(matches!(
1517 err,
1518 LadduError::LengthMismatch {
1519 expected: 1,
1520 actual: 2,
1521 ..
1522 }
1523 ));
1524 }
1525
1526 #[test]
1527 fn stochastic_nll_validates_batch_size() {
1528 let (nll, _params) = make_constant_nll();
1529 let err_zero = match nll.to_stochastic(0, Some(0)) {
1530 Ok(_) => panic!("expected batch_size=0 to return an error"),
1531 Err(err) => err,
1532 };
1533 assert!(matches!(
1534 err_zero,
1535 LadduError::LengthMismatch {
1536 expected: 2,
1537 actual: 0,
1538 ..
1539 }
1540 ));
1541
1542 let err_large = match nll.to_stochastic(3, Some(0)) {
1543 Ok(_) => panic!("expected oversized batch to return an error"),
1544 Err(err) => err,
1545 };
1546 assert!(matches!(
1547 err_large,
1548 LadduError::LengthMismatch {
1549 expected: 2,
1550 actual: 3,
1551 ..
1552 }
1553 ));
1554 }
1555
1556 #[test]
1557 fn stochastic_nll_accepts_full_dataset_batch() {
1558 let (nll, params) = make_constant_nll();
1559 let stochastic = nll.to_stochastic(2, Some(0)).unwrap();
1560 let value = stochastic.evaluate(¶ms).unwrap();
1561 assert!(value.is_finite());
1562 }
1563
1564 #[test]
1565 fn stochastic_nll_matches_closed_form_on_full_batch() {
1566 let (nll, params) = make_constant_nll();
1567 let stochastic = nll
1568 .to_stochastic(nll.data_evaluator.dataset.n_events(), Some(0))
1569 .unwrap();
1570 let stochastic_value = stochastic.evaluate(¶ms).unwrap();
1571 let deterministic_value = nll.evaluate(¶ms).unwrap();
1572 assert_relative_eq!(stochastic_value, deterministic_value, epsilon = 1e-12);
1573 }
1574
1575 #[test]
1576 fn likelihood_evaluator_reports_length_mismatch() {
1577 let alpha = LikelihoodScalar::new("alpha").unwrap();
1578
1579 let err_short = alpha.evaluate(&[]).unwrap_err();
1580 assert!(matches!(
1581 err_short,
1582 LadduError::LengthMismatch {
1583 expected: 1,
1584 actual: 0,
1585 ..
1586 }
1587 ));
1588
1589 let err_long = alpha.evaluate_gradient(&[1.0, 2.0]).unwrap_err();
1590 assert!(matches!(
1591 err_long,
1592 LadduError::LengthMismatch {
1593 expected: 1,
1594 actual: 2,
1595 ..
1596 }
1597 ));
1598 }
1599
1600 #[cfg(feature = "mpi")]
1601 #[mpi_test(np = [2])]
1602 fn mpi_negative_paths_report_structured_errors() {
1603 use_mpi(true);
1604 let world = get_world().expect("MPI world should be initialized");
1605 let (nll, params) = make_constant_nll();
1606
1607 let err_len = nll.project_weights_mpi(&[], None, &world).unwrap_err();
1608 assert!(matches!(
1609 err_len,
1610 LadduError::LengthMismatch {
1611 expected: 1,
1612 actual: 0,
1613 ..
1614 }
1615 ));
1616
1617 let err_amp = nll
1618 .project_weights_subset_mpi_strict::<&str>(
1619 ¶ms,
1620 &["missing_amplitude"],
1621 None,
1622 &world,
1623 )
1624 .unwrap_err();
1625 assert!(matches!(err_amp, LadduError::AmplitudeNotFoundError { .. }));
1626 finalize_mpi();
1627 }
1628
1629 #[cfg(feature = "mpi")]
1630 #[mpi_test(np = [2])]
1631 fn mpi_value_and_gradient_match_total_non_mpi() {
1632 use_mpi(true);
1633 let world = get_world().expect("MPI world should be initialized");
1634 let (nll, params) = make_constant_nll();
1635 let data_term_local = crate::likelihood::nll::evaluate_weighted_expression_sum_local(
1636 &nll.data_evaluator,
1637 ¶ms,
1638 |l| f64::ln(l.re),
1639 )
1640 .expect("evaluate should succeed");
1641 let mc_term_local = nll
1642 .accmc_evaluator
1643 .evaluate_weighted_value_sum_local(¶ms)
1644 .expect("evaluate should succeed");
1645 let data_term = crate::likelihood::nll::reduce_scalar(&world, data_term_local);
1646 let mc_term = crate::likelihood::nll::reduce_scalar(&world, mc_term_local);
1647 let expected_value = -2.0 * (data_term - mc_term / nll.n_mc);
1648
1649 let mpi_value = nll
1650 .evaluate_mpi(¶ms, &world)
1651 .expect("evaluate should succeed");
1652 assert_relative_eq!(mpi_value, expected_value);
1653
1654 let data_gradient_local = nll
1655 .evaluate_data_gradient_term_local(¶ms)
1656 .expect("evaluate should succeed");
1657 let mc_gradient_local = nll
1658 .accmc_evaluator
1659 .evaluate_weighted_gradient_sum_local(¶ms)
1660 .expect("evaluate should succeed");
1661 let data_gradient = crate::likelihood::nll::reduce_gradient(&world, &data_gradient_local);
1662 let mc_gradient = crate::likelihood::nll::reduce_gradient(&world, &mc_gradient_local);
1663 let expected_gradient = -2.0 * (data_gradient - mc_gradient / nll.n_mc);
1664 let mpi_gradient = nll
1665 .evaluate_gradient_mpi(¶ms, &world)
1666 .expect("evaluate should succeed");
1667 assert_relative_eq!(mpi_gradient, expected_gradient);
1668
1669 finalize_mpi();
1670 }
1671
1672 #[cfg(feature = "mpi")]
1673 #[mpi_test(np = [2])]
1674 fn mpi_deterministic_fixture_matches_local_and_reduced_baselines_across_activation_toggles() {
1675 use_mpi(true);
1676 let world = get_world().expect("MPI world should be initialized");
1677
1678 let fixture = make_deterministic_nll_fixture(DeterministicModelKind::Partial);
1679 assert_nll_fixture_matches_weighted_baseline(&fixture);
1680 assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1681
1682 fixture.nll.isolate_many(&["p", "c"]);
1683 assert_nll_fixture_matches_weighted_baseline(&fixture);
1684 assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1685
1686 fixture.nll.activate_all();
1687 assert_nll_fixture_matches_weighted_baseline(&fixture);
1688 assert_nll_fixture_matches_mpi_reduced_baseline(&fixture, &world);
1689
1690 finalize_mpi();
1691 }
1692
1693 #[cfg(feature = "mpi")]
1694 #[mpi_test(np = [2])]
1695 fn mpi_mixed_scale_value_matches_local_evaluate() {
1696 use_mpi(true);
1697 let world = get_world().expect("MPI world should be initialized");
1698 let amp_a = CachedBeamScaleAmplitude::new("amp_a", parameter!("scale_a"), 0).unwrap();
1699 let amp_b = CachedBeamScaleAmplitude::new("amp_b", parameter!("scale_b"), 1).unwrap();
1700 let expr = (amp_a + amp_b).norm_sqr();
1701 let data = dataset_with_two_p4_and_weights(
1702 &[(1.0, 0.5), (10.0, 1.0), (3.0, 5.0), (1.0e2, 2.0e-1)],
1703 &[1.0e12, 1.0e-12, 3.5, 7.25e4],
1704 );
1705 let mc = dataset_with_two_p4_and_weights(
1706 &[(4.0, 0.1), (6.0, 2.0), (8.0, 1.5), (1.0e1, 3.0)],
1707 &[4.0e9, 9.0e-6, 1.25, 2.5e2],
1708 );
1709 let nll = NLL::new(&expr, &data, &mc, None).unwrap();
1710 let params = vec![1.125, -0.375];
1711
1712 let data_local = nll
1713 .data_evaluator
1714 .evaluate_local(¶ms)
1715 .expect("evaluate should succeed");
1716 let mc_local = nll
1717 .accmc_evaluator
1718 .evaluate_local(¶ms)
1719 .expect("evaluate should succeed");
1720 let data_term_local: f64 = data_local
1721 .iter()
1722 .zip(nll.data_evaluator.dataset.weights_local().iter())
1723 .map(|(value, event)| *event * value.re.ln())
1724 .sum();
1725 let mc_term_local: f64 = mc_local
1726 .iter()
1727 .zip(nll.accmc_evaluator.dataset.weights_local().iter())
1728 .map(|(value, event)| *event * value.re)
1729 .sum();
1730 let data_term = crate::likelihood::nll::reduce_scalar(&world, data_term_local);
1731 let mc_term = crate::likelihood::nll::reduce_scalar(&world, mc_term_local);
1732 let expected = -2.0 * (data_term - mc_term / nll.n_mc);
1733 let mpi_value = nll
1734 .evaluate_mpi(¶ms, &world)
1735 .expect("evaluate should succeed");
1736 assert_relative_eq!(mpi_value, expected, epsilon = 1e-9, max_relative = 1e-12);
1737 finalize_mpi();
1738 }
1739
1740 #[cfg(feature = "mpi")]
1741 #[mpi_test(np = [2])]
1742 fn mpi_projection_paths_are_explicit_global_gathers() {
1743 use_mpi(true);
1744 let world = get_world().expect("MPI world should be initialized");
1745 let (nll, params) = make_constant_nll();
1746
1747 let local_projection = nll
1748 .project_weights_local(¶ms, None)
1749 .expect("local projection should evaluate");
1750 let gathered_projection = nll
1751 .project_weights_mpi(¶ms, None, &world)
1752 .expect("mpi projection should gather global projection");
1753 let local_len = nll.accmc_evaluator.dataset.n_events_local();
1754 let total_len = nll.accmc_evaluator.dataset.n_events();
1755 assert_eq!(local_projection.len(), local_len);
1756 assert_eq!(gathered_projection.len(), total_len);
1757
1758 let (counts, displs) = world.get_counts_displs(total_len);
1759 let rank = world.rank() as usize;
1760 let start = displs[rank] as usize;
1761 let end = start + counts[rank] as usize;
1762 assert_eq!(
1763 &gathered_projection[start..end],
1764 local_projection.as_slice()
1765 );
1766
1767 let (local_weights, local_gradients) = nll
1768 .project_weights_and_gradients_local(¶ms, None)
1769 .expect("local projection gradient should evaluate");
1770 let (gathered_weights, gathered_gradients) = nll
1771 .project_weights_and_gradients_mpi(¶ms, None, &world)
1772 .expect("mpi projection gradient should gather global projection");
1773 assert_eq!(local_weights.len(), local_len);
1774 assert_eq!(local_gradients.len(), local_len);
1775 assert_eq!(gathered_weights.len(), total_len);
1776 assert_eq!(gathered_gradients.len(), total_len);
1777 assert_eq!(&gathered_weights[start..end], local_weights.as_slice());
1778
1779 let local_grad_slice = &gathered_gradients[start..end];
1780 for (lhs, rhs) in local_grad_slice.iter().zip(local_gradients.iter()) {
1781 assert_relative_eq!(lhs, rhs);
1782 }
1783 finalize_mpi();
1784 }
1785
1786 #[cfg(feature = "mpi")]
1787 #[mpi_test(np = [2])]
1788 fn mpi_project_weights_subsets_matches_repeated_project_weights_subset_mpi() {
1789 use_mpi(true);
1790 let world = get_world().expect("MPI world should be initialized");
1791 let (nll, params) = make_two_parameter_nll();
1792 let subsets = vec![
1793 vec!["amp_b".to_string()],
1794 vec!["amp_a".to_string()],
1795 vec!["amp_a".to_string(), "amp_b".to_string()],
1796 vec!["amp_a".to_string()],
1797 ];
1798 let batched = nll
1799 .project_weights_subsets_mpi(¶ms, &subsets, None, &world)
1800 .expect("batched mpi projection should evaluate");
1801 let repeated = subsets
1802 .iter()
1803 .map(|subset| {
1804 nll.project_weights_subset_mpi(¶ms, subset, None, &world)
1805 .expect("single mpi subset projection should evaluate")
1806 })
1807 .collect::<Vec<_>>();
1808 assert_eq!(batched.len(), repeated.len());
1809 for (lhs, rhs) in batched.iter().zip(repeated.iter()) {
1810 assert_eq!(lhs.len(), rhs.len());
1811 for (lhs_value, rhs_value) in lhs.iter().zip(rhs.iter()) {
1812 assert_relative_eq!(lhs_value, rhs_value, epsilon = 1e-12);
1813 }
1814 }
1815 finalize_mpi();
1816 }
1817
1818 #[cfg(feature = "mpi")]
1819 #[mpi_test(np = [2])]
1820 fn mpi_project_weights_and_gradients_subset_matches_repeated_project_weights_and_gradients_subset_mpi(
1821 ) {
1822 use_mpi(true);
1823 let world = get_world().expect("MPI world should be initialized");
1824 let (nll, params) = make_two_parameter_nll();
1825 let subsets = vec![
1826 vec!["amp_b".to_string()],
1827 vec!["amp_a".to_string()],
1828 vec!["amp_a".to_string(), "amp_b".to_string()],
1829 ];
1830 for subset in subsets {
1831 let (weights_mpi, gradients_mpi) = nll
1832 .project_weights_and_gradients_subset_mpi(¶ms, &subset, None, &world)
1833 .expect("mpi gradient projection should evaluate");
1834 let (weights_auto, gradients_auto) = nll
1835 .project_weights_and_gradients_subset(¶ms, &subset, None)
1836 .expect("auto gradient projection should evaluate");
1837 assert_eq!(weights_mpi.len(), weights_auto.len());
1838 assert_eq!(gradients_mpi.len(), gradients_auto.len());
1839 for (lhs, rhs) in weights_mpi.iter().zip(weights_auto.iter()) {
1840 assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1841 }
1842 for (lhs, rhs) in gradients_mpi.iter().zip(gradients_auto.iter()) {
1843 assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
1844 }
1845 }
1846 finalize_mpi();
1847 }
1848
1849 #[cfg(feature = "mpi")]
1850 #[mpi_test(np = [2])]
1851 fn mpi_mixed_workload_rss_stays_bounded() {
1852 use_mpi(true);
1853 let world = get_world().expect("MPI world should be initialized");
1854 let fixture = make_mixed_workload_nll_fixture(2_048);
1855
1856 let baseline_value = fixture
1857 .nll
1858 .evaluate_mpi(&fixture.parameters, &world)
1859 .expect("evaluate should succeed");
1860 let baseline_gradient = fixture
1861 .nll
1862 .evaluate_gradient_mpi(&fixture.parameters, &world)
1863 .expect("evaluate should succeed");
1864 let baseline_weights = fixture
1865 .nll
1866 .project_weights_mpi(&fixture.parameters, None, &world)
1867 .expect("baseline MPI projection should evaluate");
1868 let (baseline_projection_weights, baseline_projection_gradients) = fixture
1869 .nll
1870 .project_weights_and_gradients_mpi(&fixture.parameters, None, &world)
1871 .expect("baseline MPI projection gradient should evaluate");
1872 let mut post_warmup_rss_kb = Vec::new();
1873
1874 assert_relative_eq!(
1875 baseline_weights.as_slice(),
1876 baseline_projection_weights.as_slice(),
1877 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1878 max_relative = DETERMINISTIC_STRICT_REL_TOL
1879 );
1880
1881 for pass_index in 0..24 {
1882 let value = fixture
1883 .nll
1884 .evaluate_mpi(&fixture.parameters, &world)
1885 .expect("evaluate should succeed");
1886 assert_relative_eq!(
1887 value,
1888 baseline_value,
1889 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1890 max_relative = DETERMINISTIC_STRICT_REL_TOL
1891 );
1892
1893 let gradient = fixture
1894 .nll
1895 .evaluate_gradient_mpi(&fixture.parameters, &world)
1896 .expect("evaluate should succeed");
1897 assert_eq!(
1898 gradient.len(),
1899 baseline_gradient.len(),
1900 "mixed-workload MPI gradient length should remain stable"
1901 );
1902 for (actual_item, expected_item) in gradient.iter().zip(baseline_gradient.iter()) {
1903 assert_relative_eq!(
1904 *actual_item,
1905 *expected_item,
1906 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1907 max_relative = DETERMINISTIC_STRICT_REL_TOL
1908 );
1909 }
1910
1911 let weights = fixture
1912 .nll
1913 .project_weights_mpi(&fixture.parameters, None, &world)
1914 .expect("MPI projection should remain evaluable");
1915 assert_eq!(
1916 weights.len(),
1917 baseline_weights.len(),
1918 "mixed-workload MPI projection length should remain stable"
1919 );
1920 for (actual_item, expected_item) in weights.iter().zip(baseline_weights.iter()) {
1921 assert_relative_eq!(
1922 *actual_item,
1923 *expected_item,
1924 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1925 max_relative = DETERMINISTIC_STRICT_REL_TOL
1926 );
1927 }
1928
1929 let (projection_weights, projection_gradients) = fixture
1930 .nll
1931 .project_weights_and_gradients_mpi(&fixture.parameters, None, &world)
1932 .expect("MPI projection gradients should remain evaluable");
1933 assert_eq!(
1934 projection_weights.len(),
1935 baseline_projection_weights.len(),
1936 "mixed-workload MPI projection-gradient weight length should remain stable"
1937 );
1938 assert_eq!(
1939 projection_gradients.len(),
1940 baseline_projection_gradients.len(),
1941 "mixed-workload MPI projection-gradient length should remain stable"
1942 );
1943 for (actual_item, expected_item) in projection_weights
1944 .iter()
1945 .zip(baseline_projection_weights.iter())
1946 {
1947 assert_relative_eq!(
1948 *actual_item,
1949 *expected_item,
1950 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1951 max_relative = DETERMINISTIC_STRICT_REL_TOL
1952 );
1953 }
1954 for (actual_gradient, expected_gradient) in projection_gradients
1955 .iter()
1956 .zip(baseline_projection_gradients.iter())
1957 {
1958 assert_eq!(
1959 actual_gradient.len(),
1960 expected_gradient.len(),
1961 "mixed-workload MPI projection-gradient vector length should remain stable"
1962 );
1963 for (actual_item, expected_item) in
1964 actual_gradient.iter().zip(expected_gradient.iter())
1965 {
1966 assert_relative_eq!(
1967 *actual_item,
1968 *expected_item,
1969 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
1970 max_relative = DETERMINISTIC_STRICT_REL_TOL
1971 );
1972 }
1973 }
1974
1975 if pass_index >= 3 {
1976 if let Some(rss_kb) = read_resident_rss_kb() {
1977 post_warmup_rss_kb.push(rss_kb);
1978 }
1979 }
1980 }
1981
1982 if let Some((&first_rss_kb, rest_rss_kb)) = post_warmup_rss_kb.split_first() {
1983 let last_rss_kb = *rest_rss_kb.last().unwrap_or(&first_rss_kb);
1984 let min_rss_kb = post_warmup_rss_kb
1985 .iter()
1986 .copied()
1987 .min()
1988 .expect("post-warmup RSS sample should exist");
1989 let max_rss_kb = post_warmup_rss_kb
1990 .iter()
1991 .copied()
1992 .max()
1993 .expect("post-warmup RSS sample should exist");
1994 const MAX_POST_WARMUP_RSS_GROWTH_KB: u64 = 64 * 1024;
1995 const MAX_POST_WARMUP_RSS_SPREAD_KB: u64 = 64 * 1024;
1996 assert!(
1997 last_rss_kb.saturating_sub(first_rss_kb) <= MAX_POST_WARMUP_RSS_GROWTH_KB,
1998 "mixed-workload post-warmup RSS grew by {} KiB (first={} KiB, last={} KiB)",
1999 last_rss_kb.saturating_sub(first_rss_kb),
2000 first_rss_kb,
2001 last_rss_kb
2002 );
2003 assert!(
2004 max_rss_kb.saturating_sub(min_rss_kb) <= MAX_POST_WARMUP_RSS_SPREAD_KB,
2005 "mixed-workload post-warmup RSS spread was {} KiB (min={} KiB, max={} KiB)",
2006 max_rss_kb.saturating_sub(min_rss_kb),
2007 min_rss_kb,
2008 max_rss_kb
2009 );
2010 }
2011
2012 finalize_mpi();
2013 }
2014}