Skip to main content

lmn_core/
sampling.rs

1use rand::Rng;
2
3// ── SamplingParams ────────────────────────────────────────────────────────────
4
5/// Configuration for the two-stage sampling mechanism (VU threshold + reservoir).
6pub struct SamplingParams {
7    /// VU count below which all results are collected.
8    /// Set to `0` to disable VU-threshold sampling entirely (always rate 1.0).
9    /// Default: 50.
10    pub vu_threshold: usize,
11    /// Maximum results to retain in the buffer.
12    /// Default: 100_000.
13    pub reservoir_size: usize,
14}
15
16impl Default for SamplingParams {
17    fn default() -> Self {
18        Self {
19            vu_threshold: 50,
20            reservoir_size: 100_000,
21        }
22    }
23}
24
25// ── ReservoirAction ───────────────────────────────────────────────────────────
26
27/// Instruction returned by `SamplingState::reservoir_slot`.
28pub enum ReservoirAction {
29    /// Append the result to the end of the results buffer.
30    Push,
31    /// Replace the result at the given index in the results buffer.
32    Replace(usize),
33    /// Drop the result; the buffer is full and this slot lost the lottery.
34    Discard,
35}
36
37// ── SamplingState ─────────────────────────────────────────────────────────────
38
39/// Runtime sampling state. Tracks counters and drives both the VU-threshold
40/// gate and Vitter's Algorithm R reservoir gate.
41pub struct SamplingState {
42    vu_threshold: usize,
43    reservoir_size: usize,
44    sample_rate: f64,
45    min_sample_rate: f64,
46    /// Actual (unsampled) total request count — always incremented.
47    total_requests: usize,
48    /// Actual (unsampled) failure count — always incremented.
49    total_failures: usize,
50    /// Denominator for the reservoir replacement lottery (Vitter's Algorithm R).
51    total_seen_for_reservoir: usize,
52    rng: rand::rngs::ThreadRng,
53}
54
55impl SamplingState {
56    pub fn new(params: SamplingParams) -> Self {
57        Self {
58            vu_threshold: params.vu_threshold,
59            reservoir_size: params.reservoir_size,
60            sample_rate: 1.0,
61            min_sample_rate: 1.0,
62            total_requests: 0,
63            total_failures: 0,
64            total_seen_for_reservoir: 0,
65            rng: rand::rng(),
66        }
67    }
68
69    /// Call on each coordinator tick when the active VU count may have changed.
70    ///
71    /// When `vus > vu_threshold` (and `vu_threshold != 0`), caps the collection
72    /// rate to `threshold / vus`. Otherwise rate is 1.0 (collect everything).
73    pub fn set_active_vus(&mut self, vus: usize) {
74        self.sample_rate = if self.vu_threshold == 0 || vus <= self.vu_threshold {
75            1.0
76        } else {
77            self.vu_threshold as f64 / vus as f64
78        };
79        self.min_sample_rate = self.min_sample_rate.min(self.sample_rate);
80    }
81
82    /// Always call for every completed request. Updates the unsampled counters
83    /// regardless of whether this result will be stored in the reservoir.
84    pub fn record_request(&mut self, success: bool) {
85        self.total_requests += 1;
86        if !success {
87            self.total_failures += 1;
88        }
89    }
90
91    /// VU-threshold gate: returns `true` if this result should proceed toward
92    /// the reservoir. At `sample_rate >= 1.0` always returns `true`.
93    pub fn should_collect(&mut self) -> bool {
94        self.sample_rate >= 1.0 || self.rng.random::<f64>() < self.sample_rate
95    }
96
97    /// Reservoir gate (Vitter's Algorithm R). Call only when `should_collect()`
98    /// returned `true`.
99    ///
100    /// Increments the internal seen-counter and returns the storage instruction:
101    /// - `Push` — buffer not yet full; append.
102    /// - `Replace(j)` — buffer full; replace slot `j` (uniform random).
103    /// - `Discard` — buffer full; this result lost the lottery.
104    pub fn reservoir_slot(&mut self, results_len: usize) -> ReservoirAction {
105        self.total_seen_for_reservoir += 1;
106        if results_len < self.reservoir_size {
107            ReservoirAction::Push
108        } else {
109            let j = self.rng.random_range(0..self.total_seen_for_reservoir);
110            if j < self.reservoir_size {
111                ReservoirAction::Replace(j)
112            } else {
113                ReservoirAction::Discard
114            }
115        }
116    }
117
118    // ── Accessors ─────────────────────────────────────────────────────────────
119
120    pub fn total_requests(&self) -> usize {
121        self.total_requests
122    }
123
124    pub fn total_failures(&self) -> usize {
125        self.total_failures
126    }
127
128    pub fn sample_rate(&self) -> f64 {
129        self.sample_rate
130    }
131
132    pub fn min_sample_rate(&self) -> f64 {
133        self.min_sample_rate
134    }
135}
136
137// ── Tests ─────────────────────────────────────────────────────────────────────
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    fn default_state() -> SamplingState {
144        SamplingState::new(SamplingParams::default())
145    }
146
147    // ── set_active_vus ────────────────────────────────────────────────────────
148
149    #[test]
150    fn rate_is_1_below_threshold() {
151        let mut s = default_state();
152        s.set_active_vus(49);
153        assert_eq!(s.sample_rate(), 1.0);
154    }
155
156    #[test]
157    fn rate_is_1_at_threshold() {
158        let mut s = default_state();
159        s.set_active_vus(50);
160        assert_eq!(s.sample_rate(), 1.0);
161    }
162
163    #[test]
164    fn rate_drops_above_threshold() {
165        let mut s = default_state();
166        s.set_active_vus(100);
167        assert!((s.sample_rate() - 0.5).abs() < f64::EPSILON);
168    }
169
170    #[test]
171    fn rate_scales_proportionally() {
172        let mut s = SamplingState::new(SamplingParams {
173            vu_threshold: 50,
174            reservoir_size: 100_000,
175        });
176        s.set_active_vus(200);
177        assert!((s.sample_rate() - 0.25).abs() < f64::EPSILON);
178    }
179
180    #[test]
181    fn zero_threshold_always_collects() {
182        let mut s = SamplingState::new(SamplingParams {
183            vu_threshold: 0,
184            reservoir_size: 100_000,
185        });
186        s.set_active_vus(10_000);
187        assert_eq!(s.sample_rate(), 1.0);
188        // should_collect must always be true when rate == 1.0
189        for _ in 0..100 {
190            assert!(s.should_collect());
191        }
192    }
193
194    #[test]
195    fn min_sample_rate_tracks_lowest_observed() {
196        let mut s = default_state();
197        s.set_active_vus(100); // rate = 0.5
198        s.set_active_vus(200); // rate = 0.25
199        s.set_active_vus(50); // rate = 1.0 — min must not increase
200        assert!((s.min_sample_rate() - 0.25).abs() < f64::EPSILON);
201    }
202
203    #[test]
204    fn min_sample_rate_starts_at_1() {
205        let s = default_state();
206        assert_eq!(s.min_sample_rate(), 1.0);
207    }
208
209    // ── record_request ────────────────────────────────────────────────────────
210
211    #[test]
212    fn record_request_increments_total() {
213        let mut s = default_state();
214        s.record_request(true);
215        s.record_request(true);
216        assert_eq!(s.total_requests(), 2);
217    }
218
219    #[test]
220    fn record_request_tracks_failures() {
221        let mut s = default_state();
222        s.record_request(true);
223        s.record_request(false);
224        s.record_request(false);
225        assert_eq!(s.total_requests(), 3);
226        assert_eq!(s.total_failures(), 2);
227    }
228
229    #[test]
230    fn record_request_success_does_not_increment_failures() {
231        let mut s = default_state();
232        s.record_request(true);
233        assert_eq!(s.total_failures(), 0);
234    }
235
236    // ── should_collect ────────────────────────────────────────────────────────
237
238    #[test]
239    fn should_collect_always_true_at_full_rate() {
240        let mut s = default_state();
241        s.set_active_vus(10); // rate = 1.0
242        for _ in 0..1000 {
243            assert!(s.should_collect());
244        }
245    }
246
247    #[test]
248    fn should_collect_probabilistic_at_half_rate() {
249        let mut s = default_state();
250        s.set_active_vus(100); // rate = 0.5
251        let collected: usize = (0..10_000).filter(|_| s.should_collect()).count();
252        // Expect ~5000; allow ±15% tolerance
253        assert!(
254            collected > 4_000 && collected < 6_000,
255            "expected ~5000 collected, got {collected}"
256        );
257    }
258
259    // ── reservoir_slot ────────────────────────────────────────────────────────
260
261    #[test]
262    fn reservoir_pushes_while_not_full() {
263        let mut s = SamplingState::new(SamplingParams {
264            vu_threshold: 0,
265            reservoir_size: 5,
266        });
267        for i in 0..5 {
268            match s.reservoir_slot(i) {
269                ReservoirAction::Push => {}
270                _ => panic!("expected Push at results_len={i}"),
271            }
272        }
273    }
274
275    #[test]
276    fn reservoir_never_pushes_when_full() {
277        let mut s = SamplingState::new(SamplingParams {
278            vu_threshold: 0,
279            reservoir_size: 5,
280        });
281        // Bring total_seen_for_reservoir up to 5 (all Pushed).
282        for i in 0..5 {
283            s.reservoir_slot(i);
284        }
285        // Now results_len == reservoir_size == 5; must not Push.
286        for _ in 0..100 {
287            if let ReservoirAction::Push = s.reservoir_slot(5) {
288                panic!("Push when reservoir is full")
289            }
290        }
291    }
292
293    #[test]
294    fn reservoir_replace_index_is_in_bounds() {
295        let mut s = SamplingState::new(SamplingParams {
296            vu_threshold: 0,
297            reservoir_size: 5,
298        });
299        // Fill reservoir first.
300        for i in 0..5 {
301            s.reservoir_slot(i);
302        }
303        for _ in 0..200 {
304            if let ReservoirAction::Replace(idx) = s.reservoir_slot(5) {
305                assert!(
306                    idx < 5,
307                    "Replace index {idx} out of bounds for reservoir_size=5"
308                );
309            }
310        }
311    }
312
313    #[test]
314    fn reservoir_discard_rate_decreases_over_time() {
315        // With a very large total_seen relative to reservoir_size, most slots
316        // should be Discard. This verifies the algorithm converges correctly.
317        let mut s = SamplingState::new(SamplingParams {
318            vu_threshold: 0,
319            reservoir_size: 10,
320        });
321        // Fill reservoir.
322        for i in 0..10 {
323            s.reservoir_slot(i);
324        }
325        // Add 990 more (total_seen = 1000, reservoir_size = 10).
326        // Expected replace rate ≈ 10/1000 = 1%.
327        let mut replaces = 0usize;
328        let mut discards = 0usize;
329        for _ in 0..1000 {
330            match s.reservoir_slot(10) {
331                ReservoirAction::Replace(_) => replaces += 1,
332                ReservoirAction::Discard => discards += 1,
333                ReservoirAction::Push => panic!("unexpected Push"),
334            }
335        }
336        assert!(
337            discards > replaces,
338            "expected more discards than replaces at high total_seen; replaces={replaces}, discards={discards}"
339        );
340    }
341
342    // ── sampling reflects history ─────────────────────────────────────────────
343
344    #[test]
345    fn is_sampling_reflects_history() {
346        // min_sample_rate stays < 1.0 even if VUs later drop back below threshold.
347        let mut s = default_state();
348        s.set_active_vus(10); // rate = 1.0
349        assert_eq!(s.min_sample_rate(), 1.0);
350
351        s.set_active_vus(100); // rate = 0.5
352        assert!((s.min_sample_rate() - 0.5).abs() < f64::EPSILON);
353
354        s.set_active_vus(10); // rate = 1.0 again — min must not reset
355        assert!((s.min_sample_rate() - 0.5).abs() < f64::EPSILON);
356    }
357}