ringkernel_montecarlo/variance/
control.rs1use crate::rng::GpuRng;
11
12pub fn control_variate_estimate(
31 x_samples: &[f32],
32 y_samples: &[f32],
33 mu_y: f32,
34) -> (f32, f32, f32) {
35 let n = x_samples.len();
36 assert_eq!(n, y_samples.len(), "Sample arrays must have same length");
37 assert!(n > 1, "Need at least 2 samples");
38
39 let n_f = n as f32;
40
41 let mean_x: f32 = x_samples.iter().sum::<f32>() / n_f;
43 let mean_y: f32 = y_samples.iter().sum::<f32>() / n_f;
44
45 let mut var_y = 0.0;
47 let mut cov_xy = 0.0;
48 let mut var_x = 0.0;
49
50 for i in 0..n {
51 let dx = x_samples[i] - mean_x;
52 let dy = y_samples[i] - mean_y;
53 var_x += dx * dx;
54 var_y += dy * dy;
55 cov_xy += dx * dy;
56 }
57
58 var_x /= n_f - 1.0;
59 var_y /= n_f - 1.0;
60 cov_xy /= n_f - 1.0;
61
62 let c = if var_y > 1e-10 { cov_xy / var_y } else { 0.0 };
64
65 let estimate = mean_x - c * (mean_y - mu_y);
67
68 let r_squared = if var_x > 1e-10 && var_y > 1e-10 {
70 (cov_xy * cov_xy) / (var_x * var_y)
71 } else {
72 0.0
73 };
74
75 let variance_reduction = 1.0 - r_squared;
77
78 (estimate, c, variance_reduction)
79}
80
81#[derive(Debug, Clone)]
83pub struct ControlVariates {
84 pub n_samples: usize,
86}
87
88impl ControlVariates {
89 pub fn new(n_samples: usize) -> Self {
91 Self { n_samples }
92 }
93
94 pub fn estimate<R: GpuRng, F, G>(
107 &self,
108 state: &mut R::State,
109 f: F,
110 g: G,
111 mu_g: f32,
112 ) -> (f32, f32, f32)
113 where
114 F: Fn(f32) -> f32,
115 G: Fn(f32) -> f32,
116 {
117 let mut x_samples = Vec::with_capacity(self.n_samples);
118 let mut y_samples = Vec::with_capacity(self.n_samples);
119
120 for _ in 0..self.n_samples {
121 let u = R::next_uniform(state);
122 x_samples.push(f(u));
123 y_samples.push(g(u));
124 }
125
126 control_variate_estimate(&x_samples, &y_samples, mu_g)
127 }
128}
129
130impl Default for ControlVariates {
131 fn default() -> Self {
132 Self::new(1000)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::rng::PhiloxRng;
140
141 #[test]
142 fn test_control_variate_perfect_correlation() {
143 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
145 let y = x.clone();
146 let mu_y = 3.0; let (estimate, c, _) = control_variate_estimate(&x, &y, mu_y);
149
150 assert!(
151 (c - 1.0).abs() < 1e-6,
152 "c should be 1 for perfect correlation"
153 );
154 assert!(
155 (estimate - mu_y).abs() < 1e-6,
156 "Estimate should equal mu_y when Y = X"
157 );
158 }
159
160 #[test]
161 fn test_control_variate_reduces_variance() {
162 let n = 5000;
164
165 let mut state = PhiloxRng::seed(42, 0);
167 let naive_samples: Vec<f32> = (0..n)
168 .map(|_| PhiloxRng::next_uniform(&mut state).exp())
169 .collect();
170 let naive_mean: f32 = naive_samples.iter().sum::<f32>() / n as f32;
171
172 let mut state = PhiloxRng::seed(42, 0);
174 let cv = ControlVariates::new(n);
175 let (cv_estimate, _c, var_reduction) =
176 cv.estimate::<PhiloxRng, _, _>(&mut state, |u| u.exp(), |u| u, 0.5);
177
178 let true_value = std::f32::consts::E - 1.0;
180
181 assert!(
183 (naive_mean - true_value).abs() < 0.1,
184 "Naive {} far from {}",
185 naive_mean,
186 true_value
187 );
188 assert!(
189 (cv_estimate - true_value).abs() < 0.1,
190 "CV {} far from {}",
191 cv_estimate,
192 true_value
193 );
194
195 assert!(
197 var_reduction < 1.0,
198 "Control variate should reduce variance"
199 );
200 }
201}