quad_rs/
state.rs

1use nalgebra::ComplexField;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use trellis_runner::{UpdateData, UserState};
5
6use crate::{IntegrableFloat, IntegrationOutput, Segment, SegmentHeap, Segments, Values};
7
8#[derive(Clone, Default, Debug, Deserialize, Serialize)]
9#[allow(clippy::module_name_repetitions)]
10pub struct IntegrationState<I, O, F>
11where
12    F: PartialOrd + PartialEq,
13{
14    /// Current value of the integral
15    pub integral: Option<O>,
16    /// Previous value of the integral
17    pub prev_integral: Option<O>,
18    /// Lowest error value of the integral
19    pub best_integral: Option<O>,
20    /// Previous best parameter vector
21    pub prev_best_integral: Option<O>,
22    /// segments in the integral
23    pub segments: SegmentHeap<I, O, F>,
24    /// Evaluation counts
25    pub counts: HashMap<String, usize>,
26    /// Whether to accumulate resolved values in the output
27    pub accumulate_values: bool,
28}
29
30impl<I, O, F> IntegrationState<I, O, F>
31where
32    O: IntegrationOutput<Float = F>,
33    I: ComplexField<RealField = F> + Copy,
34    F: IntegrableFloat,
35{
36    #[must_use]
37    pub fn param(mut self, param: O) -> Self {
38        std::mem::swap(&mut self.prev_integral, &mut self.integral);
39        self.integral = Some(param);
40        self
41    }
42
43    // Add the new segments to the internal heap
44    #[must_use]
45    pub fn segments(mut self, segments: Vec<Segment<I, O, F>>) -> Self {
46        segments
47            .into_iter()
48            .for_each(|segment| self.segments.push(segment));
49        self
50    }
51
52    pub fn pop_worst_segment(&mut self) -> Option<Segment<I, O, F>> {
53        self.segments.pop()
54    }
55
56    pub fn take_integral(&mut self) -> Option<O> {
57        self.integral.take()
58    }
59
60    pub fn get_integral(&self) -> Option<&O> {
61        self.integral.as_ref()
62    }
63
64    pub fn get_prev_integral(&self) -> Option<&O> {
65        self.prev_integral.as_ref()
66    }
67
68    pub fn take_prev_integral(&mut self) -> Option<O> {
69        self.prev_integral.take()
70    }
71
72    pub fn get_prev_best_integral(&self) -> Option<&O> {
73        self.prev_best_integral.as_ref()
74    }
75
76    pub fn take_best_integral(&mut self) -> Option<O> {
77        self.best_integral.take()
78    }
79
80    pub fn take_prev_best_integral(&mut self) -> Option<O> {
81        self.prev_best_integral.take()
82    }
83
84    // Consume the state to get the ordered raw values
85    pub fn into_resolved(self) -> Option<Values<I, O>> {
86        // Segments ordered by the input vector
87        let ordered_segments = self.segments.into_input_ordered();
88
89        let mut points = Vec::new();
90        let mut values = Vec::new();
91        let mut weights = Vec::new();
92
93        for segment in ordered_segments.into_iter() {
94            if let Some(data) = segment.data {
95                points.extend_from_slice(&data.points);
96                values.extend_from_slice(&data.values);
97                weights.extend_from_slice(&data.weights);
98            }
99        }
100
101        Some(Values {
102            points,
103            values,
104            weights,
105        })
106    }
107}
108
109impl<I, O, F> UserState for IntegrationState<I, O, F>
110where
111    I: ComplexField<RealField = F> + Copy,
112    O: IntegrationOutput<Float = F>,
113    F: IntegrableFloat,
114{
115    type Float = F;
116    type Param = O;
117
118    fn new() -> Self {
119        Self {
120            integral: None,
121            prev_integral: None,
122            best_integral: None,
123            prev_best_integral: None,
124            segments: SegmentHeap::empty(),
125            counts: HashMap::new(),
126            accumulate_values: false,
127        }
128    }
129
130    fn is_initialised(&self) -> bool {
131        self.get_integral().is_some()
132    }
133
134    fn update(&mut self) -> impl Into<std::option::Option<UpdateData<<Self as UserState>::Float>>> {
135        let absolute_error = self.segments.error().into_inner();
136        let result = self.segments.result();
137        let relative_error = absolute_error / result.l2_norm();
138        self.integral = Some(result);
139
140        Some(UpdateData::ErrorEstimate {
141            relative: relative_error,
142            absolute: absolute_error,
143        })
144    }
145
146    fn get_param(&self) -> Option<&O> {
147        self.get_integral()
148    }
149
150    fn last_was_best(&mut self) {
151        // If the last iteration was the best one, we swap the previous best and best integral
152        // values
153        if let Some(integral) = self.get_integral().cloned() {
154            std::mem::swap(&mut self.prev_best_integral, &mut self.best_integral);
155            self.best_integral = Some(integral);
156        }
157    }
158}