1use nalgebra::ComplexField;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use trellis_runner::{UpdateData, UserState};
5
6use crate::{IntegrableFloat, IntegrationOutput, Segment, SegmentHeap, Segments, Values};
7
8#[derive(Clone, Default, Debug, Deserialize, Serialize)]
9#[allow(clippy::module_name_repetitions)]
10pub struct IntegrationState<I, O, F>
11where
12 F: PartialOrd + PartialEq,
13{
14 pub integral: Option<O>,
16 pub prev_integral: Option<O>,
18 pub best_integral: Option<O>,
20 pub prev_best_integral: Option<O>,
22 pub segments: SegmentHeap<I, O, F>,
24 pub counts: HashMap<String, usize>,
26 pub accumulate_values: bool,
28}
29
30impl<I, O, F> IntegrationState<I, O, F>
31where
32 O: IntegrationOutput<Float = F>,
33 I: ComplexField<RealField = F> + Copy,
34 F: IntegrableFloat,
35{
36 #[must_use]
37 pub fn param(mut self, param: O) -> Self {
38 std::mem::swap(&mut self.prev_integral, &mut self.integral);
39 self.integral = Some(param);
40 self
41 }
42
43 #[must_use]
45 pub fn segments(mut self, segments: Vec<Segment<I, O, F>>) -> Self {
46 segments
47 .into_iter()
48 .for_each(|segment| self.segments.push(segment));
49 self
50 }
51
52 pub fn pop_worst_segment(&mut self) -> Option<Segment<I, O, F>> {
53 self.segments.pop()
54 }
55
56 pub fn take_integral(&mut self) -> Option<O> {
57 self.integral.take()
58 }
59
60 pub fn get_integral(&self) -> Option<&O> {
61 self.integral.as_ref()
62 }
63
64 pub fn get_prev_integral(&self) -> Option<&O> {
65 self.prev_integral.as_ref()
66 }
67
68 pub fn take_prev_integral(&mut self) -> Option<O> {
69 self.prev_integral.take()
70 }
71
72 pub fn get_prev_best_integral(&self) -> Option<&O> {
73 self.prev_best_integral.as_ref()
74 }
75
76 pub fn take_best_integral(&mut self) -> Option<O> {
77 self.best_integral.take()
78 }
79
80 pub fn take_prev_best_integral(&mut self) -> Option<O> {
81 self.prev_best_integral.take()
82 }
83
84 pub fn into_resolved(self) -> Option<Values<I, O>> {
86 let ordered_segments = self.segments.into_input_ordered();
88
89 let mut points = Vec::new();
90 let mut values = Vec::new();
91 let mut weights = Vec::new();
92
93 for segment in ordered_segments.into_iter() {
94 if let Some(data) = segment.data {
95 points.extend_from_slice(&data.points);
96 values.extend_from_slice(&data.values);
97 weights.extend_from_slice(&data.weights);
98 }
99 }
100
101 Some(Values {
102 points,
103 values,
104 weights,
105 })
106 }
107}
108
109impl<I, O, F> UserState for IntegrationState<I, O, F>
110where
111 I: ComplexField<RealField = F> + Copy,
112 O: IntegrationOutput<Float = F>,
113 F: IntegrableFloat,
114{
115 type Float = F;
116 type Param = O;
117
118 fn new() -> Self {
119 Self {
120 integral: None,
121 prev_integral: None,
122 best_integral: None,
123 prev_best_integral: None,
124 segments: SegmentHeap::empty(),
125 counts: HashMap::new(),
126 accumulate_values: false,
127 }
128 }
129
130 fn is_initialised(&self) -> bool {
131 self.get_integral().is_some()
132 }
133
134 fn update(&mut self) -> impl Into<std::option::Option<UpdateData<<Self as UserState>::Float>>> {
135 let absolute_error = self.segments.error().into_inner();
136 let result = self.segments.result();
137 let relative_error = absolute_error / result.l2_norm();
138 self.integral = Some(result);
139
140 Some(UpdateData::ErrorEstimate {
141 relative: relative_error,
142 absolute: absolute_error,
143 })
144 }
145
146 fn get_param(&self) -> Option<&O> {
147 self.get_integral()
148 }
149
150 fn last_was_best(&mut self) {
151 if let Some(integral) = self.get_integral().cloned() {
154 std::mem::swap(&mut self.prev_best_integral, &mut self.best_integral);
155 self.best_integral = Some(integral);
156 }
157 }
158}