ringkernel_montecarlo/variance/importance.rs
1//! Importance sampling for variance reduction.
2//!
3//! Importance sampling rewrites an expectation E_p[f(X)] as E_q[f(X) * w(X)]
4//! where w(X) = p(X)/q(X) is the importance weight and q is a proposal distribution.
5//!
6//! This can reduce variance when q is chosen to place more probability mass
7//! in regions where |f(X)| is large.
8
9use crate::rng::GpuRng;
10
11/// Compute importance sample with weight.
12///
13/// Given a sample from proposal distribution q, compute the weighted sample
14/// for estimating E_p[f(X)].
15///
16/// # Arguments
17///
18/// * `x` - Sample from proposal q
19/// * `f_x` - Value of f at x
20/// * `p_x` - Target density p(x)
21/// * `q_x` - Proposal density q(x)
22///
23/// # Returns
24///
25/// Weighted sample f(x) * p(x) / q(x)
26#[inline]
27pub fn importance_sample(_x: f32, f_x: f32, p_x: f32, q_x: f32) -> f32 {
28 if q_x.abs() < 1e-10 {
29 0.0 // Avoid division by zero
30 } else {
31 f_x * p_x / q_x
32 }
33}
34
35/// Configuration for importance sampling.
36#[derive(Debug, Clone)]
37pub struct ImportanceSampling {
38 /// Number of samples.
39 pub n_samples: usize,
40 /// Whether to use self-normalized estimator.
41 pub self_normalized: bool,
42}
43
44impl ImportanceSampling {
45 /// Create new importance sampling configuration.
46 pub fn new(n_samples: usize) -> Self {
47 Self {
48 n_samples,
49 self_normalized: false,
50 }
51 }
52
53 /// Use self-normalized importance sampling.
54 ///
55 /// The self-normalized estimator divides by sum of weights:
56 /// `sum(w_i * f_i) / sum(w_i)`
57 ///
58 /// This is biased but can have lower variance and doesn't require
59 /// knowing the normalizing constant of p.
60 pub fn self_normalized(mut self) -> Self {
61 self.self_normalized = true;
62 self
63 }
64
65 /// Estimate E_p[f(X)] using importance sampling.
66 ///
67 /// # Arguments
68 ///
69 /// * `state` - RNG state
70 /// * `sample_q` - Function to sample from proposal distribution q
71 /// * `f` - Function to estimate expectation of
72 /// * `log_p` - Log of target density (unnormalized OK if self_normalized)
73 /// * `log_q` - Log of proposal density
74 ///
75 /// # Returns
76 ///
77 /// (estimate, effective_sample_size)
78 pub fn estimate<R: GpuRng, S, F, LP, LQ>(
79 &self,
80 state: &mut R::State,
81 sample_q: S,
82 f: F,
83 log_p: LP,
84 log_q: LQ,
85 ) -> (f32, f32)
86 where
87 S: Fn(&mut R::State) -> f32,
88 F: Fn(f32) -> f32,
89 LP: Fn(f32) -> f32,
90 LQ: Fn(f32) -> f32,
91 {
92 let mut weighted_sum = 0.0;
93 let mut weight_sum = 0.0;
94 let mut weight_sq_sum = 0.0;
95
96 for _ in 0..self.n_samples {
97 let x = sample_q(state);
98 let f_x = f(x);
99 let log_w = log_p(x) - log_q(x);
100 let w = log_w.exp();
101
102 weighted_sum += w * f_x;
103 weight_sum += w;
104 weight_sq_sum += w * w;
105 }
106
107 let estimate = if self.self_normalized {
108 if weight_sum.abs() < 1e-10 {
109 0.0
110 } else {
111 weighted_sum / weight_sum
112 }
113 } else {
114 weighted_sum / self.n_samples as f32
115 };
116
117 // Effective sample size: ESS = (sum w)² / (sum w²)
118 let ess = if weight_sq_sum.abs() < 1e-10 {
119 0.0
120 } else {
121 (weight_sum * weight_sum) / weight_sq_sum
122 };
123
124 (estimate, ess)
125 }
126}
127
128impl Default for ImportanceSampling {
129 fn default() -> Self {
130 Self::new(1000)
131 }
132}
133
134/// Exponential tilting proposal for rare event simulation.
135///
136/// For estimating P(X > a) where X ~ N(0,1), shift mean to a/2.
137#[allow(dead_code)]
138#[derive(Debug, Clone, Copy)]
139pub struct ExponentialTilt {
140 /// Tilt parameter (new mean).
141 pub theta: f32,
142}
143
144#[allow(dead_code)]
145impl ExponentialTilt {
146 /// Create exponential tilt for estimating P(X > a).
147 ///
148 /// Uses optimal tilt theta = a for normal random variables.
149 pub fn for_tail_probability(a: f32) -> Self {
150 Self { theta: a }
151 }
152
153 /// Sample from tilted distribution N(theta, 1).
154 pub fn sample<R: GpuRng>(&self, state: &mut R::State) -> f32 {
155 R::next_normal(state) + self.theta
156 }
157
158 /// Log ratio of p(x) / q(x) where p = N(0,1) and q = N(theta,1).
159 pub fn log_weight(&self, x: f32) -> f32 {
160 // log p(x) - log q(x) = -x²/2 - (-(x-θ)²/2)
161 // = -x²/2 + (x-θ)²/2
162 // = (-x² + x² - 2xθ + θ²) / 2
163 // = (-2xθ + θ²) / 2
164 // = θ² / 2 - xθ
165 0.5 * self.theta * self.theta - x * self.theta
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::rng::PhiloxRng;
173
174 #[test]
175 fn test_importance_sample_basic() {
176 // Equal densities should give f(x)
177 let result = importance_sample(1.0, 2.0, 0.5, 0.5);
178 assert!((result - 2.0).abs() < 1e-6);
179
180 // Double target density should double the weight
181 let result = importance_sample(1.0, 2.0, 1.0, 0.5);
182 assert!((result - 4.0).abs() < 1e-6);
183 }
184
185 #[test]
186 fn test_importance_sampling_uniform() {
187 // Estimate E[X²] where X ~ U[0,2]
188 // E[X²] = (1/2) * integral_0^2 x² dx = 4/3
189
190 let mut state = PhiloxRng::seed(42, 0);
191 let is = ImportanceSampling::new(5000).self_normalized();
192
193 // When p = q, use the same log density (they cancel out)
194 let (estimate, ess) = is.estimate::<PhiloxRng, _, _, _, _>(
195 &mut state,
196 |s| PhiloxRng::next_uniform(s) * 2.0, // Sample from U[0,2]
197 |x| x * x, // f(x) = x²
198 |_| -std::f32::consts::LN_2, // log p = log(0.5) for U[0,2]
199 |_| -std::f32::consts::LN_2, // log q = log(0.5) for U[0,2]
200 );
201
202 // When p = q, weights are all 1, so ESS ≈ n
203 assert!(
204 ess > 0.5 * 5000.0,
205 "ESS {} should be reasonable when p = q",
206 ess
207 );
208
209 let true_value = 4.0 / 3.0;
210 assert!(
211 (estimate - true_value).abs() < 0.1,
212 "Estimate {} far from {}",
213 estimate,
214 true_value
215 );
216 }
217
218 #[test]
219 fn test_exponential_tilt() {
220 // Estimate P(Z > 3) where Z ~ N(0,1)
221 // True value ≈ 0.00135
222 let mut state = PhiloxRng::seed(42, 0);
223 let tilt = ExponentialTilt::for_tail_probability(3.0);
224 let is = ImportanceSampling::new(10000).self_normalized();
225
226 let (estimate, ess) = is.estimate::<PhiloxRng, _, _, _, _>(
227 &mut state,
228 |s| tilt.sample::<PhiloxRng>(s),
229 |x| if x > 3.0 { 1.0 } else { 0.0 }, // Indicator X > 3
230 |x| -0.5 * x * x, // log N(0,1)
231 |x| -0.5 * (x - 3.0) * (x - 3.0), // log N(3,1)
232 );
233
234 // Should be reasonably close to true value (wider tolerance for variance)
235 let true_value = 0.00135; // 1 - Phi(3)
236 assert!(
237 (estimate - true_value).abs() < 0.005,
238 "Estimate {} far from {}",
239 estimate,
240 true_value
241 );
242
243 // ESS should be positive (importance sampling typically has low ESS for rare events)
244 assert!(ess > 50.0, "ESS {} too low", ess);
245 }
246}