1#[cfg(feature = "rayon")]
2use std::cell::RefCell;
3#[cfg(feature = "rayon")]
4use std::fmt::Debug;
5use std::{collections::HashMap, sync::Arc};
6
7use accurate::{sum::Klein, traits::*};
8use fastrand::Rng;
9#[cfg(feature = "mpi")]
10use laddu_core::mpi::LadduMPI;
11use laddu_core::{
12 amplitude::{CompiledExpression, Evaluator, Expression, ParameterMap},
13 data::Dataset,
14 validate_free_parameter_len, LadduError, LadduResult,
15};
16#[cfg(feature = "mpi")]
17use mpi::{
18 collective::SystemOperation, datatype::PartitionMut, topology::SimpleCommunicator, traits::*,
19};
20use nalgebra::DVector;
21use num::complex::Complex64;
22use parking_lot::Mutex;
23#[cfg(feature = "rayon")]
24use rayon::prelude::*;
25
26use super::term::LikelihoodTerm;
27use crate::random::RngSubsetExtension;
28
29pub(crate) type ProjectionMaskCacheKey = (bool, Vec<String>);
30
31pub(crate) fn validate_stochastic_batch_size(
32 batch_size: usize,
33 n_events: usize,
34) -> LadduResult<()> {
35 if n_events == 0 {
36 return Err(LadduError::Custom(
37 "stochastic batch_size requires a non-empty dataset".to_string(),
38 ));
39 }
40 if batch_size == 0 || batch_size > n_events {
41 return Err(LadduError::LengthMismatch {
42 context: format!("stochastic batch_size (valid range: 1..={n_events})"),
43 expected: n_events,
44 actual: batch_size,
45 });
46 }
47 Ok(())
48}
49
50#[cfg(feature = "mpi")]
51pub(crate) fn reduce_scalar(world: &SimpleCommunicator, value: f64) -> f64 {
52 let mut reduced = 0.0;
53 world.all_reduce_into(&value, &mut reduced, SystemOperation::sum());
54 reduced
55}
56
57#[cfg(feature = "mpi")]
58pub(crate) fn reduce_gradient(world: &SimpleCommunicator, gradient: &DVector<f64>) -> DVector<f64> {
59 let mut reduced = vec![0.0; gradient.len()];
60 world.all_reduce_into(gradient.as_slice(), &mut reduced, SystemOperation::sum());
61 DVector::from_vec(reduced)
62}
63
64pub(crate) fn evaluate_weighted_expression_sum_local<F>(
65 evaluator: &Evaluator,
66 parameters: &[f64],
67 value_map: F,
68) -> LadduResult<f64>
69where
70 F: Fn(Complex64) -> f64 + Copy + Send + Sync,
71{
72 let resources = evaluator.resources.read();
73 let parameters = resources.parameter_map.assemble(parameters)?;
74 let amplitude_len = evaluator.amplitude_value_slot_count();
75 let active_indices = resources.active_indices().to_vec();
76 let program_snapshot = evaluator.expression_value_program_snapshot();
77 let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
78 #[cfg(feature = "rayon")]
79 {
80 Ok(resources
81 .caches
82 .par_iter()
83 .zip(evaluator.dataset.weights_local().par_iter())
84 .map_init(
85 || {
86 (
87 vec![Complex64::ZERO; amplitude_len],
88 vec![Complex64::ZERO; slot_count],
89 )
90 },
91 |(amplitude_values, expr_slots), (cache, event)| {
92 evaluator.fill_amplitude_values(
93 amplitude_values,
94 &active_indices,
95 ¶meters,
96 cache,
97 );
98 let l = evaluator.evaluate_expression_value_with_program_snapshot(
99 &program_snapshot,
100 amplitude_values,
101 expr_slots,
102 );
103 *event * value_map(l)
104 },
105 )
106 .parallel_sum_with_accumulator::<Klein<f64>>())
107 }
108 #[cfg(not(feature = "rayon"))]
109 {
110 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
111 let mut expr_slots = vec![Complex64::ZERO; slot_count];
112 Ok(resources
113 .caches
114 .iter()
115 .zip(evaluator.dataset.weights_local().iter())
116 .map(|(cache, event)| {
117 evaluator.fill_amplitude_values(
118 &mut amplitude_values,
119 &active_indices,
120 ¶meters,
121 cache,
122 );
123 let l = evaluator.evaluate_expression_value_with_program_snapshot(
124 &program_snapshot,
125 &litude_values,
126 &mut expr_slots,
127 );
128 *event * value_map(l)
129 })
130 .sum_with_accumulator::<Klein<f64>>())
131 }
132}
133
134pub(crate) fn project_weights_local_from_evaluator(
135 evaluator: &Evaluator,
136 parameters: &[f64],
137 n_mc: f64,
138) -> LadduResult<Vec<f64>> {
139 let resources = evaluator.resources.read();
140 let parameters = resources.parameter_map.assemble(parameters)?;
141 let amplitude_len = evaluator.amplitude_value_slot_count();
142 let active_indices = resources.active_indices().to_vec();
143 let program_snapshot = evaluator.expression_value_program_snapshot();
144 let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
145 #[cfg(feature = "rayon")]
146 {
147 Ok(resources
148 .caches
149 .par_iter()
150 .zip(evaluator.dataset.weights_local().par_iter())
151 .map_init(
152 || {
153 (
154 vec![Complex64::ZERO; amplitude_len],
155 vec![Complex64::ZERO; slot_count],
156 )
157 },
158 |(amplitude_values, expr_slots), (cache, event)| {
159 evaluator.fill_amplitude_values(
160 amplitude_values,
161 &active_indices,
162 ¶meters,
163 cache,
164 );
165 let value = evaluator.evaluate_expression_value_with_program_snapshot(
166 &program_snapshot,
167 amplitude_values,
168 expr_slots,
169 );
170 *event * value.re / n_mc
171 },
172 )
173 .collect())
174 }
175 #[cfg(not(feature = "rayon"))]
176 {
177 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
178 let mut expr_slots = vec![Complex64::ZERO; slot_count];
179 Ok(resources
180 .caches
181 .iter()
182 .zip(evaluator.dataset.weights_local().iter())
183 .map(|(cache, event)| {
184 evaluator.fill_amplitude_values(
185 &mut amplitude_values,
186 &active_indices,
187 ¶meters,
188 cache,
189 );
190 let value = evaluator.evaluate_expression_value_with_program_snapshot(
191 &program_snapshot,
192 &litude_values,
193 &mut expr_slots,
194 );
195 *event * value.re / n_mc
196 })
197 .collect())
198 }
199}
200
201pub(crate) fn project_weights_local_from_resolved_mask(
202 evaluator: &Evaluator,
203 parameters: &[f64],
204 n_mc: f64,
205 resolved_mask: &[bool],
206) -> LadduResult<Vec<f64>> {
207 let resources = evaluator.resources.read();
208 let parameters = resources.parameter_map.assemble(parameters)?;
209 let amplitude_len = evaluator.amplitude_value_slot_count();
210 let active_indices = resolved_mask
211 .iter()
212 .enumerate()
213 .filter_map(|(index, &active)| if active { Some(index) } else { None })
214 .collect::<Vec<_>>();
215 let program_snapshot =
216 evaluator.expression_value_program_snapshot_for_active_mask(resolved_mask)?;
217 let slot_count = evaluator.expression_value_program_snapshot_slot_count(&program_snapshot);
218 #[cfg(feature = "rayon")]
219 {
220 Ok(resources
221 .caches
222 .par_iter()
223 .zip(evaluator.dataset.weights_local().par_iter())
224 .map_init(
225 || {
226 (
227 vec![Complex64::ZERO; amplitude_len],
228 vec![Complex64::ZERO; slot_count],
229 )
230 },
231 |(amplitude_values, expr_slots), (cache, event)| {
232 evaluator.fill_amplitude_values(
233 amplitude_values,
234 &active_indices,
235 ¶meters,
236 cache,
237 );
238 let value = evaluator.evaluate_expression_value_with_program_snapshot(
239 &program_snapshot,
240 amplitude_values,
241 expr_slots,
242 );
243 *event * value.re / n_mc
244 },
245 )
246 .collect())
247 }
248 #[cfg(not(feature = "rayon"))]
249 {
250 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
251 let mut expr_slots = vec![Complex64::ZERO; slot_count];
252 Ok(resources
253 .caches
254 .iter()
255 .zip(evaluator.dataset.weights_local().iter())
256 .map(|(cache, event)| {
257 evaluator.fill_amplitude_values(
258 &mut amplitude_values,
259 &active_indices,
260 ¶meters,
261 cache,
262 );
263 let value = evaluator.evaluate_expression_value_with_program_snapshot(
264 &program_snapshot,
265 &litude_values,
266 &mut expr_slots,
267 );
268 *event * value.re / n_mc
269 })
270 .collect())
271 }
272}
273
274pub(crate) fn project_weights_and_gradients_local_from_evaluator(
275 evaluator: &Evaluator,
276 parameters: &[f64],
277 n_mc: f64,
278) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
279 let resources = evaluator.resources.read();
280 let parameters = resources.parameter_map.assemble(parameters)?;
281 let amplitude_len = evaluator.amplitude_value_slot_count();
282 let grad_dim = parameters.len();
283 let active_indices = resources.active_indices().to_vec();
284 let active_mask = resources.active.clone();
285 let slot_count = evaluator.expression_value_gradient_slot_count_public();
286
287 #[cfg(feature = "rayon")]
288 {
289 let weighted = resources
290 .caches
291 .par_iter()
292 .zip(evaluator.dataset.weights_local().par_iter())
293 .map_init(
294 || {
295 (
296 vec![Complex64::ZERO; amplitude_len],
297 vec![DVector::zeros(grad_dim); amplitude_len],
298 vec![Complex64::ZERO; slot_count],
299 vec![DVector::zeros(grad_dim); slot_count],
300 )
301 },
302 |(amplitude_values, gradient_values, value_slots, gradient_slots),
303 (cache, event)| {
304 evaluator.fill_amplitude_values_and_gradients(
305 amplitude_values,
306 gradient_values,
307 &active_indices,
308 &active_mask,
309 ¶meters,
310 cache,
311 );
312 let (value, gradient) = evaluator
313 .evaluate_expression_value_gradient_with_scratch(
314 amplitude_values,
315 gradient_values,
316 value_slots,
317 gradient_slots,
318 );
319 (
320 *event * value.re / n_mc,
321 gradient.map(|g| g.re).scale(*event / n_mc),
322 )
323 },
324 )
325 .collect::<Vec<_>>();
326 Ok(weighted.into_iter().unzip())
327 }
328 #[cfg(not(feature = "rayon"))]
329 {
330 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
331 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
332 let mut value_slots = vec![Complex64::ZERO; slot_count];
333 let mut gradient_slots = vec![DVector::zeros(grad_dim); slot_count];
334 Ok(resources
335 .caches
336 .iter()
337 .zip(evaluator.dataset.weights_local().iter())
338 .map(|(cache, event)| {
339 evaluator.fill_amplitude_values_and_gradients(
340 &mut amplitude_values,
341 &mut gradient_values,
342 &active_indices,
343 &active_mask,
344 ¶meters,
345 cache,
346 );
347 let (value, gradient) = evaluator.evaluate_expression_value_gradient_with_scratch(
348 &litude_values,
349 &gradient_values,
350 &mut value_slots,
351 &mut gradient_slots,
352 );
353 (
354 *event * value.re / n_mc,
355 gradient.map(|g| g.re).scale(*event / n_mc),
356 )
357 })
358 .unzip())
359 }
360}
361
362#[cfg(feature = "rayon")]
363pub(crate) fn sum_dvectors_parallel(
364 iter: impl rayon::iter::ParallelIterator<Item = DVector<f64>>,
365 len: usize,
366) -> DVector<f64> {
367 iter.reduce(
368 || DVector::zeros(len),
369 |mut accum, value| {
370 accum += value;
371 accum
372 },
373 )
374}
375
376#[cfg(feature = "rayon")]
377#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
378pub(crate) struct GradientScratchKey {
379 n_parameters: usize,
380 n_amplitudes: usize,
381 n_expression_slots: usize,
382}
383
384#[cfg(feature = "rayon")]
385pub(crate) struct GradientScratchWorkspace {
386 amplitude_values: Vec<Complex64>,
387 gradient_values: Vec<DVector<Complex64>>,
388 value_slots: Vec<Complex64>,
389 gradient_slots: Vec<DVector<Complex64>>,
390}
391
392#[cfg(feature = "rayon")]
393impl GradientScratchWorkspace {
394 fn new(key: GradientScratchKey) -> Self {
395 Self {
396 amplitude_values: vec![Complex64::ZERO; key.n_amplitudes],
397 gradient_values: vec![DVector::zeros(key.n_parameters); key.n_amplitudes],
398 value_slots: vec![Complex64::ZERO; key.n_expression_slots],
399 gradient_slots: vec![DVector::zeros(key.n_parameters); key.n_expression_slots],
400 }
401 }
402
403 fn matches_key(&self, key: GradientScratchKey) -> bool {
404 self.amplitude_values.len() == key.n_amplitudes
405 && self.gradient_values.len() == key.n_amplitudes
406 && self.value_slots.len() == key.n_expression_slots
407 && self.gradient_slots.len() == key.n_expression_slots
408 && self
409 .gradient_values
410 .iter()
411 .all(|gradient| gradient.len() == key.n_parameters)
412 && self
413 .gradient_slots
414 .iter()
415 .all(|slot| slot.len() == key.n_parameters)
416 }
417}
418
419#[cfg(feature = "rayon")]
420pub(crate) struct GradientScratchLease {
421 key: GradientScratchKey,
422 workspace: Option<GradientScratchWorkspace>,
423}
424
425#[cfg(feature = "rayon")]
426impl GradientScratchLease {
427 fn workspace_mut(&mut self) -> &mut GradientScratchWorkspace {
428 self.workspace
429 .as_mut()
430 .expect("gradient scratch workspace must be available while leased")
431 }
432}
433
434#[cfg(feature = "rayon")]
435impl Drop for GradientScratchLease {
436 fn drop(&mut self) {
437 if let Some(workspace) = self.workspace.take() {
438 TLS_GRADIENT_SCRATCH_POOL.with(|pool| {
439 pool.borrow_mut().insert(self.key, workspace);
440 });
441 }
442 }
443}
444
445#[cfg(feature = "rayon")]
446pub(crate) fn acquire_gradient_scratch(key: GradientScratchKey) -> GradientScratchLease {
447 let mut workspace = TLS_GRADIENT_SCRATCH_POOL.with(|pool| {
448 pool.borrow_mut()
449 .remove(&key)
450 .unwrap_or_else(|| GradientScratchWorkspace::new(key))
451 });
452 if !workspace.matches_key(key) {
453 workspace = GradientScratchWorkspace::new(key);
454 }
455 GradientScratchLease {
456 key,
457 workspace: Some(workspace),
458 }
459}
460
461#[cfg(feature = "rayon")]
462thread_local! {
463 static TLS_GRADIENT_SCRATCH_POOL: RefCell<HashMap<GradientScratchKey, GradientScratchWorkspace>> =
464 RefCell::new(HashMap::new());
465}
466
467#[derive(Clone)]
469pub struct NLL {
470 pub data_evaluator: Evaluator,
472 pub accmc_evaluator: Evaluator,
474 pub(crate) n_mc: f64,
475 pub(crate) projection_active_mask_cache: Arc<Mutex<HashMap<ProjectionMaskCacheKey, Vec<bool>>>>,
476}
477
478impl NLL {
479 pub fn new(
484 expression: &Expression,
485 ds_data: &Arc<Dataset>,
486 ds_accmc: &Arc<Dataset>,
487 n_mc: Option<f64>,
488 ) -> LadduResult<Box<Self>> {
489 let data_evaluator = expression.load(ds_data)?;
490 let accmc_evaluator = expression.load(ds_accmc)?;
491 Ok(Self {
492 data_evaluator,
493 n_mc: n_mc.unwrap_or(accmc_evaluator.dataset.n_events_weighted()),
494 accmc_evaluator,
495 projection_active_mask_cache: Arc::new(Mutex::new(HashMap::new())),
496 }
497 .into())
498 }
499
500 fn normalized_projection_key<T: AsRef<str>>(names: &[T]) -> Vec<String> {
501 let mut key = names
502 .iter()
503 .map(|name| name.as_ref().to_string())
504 .collect::<Vec<_>>();
505 key.sort_unstable();
506 key.dedup();
507 key
508 }
509
510 fn projection_cache_key<T: AsRef<str>>(names: &[T], strict: bool) -> ProjectionMaskCacheKey {
511 (strict, Self::normalized_projection_key(names))
512 }
513
514 fn resolve_projection_active_mask_for_evaluator<T: AsRef<str>>(
515 evaluator: &Evaluator,
516 names: &[T],
517 strict: bool,
518 ) -> LadduResult<Vec<bool>> {
519 let current_active_mask = evaluator.active_mask();
520 let isolate_result = if strict {
521 evaluator.isolate_many_strict(names)
522 } else {
523 evaluator.isolate_many(names);
524 Ok(())
525 };
526 if let Err(err) = isolate_result {
527 evaluator.set_active_mask(¤t_active_mask)?;
528 return Err(err);
529 }
530 let resolved_mask = evaluator.active_mask();
531 evaluator.set_active_mask(¤t_active_mask)?;
532 Ok(resolved_mask)
533 }
534
535 fn get_or_build_projection_active_mask<T: AsRef<str>>(
536 &self,
537 names: &[T],
538 strict: bool,
539 ) -> LadduResult<Vec<bool>> {
540 let key = Self::projection_cache_key(names, strict);
541 if let Some(mask) = self.projection_active_mask_cache.lock().get(&key).cloned() {
542 return Ok(mask);
543 }
544
545 let resolved_mask = Self::resolve_projection_active_mask_for_evaluator(
546 &self.accmc_evaluator,
547 names,
548 strict,
549 )?;
550 self.projection_active_mask_cache
551 .lock()
552 .insert(key, resolved_mask.clone());
553 Ok(resolved_mask)
554 }
555
556 fn invalidate_projection_mask_cache(&self) {
557 self.projection_active_mask_cache.lock().clear();
558 }
559
560 pub fn parameters(&self) -> ParameterMap {
562 self.data_evaluator.parameters()
563 }
564
565 pub fn n_free(&self) -> usize {
567 self.data_evaluator.n_free()
568 }
569
570 pub fn n_fixed(&self) -> usize {
572 self.data_evaluator.n_fixed()
573 }
574
575 pub fn n_parameters(&self) -> usize {
577 self.data_evaluator.n_parameters()
578 }
579
580 pub fn expression(&self) -> Expression {
582 self.data_evaluator.expression()
583 }
584
585 pub fn compiled_expression(&self) -> CompiledExpression {
588 self.data_evaluator.compiled_expression()
589 }
590
591 pub fn to_stochastic(
593 &self,
594 batch_size: usize,
595 seed: Option<usize>,
596 ) -> LadduResult<StochasticNLL> {
597 StochasticNLL::new(self.clone(), batch_size, seed)
598 }
599 pub fn activate<T: AsRef<str>>(&self, name: T) {
601 self.invalidate_projection_mask_cache();
602 self.data_evaluator.activate(&name);
603 self.accmc_evaluator.activate(name);
604 }
605 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
607 self.invalidate_projection_mask_cache();
608 self.data_evaluator.activate_strict(&name)?;
609 self.accmc_evaluator.activate_strict(name)?;
610 Ok(())
611 }
612 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
614 self.invalidate_projection_mask_cache();
615 self.data_evaluator.activate_many(names);
616 self.accmc_evaluator.activate_many(names);
617 }
618 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
620 self.invalidate_projection_mask_cache();
621 self.data_evaluator.activate_many_strict(names)?;
622 self.accmc_evaluator.activate_many_strict(names)?;
623 Ok(())
624 }
625 pub fn activate_all(&self) {
627 self.invalidate_projection_mask_cache();
628 self.data_evaluator.activate_all();
629 self.accmc_evaluator.activate_all();
630 }
631 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
633 self.invalidate_projection_mask_cache();
634 self.data_evaluator.deactivate(&name);
635 self.accmc_evaluator.deactivate(name);
636 }
637 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
639 self.invalidate_projection_mask_cache();
640 self.data_evaluator.deactivate_strict(&name)?;
641 self.accmc_evaluator.deactivate_strict(name)?;
642 Ok(())
643 }
644 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
646 self.invalidate_projection_mask_cache();
647 self.data_evaluator.deactivate_many(names);
648 self.accmc_evaluator.deactivate_many(names);
649 }
650 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
652 self.invalidate_projection_mask_cache();
653 self.data_evaluator.deactivate_many_strict(names)?;
654 self.accmc_evaluator.deactivate_many_strict(names)?;
655 Ok(())
656 }
657 pub fn deactivate_all(&self) {
659 self.invalidate_projection_mask_cache();
660 self.data_evaluator.deactivate_all();
661 self.accmc_evaluator.deactivate_all();
662 }
663 pub fn isolate<T: AsRef<str>>(&self, name: T) {
665 self.invalidate_projection_mask_cache();
666 self.data_evaluator.isolate(&name);
667 self.accmc_evaluator.isolate(name);
668 }
669 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
671 self.invalidate_projection_mask_cache();
672 self.data_evaluator.isolate_strict(&name)?;
673 self.accmc_evaluator.isolate_strict(name)?;
674 Ok(())
675 }
676 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
678 self.invalidate_projection_mask_cache();
679 self.data_evaluator.isolate_many(names);
680 self.accmc_evaluator.isolate_many(names);
681 }
682 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
684 self.invalidate_projection_mask_cache();
685 self.data_evaluator.isolate_many_strict(names)?;
686 self.accmc_evaluator.isolate_many_strict(names)?;
687 Ok(())
688 }
689
690 pub fn project_weights_local(
699 &self,
700 parameters: &[f64],
701 mc_evaluator: Option<Evaluator>,
702 ) -> LadduResult<Vec<f64>> {
703 validate_free_parameter_len(parameters.len(), self.n_free())?;
704 if let Some(mc_evaluator) = mc_evaluator {
705 project_weights_local_from_evaluator(&mc_evaluator, parameters, self.n_mc)
706 } else {
707 project_weights_local_from_evaluator(&self.accmc_evaluator, parameters, self.n_mc)
708 }
709 }
710
711 #[cfg(feature = "mpi")]
720 pub fn project_weights_mpi(
721 &self,
722 parameters: &[f64],
723 mc_evaluator: Option<Evaluator>,
724 world: &SimpleCommunicator,
725 ) -> LadduResult<Vec<f64>> {
726 let n_events = mc_evaluator
727 .as_ref()
728 .unwrap_or(&self.accmc_evaluator)
729 .dataset
730 .n_events();
731 let local_projection = self.project_weights_local(parameters, mc_evaluator)?;
732 let mut buffer: Vec<f64> = vec![0.0; n_events];
733 let (counts, displs) = world.get_counts_displs(n_events);
734 {
735 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
738 world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
739 }
740 Ok(buffer)
741 }
742
743 pub fn project_weights(
757 &self,
758 parameters: &[f64],
759 mc_evaluator: Option<Evaluator>,
760 ) -> LadduResult<Vec<f64>> {
761 #[cfg(feature = "mpi")]
762 {
763 if let Some(world) = laddu_core::mpi::get_world() {
764 return self.project_weights_mpi(parameters, mc_evaluator, &world);
765 }
766 }
767 self.project_weights_local(parameters, mc_evaluator)
768 }
769
770 pub fn project_weights_and_gradients_local(
779 &self,
780 parameters: &[f64],
781 mc_evaluator: Option<Evaluator>,
782 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
783 validate_free_parameter_len(parameters.len(), self.n_free())?;
784 if let Some(mc_evaluator) = mc_evaluator {
785 project_weights_and_gradients_local_from_evaluator(&mc_evaluator, parameters, self.n_mc)
786 } else {
787 project_weights_and_gradients_local_from_evaluator(
788 &self.accmc_evaluator,
789 parameters,
790 self.n_mc,
791 )
792 }
793 }
794
795 #[cfg(feature = "mpi")]
804 pub fn project_weights_and_gradients_mpi(
805 &self,
806 parameters: &[f64],
807 mc_evaluator: Option<Evaluator>,
808 world: &SimpleCommunicator,
809 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
810 let n_events = mc_evaluator
811 .as_ref()
812 .unwrap_or(&self.accmc_evaluator)
813 .dataset
814 .n_events();
815 let (local_projection, local_gradient_projection) =
816 self.project_weights_and_gradients_local(parameters, mc_evaluator)?;
817 let mut projection_result: Vec<f64> = vec![0.0; n_events];
818 let (counts, displs) = world.get_counts_displs(n_events);
819 {
820 let mut partitioned_buffer = PartitionMut::new(&mut projection_result, counts, displs);
822 world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
823 }
824
825 let flattened_local_gradient_projection = local_gradient_projection
826 .iter()
827 .flat_map(|g| g.data.as_vec().to_vec())
828 .collect::<Vec<f64>>();
829 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
830 let mut flattened_result_buffer = vec![0.0; n_events * parameters.len()];
831 let mut partitioned_flattened_result_buffer =
832 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
833 world.all_gather_varcount_into(
835 &flattened_local_gradient_projection,
836 &mut partitioned_flattened_result_buffer,
837 );
838 let gradient_projection_result = flattened_result_buffer
839 .chunks(parameters.len())
840 .map(DVector::from_row_slice)
841 .collect();
842 Ok((projection_result, gradient_projection_result))
843 }
844 pub fn project_weights_and_gradients(
858 &self,
859 parameters: &[f64],
860 mc_evaluator: Option<Evaluator>,
861 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
862 #[cfg(feature = "mpi")]
863 {
864 if let Some(world) = laddu_core::mpi::get_world() {
865 return self.project_weights_and_gradients_mpi(parameters, mc_evaluator, &world);
866 }
867 }
868 self.project_weights_and_gradients_local(parameters, mc_evaluator)
869 }
870
871 fn project_weights_subset_local_with_strict<T: AsRef<str>>(
881 &self,
882 parameters: &[f64],
883 names: &[T],
884 mc_evaluator: Option<Evaluator>,
885 strict: bool,
886 ) -> LadduResult<Vec<f64>> {
887 validate_free_parameter_len(parameters.len(), self.n_free())?;
888 if let Some(mc_evaluator) = mc_evaluator.as_ref() {
889 let resolved_mask =
890 Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)?;
891 project_weights_local_from_resolved_mask(
892 mc_evaluator,
893 parameters,
894 self.n_mc,
895 &resolved_mask,
896 )
897 } else {
898 let resolved_mask = self.get_or_build_projection_active_mask(names, strict)?;
899 project_weights_local_from_resolved_mask(
900 &self.accmc_evaluator,
901 parameters,
902 self.n_mc,
903 &resolved_mask,
904 )
905 }
906 }
907
908 pub fn project_weights_subset_local<T: AsRef<str>>(
911 &self,
912 parameters: &[f64],
913 names: &[T],
914 mc_evaluator: Option<Evaluator>,
915 ) -> LadduResult<Vec<f64>> {
916 self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, false)
917 }
918
919 pub fn project_weights_subset_local_strict<T: AsRef<str>>(
922 &self,
923 parameters: &[f64],
924 names: &[T],
925 mc_evaluator: Option<Evaluator>,
926 ) -> LadduResult<Vec<f64>> {
927 self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, true)
928 }
929
930 #[cfg(feature = "mpi")]
940 fn project_weights_subset_mpi_with_strict<T: AsRef<str>>(
941 &self,
942 parameters: &[f64],
943 names: &[T],
944 mc_evaluator: Option<Evaluator>,
945 world: &SimpleCommunicator,
946 strict: bool,
947 ) -> LadduResult<Vec<f64>> {
948 let n_events = mc_evaluator
949 .as_ref()
950 .unwrap_or(&self.accmc_evaluator)
951 .dataset
952 .n_events();
953 let local_projection =
954 self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, strict)?;
955 let mut buffer: Vec<f64> = vec![0.0; n_events];
956 let (counts, displs) = world.get_counts_displs(n_events);
957 {
958 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
960 world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
961 }
962 Ok(buffer)
963 }
964
965 #[cfg(feature = "mpi")]
966 pub fn project_weights_subset_mpi<T: AsRef<str>>(
969 &self,
970 parameters: &[f64],
971 names: &[T],
972 mc_evaluator: Option<Evaluator>,
973 world: &SimpleCommunicator,
974 ) -> LadduResult<Vec<f64>> {
975 self.project_weights_subset_mpi_with_strict(parameters, names, mc_evaluator, world, false)
976 }
977
978 #[cfg(feature = "mpi")]
979 pub fn project_weights_subset_mpi_strict<T: AsRef<str>>(
982 &self,
983 parameters: &[f64],
984 names: &[T],
985 mc_evaluator: Option<Evaluator>,
986 world: &SimpleCommunicator,
987 ) -> LadduResult<Vec<f64>> {
988 self.project_weights_subset_mpi_with_strict(parameters, names, mc_evaluator, world, true)
989 }
990
991 fn project_weights_subset_with_strict<T: AsRef<str>>(
1008 &self,
1009 parameters: &[f64],
1010 names: &[T],
1011 mc_evaluator: Option<Evaluator>,
1012 strict: bool,
1013 ) -> LadduResult<Vec<f64>> {
1014 #[cfg(feature = "mpi")]
1015 {
1016 if let Some(world) = laddu_core::mpi::get_world() {
1017 return self.project_weights_subset_mpi_with_strict(
1018 parameters,
1019 names,
1020 mc_evaluator,
1021 &world,
1022 strict,
1023 );
1024 }
1025 }
1026 self.project_weights_subset_local_with_strict(parameters, names, mc_evaluator, strict)
1027 }
1028
1029 pub fn project_weights_subset<T: AsRef<str>>(
1032 &self,
1033 parameters: &[f64],
1034 names: &[T],
1035 mc_evaluator: Option<Evaluator>,
1036 ) -> LadduResult<Vec<f64>> {
1037 self.project_weights_subset_with_strict(parameters, names, mc_evaluator, false)
1038 }
1039
1040 pub fn project_weights_subset_strict<T: AsRef<str>>(
1043 &self,
1044 parameters: &[f64],
1045 names: &[T],
1046 mc_evaluator: Option<Evaluator>,
1047 ) -> LadduResult<Vec<f64>> {
1048 self.project_weights_subset_with_strict(parameters, names, mc_evaluator, true)
1049 }
1050
1051 fn project_weights_subsets_local_with_strict<T: AsRef<str>>(
1053 &self,
1054 parameters: &[f64],
1055 subsets: &[Vec<T>],
1056 mc_evaluator: Option<Evaluator>,
1057 strict: bool,
1058 ) -> LadduResult<Vec<Vec<f64>>> {
1059 validate_free_parameter_len(parameters.len(), self.n_free())?;
1060 if subsets.is_empty() {
1061 return Ok(Vec::new());
1062 }
1063 if let Some(mc_evaluator) = mc_evaluator.as_ref() {
1064 let resolved_masks = subsets
1065 .iter()
1066 .map(|names| {
1067 Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)
1068 })
1069 .collect::<LadduResult<Vec<_>>>()?;
1070 resolved_masks
1071 .iter()
1072 .map(|mask| {
1073 project_weights_local_from_resolved_mask(
1074 mc_evaluator,
1075 parameters,
1076 self.n_mc,
1077 mask,
1078 )
1079 })
1080 .collect()
1081 } else {
1082 let resolved_masks = subsets
1083 .iter()
1084 .map(|names| self.get_or_build_projection_active_mask(names, strict))
1085 .collect::<LadduResult<Vec<_>>>()?;
1086 resolved_masks
1087 .iter()
1088 .map(|mask| {
1089 project_weights_local_from_resolved_mask(
1090 &self.accmc_evaluator,
1091 parameters,
1092 self.n_mc,
1093 mask,
1094 )
1095 })
1096 .collect()
1097 }
1098 }
1099
1100 pub fn project_weights_subsets_local<T: AsRef<str>>(
1103 &self,
1104 parameters: &[f64],
1105 subsets: &[Vec<T>],
1106 mc_evaluator: Option<Evaluator>,
1107 ) -> LadduResult<Vec<Vec<f64>>> {
1108 self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, false)
1109 }
1110
1111 pub fn project_weights_subsets_local_strict<T: AsRef<str>>(
1114 &self,
1115 parameters: &[f64],
1116 subsets: &[Vec<T>],
1117 mc_evaluator: Option<Evaluator>,
1118 ) -> LadduResult<Vec<Vec<f64>>> {
1119 self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, true)
1120 }
1121
1122 #[cfg(feature = "mpi")]
1124 fn project_weights_subsets_mpi_with_strict<T: AsRef<str>>(
1125 &self,
1126 parameters: &[f64],
1127 subsets: &[Vec<T>],
1128 mc_evaluator: Option<Evaluator>,
1129 world: &SimpleCommunicator,
1130 strict: bool,
1131 ) -> LadduResult<Vec<Vec<f64>>> {
1132 let n_events = mc_evaluator
1133 .as_ref()
1134 .unwrap_or(&self.accmc_evaluator)
1135 .dataset
1136 .n_events();
1137 let local_projections = self.project_weights_subsets_local_with_strict(
1138 parameters,
1139 subsets,
1140 mc_evaluator,
1141 strict,
1142 )?;
1143 let (counts, displs) = world.get_counts_displs(n_events);
1144 let mut gathered = Vec::with_capacity(local_projections.len());
1145 for local_projection in local_projections {
1146 let mut buffer = vec![0.0; n_events];
1147 {
1148 let mut partitioned_buffer =
1149 PartitionMut::new(&mut buffer, counts.clone(), displs.clone());
1150 world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
1151 }
1152 gathered.push(buffer);
1153 }
1154 Ok(gathered)
1155 }
1156
1157 #[cfg(feature = "mpi")]
1158 pub fn project_weights_subsets_mpi<T: AsRef<str>>(
1161 &self,
1162 parameters: &[f64],
1163 subsets: &[Vec<T>],
1164 mc_evaluator: Option<Evaluator>,
1165 world: &SimpleCommunicator,
1166 ) -> LadduResult<Vec<Vec<f64>>> {
1167 self.project_weights_subsets_mpi_with_strict(
1168 parameters,
1169 subsets,
1170 mc_evaluator,
1171 world,
1172 false,
1173 )
1174 }
1175
1176 #[cfg(feature = "mpi")]
1177 pub fn project_weights_subsets_mpi_strict<T: AsRef<str>>(
1180 &self,
1181 parameters: &[f64],
1182 subsets: &[Vec<T>],
1183 mc_evaluator: Option<Evaluator>,
1184 world: &SimpleCommunicator,
1185 ) -> LadduResult<Vec<Vec<f64>>> {
1186 self.project_weights_subsets_mpi_with_strict(parameters, subsets, mc_evaluator, world, true)
1187 }
1188
1189 fn project_weights_subsets_with_strict<T: AsRef<str>>(
1191 &self,
1192 parameters: &[f64],
1193 subsets: &[Vec<T>],
1194 mc_evaluator: Option<Evaluator>,
1195 strict: bool,
1196 ) -> LadduResult<Vec<Vec<f64>>> {
1197 #[cfg(feature = "mpi")]
1198 {
1199 if let Some(world) = laddu_core::mpi::get_world() {
1200 return self.project_weights_subsets_mpi_with_strict(
1201 parameters,
1202 subsets,
1203 mc_evaluator,
1204 &world,
1205 strict,
1206 );
1207 }
1208 }
1209 self.project_weights_subsets_local_with_strict(parameters, subsets, mc_evaluator, strict)
1210 }
1211
1212 pub fn project_weights_subsets<T: AsRef<str>>(
1215 &self,
1216 parameters: &[f64],
1217 subsets: &[Vec<T>],
1218 mc_evaluator: Option<Evaluator>,
1219 ) -> LadduResult<Vec<Vec<f64>>> {
1220 self.project_weights_subsets_with_strict(parameters, subsets, mc_evaluator, false)
1221 }
1222
1223 pub fn project_weights_subsets_strict<T: AsRef<str>>(
1226 &self,
1227 parameters: &[f64],
1228 subsets: &[Vec<T>],
1229 mc_evaluator: Option<Evaluator>,
1230 ) -> LadduResult<Vec<Vec<f64>>> {
1231 self.project_weights_subsets_with_strict(parameters, subsets, mc_evaluator, true)
1232 }
1233
1234 fn project_weights_and_gradients_subset_local_with_strict<T: AsRef<str>>(
1245 &self,
1246 parameters: &[f64],
1247 names: &[T],
1248 mc_evaluator: Option<Evaluator>,
1249 strict: bool,
1250 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1251 validate_free_parameter_len(parameters.len(), self.n_free())?;
1252 let evaluator = mc_evaluator.as_ref().unwrap_or(&self.accmc_evaluator);
1253 let resolved_mask = if let Some(mc_evaluator) = mc_evaluator.as_ref() {
1254 Self::resolve_projection_active_mask_for_evaluator(mc_evaluator, names, strict)?
1255 } else {
1256 self.get_or_build_projection_active_mask(names, strict)?
1257 };
1258 let mc_dataset = &evaluator.dataset;
1259 let result =
1260 evaluator.evaluate_with_gradient_local_with_active_mask(parameters, &resolved_mask)?;
1261 #[cfg(feature = "rayon")]
1262 let (res, res_gradient) = {
1263 (
1264 result
1265 .par_iter()
1266 .zip(mc_dataset.weights_local().par_iter())
1267 .map(|((l, _), e)| e * l.re / self.n_mc)
1268 .collect(),
1269 result
1270 .par_iter()
1271 .zip(mc_dataset.weights_local().par_iter())
1272 .map(|((_, grad_l), e)| grad_l.map(|g| g.re).scale(e / self.n_mc))
1273 .collect(),
1274 )
1275 };
1276 #[cfg(not(feature = "rayon"))]
1277 let (res, res_gradient) = {
1278 (
1279 result
1280 .iter()
1281 .zip(mc_dataset.weights_local().iter())
1282 .map(|((l, _), e)| e * l.re / self.n_mc)
1283 .collect(),
1284 result
1285 .iter()
1286 .zip(mc_dataset.weights_local().iter())
1287 .map(|((_, grad_l), e)| grad_l.map(|g| g.re).scale(e / self.n_mc))
1288 .collect(),
1289 )
1290 };
1291 Ok((res, res_gradient))
1292 }
1293
1294 pub fn project_weights_and_gradients_subset_local<T: AsRef<str>>(
1297 &self,
1298 parameters: &[f64],
1299 names: &[T],
1300 mc_evaluator: Option<Evaluator>,
1301 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1302 self.project_weights_and_gradients_subset_local_with_strict(
1303 parameters,
1304 names,
1305 mc_evaluator,
1306 false,
1307 )
1308 }
1309
1310 pub fn project_weights_and_gradients_subset_local_strict<T: AsRef<str>>(
1313 &self,
1314 parameters: &[f64],
1315 names: &[T],
1316 mc_evaluator: Option<Evaluator>,
1317 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1318 self.project_weights_and_gradients_subset_local_with_strict(
1319 parameters,
1320 names,
1321 mc_evaluator,
1322 true,
1323 )
1324 }
1325
1326 #[cfg(feature = "mpi")]
1337 fn project_weights_and_gradients_subset_mpi_with_strict<T: AsRef<str>>(
1338 &self,
1339 parameters: &[f64],
1340 names: &[T],
1341 mc_evaluator: Option<Evaluator>,
1342 world: &SimpleCommunicator,
1343 strict: bool,
1344 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1345 let n_events = mc_evaluator
1346 .as_ref()
1347 .unwrap_or(&self.accmc_evaluator)
1348 .dataset
1349 .n_events();
1350 let (local_projection, local_gradient_projection) = self
1351 .project_weights_and_gradients_subset_local_with_strict(
1352 parameters,
1353 names,
1354 mc_evaluator,
1355 strict,
1356 )?;
1357 let mut projection_result: Vec<f64> = vec![0.0; n_events];
1358 let (counts, displs) = world.get_counts_displs(n_events);
1359 {
1360 let mut partitioned_buffer = PartitionMut::new(&mut projection_result, counts, displs);
1362 world.all_gather_varcount_into(&local_projection, &mut partitioned_buffer);
1363 }
1364
1365 let flattened_local_gradient_projection = local_gradient_projection
1366 .iter()
1367 .flat_map(|g| g.data.as_vec().to_vec())
1368 .collect::<Vec<f64>>();
1369 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
1370 let mut flattened_result_buffer = vec![0.0; n_events * parameters.len()];
1371 let mut partitioned_flattened_result_buffer =
1372 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
1373 world.all_gather_varcount_into(
1375 &flattened_local_gradient_projection,
1376 &mut partitioned_flattened_result_buffer,
1377 );
1378 let gradient_projection_result = flattened_result_buffer
1379 .chunks(parameters.len())
1380 .map(DVector::from_row_slice)
1381 .collect();
1382 Ok((projection_result, gradient_projection_result))
1383 }
1384
1385 #[cfg(feature = "mpi")]
1386 pub fn project_weights_and_gradients_subset_mpi<T: AsRef<str>>(
1389 &self,
1390 parameters: &[f64],
1391 names: &[T],
1392 mc_evaluator: Option<Evaluator>,
1393 world: &SimpleCommunicator,
1394 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1395 self.project_weights_and_gradients_subset_mpi_with_strict(
1396 parameters,
1397 names,
1398 mc_evaluator,
1399 world,
1400 false,
1401 )
1402 }
1403
1404 #[cfg(feature = "mpi")]
1405 pub fn project_weights_and_gradients_subset_mpi_strict<T: AsRef<str>>(
1408 &self,
1409 parameters: &[f64],
1410 names: &[T],
1411 mc_evaluator: Option<Evaluator>,
1412 world: &SimpleCommunicator,
1413 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1414 self.project_weights_and_gradients_subset_mpi_with_strict(
1415 parameters,
1416 names,
1417 mc_evaluator,
1418 world,
1419 true,
1420 )
1421 }
1422 fn project_weights_and_gradients_subset_with_strict<T: AsRef<str>>(
1441 &self,
1442 parameters: &[f64],
1443 names: &[T],
1444 mc_evaluator: Option<Evaluator>,
1445 strict: bool,
1446 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1447 #[cfg(feature = "mpi")]
1448 {
1449 if let Some(world) = laddu_core::mpi::get_world() {
1450 return self.project_weights_and_gradients_subset_mpi_with_strict(
1451 parameters,
1452 names,
1453 mc_evaluator,
1454 &world,
1455 strict,
1456 );
1457 }
1458 }
1459 self.project_weights_and_gradients_subset_local_with_strict(
1460 parameters,
1461 names,
1462 mc_evaluator,
1463 strict,
1464 )
1465 }
1466
1467 pub fn project_weights_and_gradients_subset<T: AsRef<str>>(
1470 &self,
1471 parameters: &[f64],
1472 names: &[T],
1473 mc_evaluator: Option<Evaluator>,
1474 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1475 self.project_weights_and_gradients_subset_with_strict(
1476 parameters,
1477 names,
1478 mc_evaluator,
1479 false,
1480 )
1481 }
1482
1483 pub fn project_weights_and_gradients_subset_strict<T: AsRef<str>>(
1486 &self,
1487 parameters: &[f64],
1488 names: &[T],
1489 mc_evaluator: Option<Evaluator>,
1490 ) -> LadduResult<(Vec<f64>, Vec<DVector<f64>>)> {
1491 self.project_weights_and_gradients_subset_with_strict(parameters, names, mc_evaluator, true)
1492 }
1493
1494 fn evaluate_data_term_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1495 evaluate_weighted_expression_sum_local(&self.data_evaluator, parameters, |l| f64::ln(l.re))
1496 }
1497
1498 fn evaluate_mc_term_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1499 self.accmc_evaluator
1500 .evaluate_weighted_value_sum_local(parameters)
1501 }
1502
1503 #[doc(hidden)]
1504 pub fn profile_data_term_local_value(&self, parameters: &[f64]) -> LadduResult<f64> {
1505 self.evaluate_data_term_local(parameters)
1506 }
1507
1508 #[doc(hidden)]
1509 pub fn profile_mc_term_local_value(&self, parameters: &[f64]) -> LadduResult<f64> {
1510 self.evaluate_mc_term_local(parameters)
1511 }
1512
1513 pub(crate) fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<f64> {
1514 let data_term = self.evaluate_data_term_local(parameters)?;
1515 let mc_term = self.evaluate_mc_term_local(parameters)?;
1516 Ok(-2.0 * (data_term - mc_term / self.n_mc))
1517 }
1518
1519 #[cfg(feature = "mpi")]
1520 #[doc(hidden)]
1521 pub fn evaluate_mpi(&self, parameters: &[f64], world: &SimpleCommunicator) -> LadduResult<f64> {
1522 let data_term_local = self.evaluate_data_term_local(parameters)?;
1523 let data_term = reduce_scalar(world, data_term_local);
1524 let mc_term = self
1525 .accmc_evaluator
1526 .evaluate_weighted_value_sum_mpi(parameters, world)?;
1527 Ok(-2.0 * (data_term - mc_term / self.n_mc))
1528 }
1529
1530 pub(crate) fn evaluate_data_gradient_term_local(
1531 &self,
1532 parameters: &[f64],
1533 ) -> LadduResult<DVector<f64>> {
1534 let data_resources = self.data_evaluator.resources.read();
1535 let data_parameters = data_resources.parameter_map.assemble(parameters)?;
1536 let data_active_indices = data_resources.active_indices().to_vec();
1537 let data_active_mask = data_resources.active.clone();
1538 #[cfg(feature = "rayon")]
1539 let n_parameters = parameters.len();
1540 #[cfg(feature = "rayon")]
1541 let data_scratch_key = GradientScratchKey {
1542 n_parameters,
1543 n_amplitudes: self.data_evaluator.amplitude_value_slot_count(),
1544 n_expression_slots: self.data_evaluator.expression_slot_count(),
1545 };
1546 #[cfg(feature = "rayon")]
1547 let data_term: DVector<f64> = sum_dvectors_parallel(
1548 self.data_evaluator
1549 .dataset
1550 .weights_local()
1551 .par_iter()
1552 .zip(data_resources.caches.par_iter())
1553 .map_init(
1554 || acquire_gradient_scratch(data_scratch_key),
1555 |scratch, (event, cache)| {
1556 let workspace = scratch.workspace_mut();
1557 let amp_vals = &mut workspace.amplitude_values;
1558 let grad_vals = &mut workspace.gradient_values;
1559 self.data_evaluator.fill_amplitude_values_and_gradients(
1560 amp_vals,
1561 grad_vals,
1562 &data_active_indices,
1563 &data_active_mask,
1564 &data_parameters,
1565 cache,
1566 );
1567 let (value, gradient) = self
1568 .data_evaluator
1569 .evaluate_expression_value_gradient_with_scratch(
1570 amp_vals,
1571 grad_vals,
1572 &mut workspace.value_slots,
1573 &mut workspace.gradient_slots,
1574 );
1575 (*event, value, gradient)
1576 },
1577 )
1578 .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re)),
1579 n_parameters,
1580 );
1581 #[cfg(not(feature = "rayon"))]
1582 let data_term: DVector<f64> = {
1583 let amplitude_len = self.data_evaluator.amplitude_value_slot_count();
1584 let mut amp_vals = vec![Complex64::ZERO; amplitude_len];
1585 let mut grad_vals = vec![DVector::zeros(parameters.len()); amplitude_len];
1586 let mut value_slots =
1587 vec![Complex64::ZERO; self.data_evaluator.expression_slot_count()];
1588 let mut gradient_slots =
1589 vec![DVector::zeros(parameters.len()); self.data_evaluator.expression_slot_count()];
1590 self.data_evaluator
1591 .dataset
1592 .weights_local()
1593 .iter()
1594 .zip(data_resources.caches.iter())
1595 .map(|(event, cache)| {
1596 self.data_evaluator.fill_amplitude_values_and_gradients(
1597 &mut amp_vals,
1598 &mut grad_vals,
1599 &data_active_indices,
1600 &data_active_mask,
1601 &data_parameters,
1602 cache,
1603 );
1604 let (value, gradient) = self
1605 .data_evaluator
1606 .evaluate_expression_value_gradient_with_scratch(
1607 &_vals,
1608 &grad_vals,
1609 &mut value_slots,
1610 &mut gradient_slots,
1611 );
1612 (*event, value, gradient)
1613 })
1614 .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re))
1615 .sum()
1616 };
1617 Ok(data_term)
1618 }
1619
1620 #[doc(hidden)]
1621 pub fn evaluate_gradient_local(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1622 let data_term = self.evaluate_data_gradient_term_local(parameters)?;
1623 let mc_term = self
1624 .accmc_evaluator
1625 .evaluate_weighted_gradient_sum_local(parameters)?;
1626 Ok(-2.0 * (data_term - mc_term / self.n_mc))
1627 }
1628
1629 #[cfg(feature = "mpi")]
1630 #[doc(hidden)]
1631 pub fn evaluate_gradient_mpi(
1632 &self,
1633 parameters: &[f64],
1634 world: &SimpleCommunicator,
1635 ) -> LadduResult<DVector<f64>> {
1636 let data_term_local = self.evaluate_data_gradient_term_local(parameters)?;
1637 let data_term = reduce_gradient(world, &data_term_local);
1638 let mc_term = self
1639 .accmc_evaluator
1640 .evaluate_weighted_gradient_sum_mpi(parameters, world)?;
1641 Ok(-2.0 * (data_term - mc_term / self.n_mc))
1642 }
1643}
1644
1645impl LikelihoodTerm for NLL {
1646 fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
1647 validate_free_parameter_len(parameters.len(), self.n_free())?;
1648 #[cfg(feature = "mpi")]
1649 {
1650 if let Some(world) = laddu_core::mpi::get_world() {
1651 return self.evaluate_mpi(parameters, &world);
1652 }
1653 }
1654 self.evaluate_local(parameters)
1655 }
1656 fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1657 validate_free_parameter_len(parameters.len(), self.n_free())?;
1658 #[cfg(feature = "mpi")]
1659 {
1660 if let Some(world) = laddu_core::mpi::get_world() {
1661 return self.evaluate_gradient_mpi(parameters, &world);
1662 }
1663 }
1664 self.evaluate_gradient_local(parameters)
1665 }
1666 fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1667 self.data_evaluator.fix_parameter(name, value)?;
1668 self.accmc_evaluator.fix_parameter(name, value)?;
1669 Ok(())
1670 }
1671 fn free_parameter(&self, name: &str) -> LadduResult<()> {
1672 self.data_evaluator.free_parameter(name)?;
1673 self.accmc_evaluator.free_parameter(name)?;
1674 Ok(())
1675 }
1676 fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
1677 self.data_evaluator.rename_parameter(old, new)?;
1678 self.accmc_evaluator.rename_parameter(old, new)?;
1679 Ok(())
1680 }
1681 fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1682 self.data_evaluator.rename_parameters(mapping)?;
1683 self.accmc_evaluator.rename_parameters(mapping)?;
1684 Ok(())
1685 }
1686 fn parameter_map(&self) -> ParameterMap {
1687 self.data_evaluator.resources.read().parameter_map.clone()
1688 }
1689}
1690
1691#[derive(Clone)]
1697pub struct StochasticNLL {
1698 pub nll: NLL,
1700 n: usize,
1701 batch_size: usize,
1702 batch_indices: Arc<Mutex<Vec<usize>>>,
1703 rng: Arc<Mutex<Rng>>,
1704}
1705
1706impl LikelihoodTerm for StochasticNLL {
1707 fn evaluate(&self, parameters: &[f64]) -> LadduResult<f64> {
1708 validate_free_parameter_len(parameters.len(), self.nll.n_free())?;
1709 let indices = self.batch_indices.lock();
1710 #[cfg(feature = "mpi")]
1711 {
1712 if let Some(world) = laddu_core::mpi::get_world() {
1713 return self.evaluate_mpi(parameters, &indices, &world);
1714 }
1715 }
1716 #[cfg(feature = "rayon")]
1717 let n_data_batch_local = indices
1718 .par_iter()
1719 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1720 .parallel_sum_with_accumulator::<Klein<f64>>();
1721 #[cfg(not(feature = "rayon"))]
1722 let n_data_batch_local = indices
1723 .iter()
1724 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1725 .sum_with_accumulator::<Klein<f64>>();
1726 self.evaluate_local(parameters, &indices, n_data_batch_local)
1727 }
1728 fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<DVector<f64>> {
1729 validate_free_parameter_len(parameters.len(), self.nll.n_free())?;
1730 let indices = self.batch_indices.lock();
1731 #[cfg(feature = "mpi")]
1732 {
1733 if let Some(world) = laddu_core::mpi::get_world() {
1734 return self.evaluate_gradient_mpi(parameters, &indices, &world);
1735 }
1736 }
1737 #[cfg(feature = "rayon")]
1738 let n_data_batch_local = indices
1739 .par_iter()
1740 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1741 .parallel_sum_with_accumulator::<Klein<f64>>();
1742 #[cfg(not(feature = "rayon"))]
1743 let n_data_batch_local = indices
1744 .iter()
1745 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1746 .sum_with_accumulator::<Klein<f64>>();
1747 self.evaluate_gradient_local(parameters, &indices, n_data_batch_local)
1748 }
1749 fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1750 self.nll.fix_parameter(name, value)
1751 }
1752 fn free_parameter(&self, name: &str) -> LadduResult<()> {
1753 self.nll.free_parameter(name)
1754 }
1755 fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
1756 self.nll.rename_parameter(old, new)
1757 }
1758 fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1759 self.nll.rename_parameters(mapping)
1760 }
1761 fn update(&self) {
1762 self.resample();
1763 }
1764 fn parameter_map(&self) -> ParameterMap {
1765 self.nll.parameter_map()
1766 }
1767}
1768
1769impl StochasticNLL {
1770 pub fn new(nll: NLL, batch_size: usize, seed: Option<usize>) -> LadduResult<Self> {
1776 let mut rng = seed.map_or_else(Rng::new, |seed| Rng::with_seed(seed as u64));
1777 let n = nll.data_evaluator.dataset.n_events();
1778 validate_stochastic_batch_size(batch_size, n)?;
1779 let batch_indices = rng.subset(batch_size, n);
1780 Ok(Self {
1781 nll,
1782 n,
1783 batch_size,
1784 batch_indices: Arc::new(Mutex::new(batch_indices)),
1785 rng: Arc::new(Mutex::new(rng)),
1786 })
1787 }
1788 pub fn resample(&self) {
1790 let mut rng = self.rng.lock();
1791 *self.batch_indices.lock() = rng.subset(self.batch_size, self.n);
1792 }
1793
1794 pub fn parameters(&self) -> ParameterMap {
1796 self.nll.parameters()
1797 }
1798
1799 pub fn n_free(&self) -> usize {
1801 self.nll.n_free()
1802 }
1803
1804 pub fn n_fixed(&self) -> usize {
1806 self.nll.n_fixed()
1807 }
1808
1809 pub fn n_parameters(&self) -> usize {
1811 self.nll.n_parameters()
1812 }
1813
1814 pub fn expression(&self) -> Expression {
1816 self.nll.expression()
1817 }
1818
1819 pub fn compiled_expression(&self) -> CompiledExpression {
1822 self.nll.compiled_expression()
1823 }
1824
1825 #[cfg(feature = "mpi")]
1826 fn data_batch_weight_local(&self, indices: &[usize]) -> f64 {
1827 #[cfg(feature = "rayon")]
1828 return indices
1829 .par_iter()
1830 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1831 .parallel_sum_with_accumulator::<Klein<f64>>();
1832 #[cfg(not(feature = "rayon"))]
1833 return indices
1834 .iter()
1835 .map(|&i| self.nll.data_evaluator.dataset.weights_local()[i])
1836 .sum_with_accumulator::<Klein<f64>>();
1837 }
1838
1839 fn evaluate_data_term_local(&self, parameters: &[f64], indices: &[usize]) -> LadduResult<f64> {
1840 let data_result = self
1841 .nll
1842 .data_evaluator
1843 .evaluate_batch_local(parameters, indices)?;
1844 #[cfg(feature = "rayon")]
1845 {
1846 Ok(indices
1847 .par_iter()
1848 .zip(data_result.par_iter())
1849 .map(|(&i, &l)| {
1850 let e = &self.nll.data_evaluator.dataset.weights_local()[i];
1851 e * l.re.ln()
1852 })
1853 .parallel_sum_with_accumulator::<Klein<f64>>())
1854 }
1855 #[cfg(not(feature = "rayon"))]
1856 {
1857 Ok(indices
1858 .iter()
1859 .zip(data_result.iter())
1860 .map(|(&i, &l)| {
1861 let e = &self.nll.data_evaluator.dataset.weights_local()[i];
1862 e * l.re.ln()
1863 })
1864 .sum_with_accumulator::<Klein<f64>>())
1865 }
1866 }
1867
1868 fn evaluate_local(
1869 &self,
1870 parameters: &[f64],
1871 indices: &[usize],
1872 n_data_batch: f64,
1873 ) -> LadduResult<f64> {
1874 let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
1875 let data_term = self.evaluate_data_term_local(parameters, indices)?;
1876 let mc_term = self
1877 .nll
1878 .accmc_evaluator
1879 .evaluate_weighted_value_sum_local(parameters)?;
1880 Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
1881 }
1882
1883 #[cfg(feature = "mpi")]
1884 fn evaluate_mpi(
1885 &self,
1886 parameters: &[f64],
1887 indices: &[usize],
1888 world: &SimpleCommunicator,
1889 ) -> LadduResult<f64> {
1890 let total = self.nll.data_evaluator.dataset.n_events();
1891 let locals = world.locals_from_globals(indices, total);
1892 let n_data_batch_local = self.data_batch_weight_local(&locals);
1893 let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
1894 let data_term_local = self.evaluate_data_term_local(parameters, &locals)?;
1895 let n_data_batch = reduce_scalar(world, n_data_batch_local);
1896 let data_term = reduce_scalar(world, data_term_local);
1897 let mc_term = self
1898 .nll
1899 .accmc_evaluator
1900 .evaluate_weighted_value_sum_mpi(parameters, world)?;
1901 Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
1902 }
1903
1904 fn evaluate_data_gradient_term_local(
1905 &self,
1906 parameters: &[f64],
1907 indices: &[usize],
1908 ) -> LadduResult<DVector<f64>> {
1909 let data_resources = self.nll.data_evaluator.resources.read();
1910 let data_parameters = data_resources.parameter_map.assemble(parameters)?;
1911 let data_active_indices = data_resources.active_indices().to_vec();
1912 let data_active_mask = data_resources.active.clone();
1913 #[cfg(feature = "rayon")]
1914 let n_parameters = parameters.len();
1915 #[cfg(feature = "rayon")]
1916 let data_scratch_key = GradientScratchKey {
1917 n_parameters,
1918 n_amplitudes: self.nll.data_evaluator.amplitude_value_slot_count(),
1919 n_expression_slots: self.nll.data_evaluator.expression_slot_count(),
1920 };
1921 #[cfg(feature = "rayon")]
1922 let data_term: DVector<f64> = sum_dvectors_parallel(
1923 indices
1924 .par_iter()
1925 .map_init(
1926 || acquire_gradient_scratch(data_scratch_key),
1927 |scratch, &idx| {
1928 let workspace = scratch.workspace_mut();
1929 let amp_vals = &mut workspace.amplitude_values;
1930 let grad_vals = &mut workspace.gradient_values;
1931 let event = &self.nll.data_evaluator.dataset.weights_local()[idx];
1932 let cache = &data_resources.caches[idx];
1933 self.nll.data_evaluator.fill_amplitude_values_and_gradients(
1934 amp_vals,
1935 grad_vals,
1936 &data_active_indices,
1937 &data_active_mask,
1938 &data_parameters,
1939 cache,
1940 );
1941 let (value, gradient) = self
1942 .nll
1943 .data_evaluator
1944 .evaluate_expression_value_gradient_with_scratch(
1945 amp_vals,
1946 grad_vals,
1947 &mut workspace.value_slots,
1948 &mut workspace.gradient_slots,
1949 );
1950 (*event, value, gradient)
1951 },
1952 )
1953 .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re)),
1954 n_parameters,
1955 );
1956 #[cfg(not(feature = "rayon"))]
1957 let data_term: DVector<f64> = {
1958 let amplitude_len = self.nll.data_evaluator.amplitude_value_slot_count();
1959 let mut amp_vals = vec![Complex64::ZERO; amplitude_len];
1960 let mut grad_vals = vec![DVector::zeros(parameters.len()); amplitude_len];
1961 let mut value_slots =
1962 vec![Complex64::ZERO; self.nll.data_evaluator.expression_slot_count()];
1963 let mut gradient_slots = vec![
1964 DVector::zeros(parameters.len());
1965 self.nll.data_evaluator.expression_slot_count()
1966 ];
1967 indices
1968 .iter()
1969 .map(|&idx| {
1970 let event = &self.nll.data_evaluator.dataset.weights_local()[idx];
1971 let cache = &data_resources.caches[idx];
1972 self.nll.data_evaluator.fill_amplitude_values_and_gradients(
1973 &mut amp_vals,
1974 &mut grad_vals,
1975 &data_active_indices,
1976 &data_active_mask,
1977 &data_parameters,
1978 cache,
1979 );
1980 let (value, gradient) = self
1981 .nll
1982 .data_evaluator
1983 .evaluate_expression_value_gradient_with_scratch(
1984 &_vals,
1985 &grad_vals,
1986 &mut value_slots,
1987 &mut gradient_slots,
1988 );
1989 (*event, value, gradient)
1990 })
1991 .map(|(w, l, g)| g.map(|gi| gi.re * w / l.re))
1992 .sum()
1993 };
1994 Ok(data_term)
1995 }
1996
1997 fn evaluate_gradient_local(
1998 &self,
1999 parameters: &[f64],
2000 indices: &[usize],
2001 n_data_batch: f64,
2002 ) -> LadduResult<DVector<f64>> {
2003 let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
2004 let data_term = self.evaluate_data_gradient_term_local(parameters, indices)?;
2005 let mc_term = self
2006 .nll
2007 .accmc_evaluator
2008 .evaluate_weighted_gradient_sum_local(parameters)?;
2009 Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
2010 }
2011
2012 #[cfg(feature = "mpi")]
2013 fn evaluate_gradient_mpi(
2014 &self,
2015 parameters: &[f64],
2016 indices: &[usize],
2017 world: &SimpleCommunicator,
2018 ) -> LadduResult<DVector<f64>> {
2019 let total = self.nll.data_evaluator.dataset.n_events();
2020 let locals = world.locals_from_globals(indices, total);
2021 let n_data_batch_local = self.data_batch_weight_local(&locals);
2022 let n_data_total = self.nll.data_evaluator.dataset.n_events_weighted();
2023 let data_term_local = self.evaluate_data_gradient_term_local(parameters, &locals)?;
2024 let n_data_batch = reduce_scalar(world, n_data_batch_local);
2025 let data_term = reduce_gradient(world, &data_term_local);
2026 let mc_term = self
2027 .nll
2028 .accmc_evaluator
2029 .evaluate_weighted_gradient_sum_mpi(parameters, world)?;
2030 Ok(-2.0 * (data_term * n_data_total / n_data_batch - mc_term / self.nll.n_mc))
2031 }
2032}