1use crate::{Duration, Instant};
2
3#[derive(Debug)]
19pub(crate) struct WorkLimiter {
20 mode: Mode,
22 cycle: u16,
24 start_time: Option<Instant>,
26 completed: usize,
28 allowed: usize,
30 desired_cycle_time: Duration,
32 smoothed_time_per_work_item_nanos: f64,
34}
35
36impl WorkLimiter {
37 pub(crate) fn new(desired_cycle_time: Duration) -> Self {
38 Self {
39 mode: Mode::Measure,
40 cycle: 0,
41 start_time: None,
42 completed: 0,
43 allowed: 0,
44 desired_cycle_time,
45 smoothed_time_per_work_item_nanos: 0.0,
46 }
47 }
48
49 pub(crate) fn start_cycle(&mut self, now: impl Fn() -> Instant) {
51 self.completed = 0;
52 if let Mode::Measure = self.mode {
53 self.start_time = Some(now());
54 }
55 }
56
57 pub(crate) fn allow_work(&mut self, now: impl Fn() -> Instant) -> bool {
61 match self.mode {
62 Mode::Measure => (now() - self.start_time.unwrap()) < self.desired_cycle_time,
63 Mode::HistoricData => self.completed < self.allowed,
64 }
65 }
66
67 pub(crate) fn record_work(&mut self, work: usize) {
71 self.completed += work;
72 }
73
74 pub(crate) fn finish_cycle(&mut self, now: impl Fn() -> Instant) {
82 if self.completed == 0 {
84 return;
85 }
86
87 if let Mode::Measure = self.mode {
88 let elapsed = now() - self.start_time.unwrap();
89
90 let time_per_work_item_nanos = (elapsed.as_nanos()) as f64 / self.completed as f64;
91
92 self.smoothed_time_per_work_item_nanos = if self.allowed == 0 {
95 time_per_work_item_nanos
97 } else {
98 (7.0 * self.smoothed_time_per_work_item_nanos + time_per_work_item_nanos) / 8.0
100 }
101 .max(1.0);
102
103 self.allowed = (((self.desired_cycle_time.as_nanos()) as f64
105 / self.smoothed_time_per_work_item_nanos) as usize)
106 .max(1);
107 self.start_time = None;
108 }
109
110 self.cycle = self.cycle.wrapping_add(1);
111 self.mode = match self.cycle % SAMPLING_INTERVAL {
112 0 => Mode::Measure,
113 _ => Mode::HistoricData,
114 };
115 }
116}
117
118const SAMPLING_INTERVAL: u16 = 256;
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122enum Mode {
123 Measure,
124 HistoricData,
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use std::cell::RefCell;
131
132 #[test]
133 fn limit_work() {
134 const CYCLE_TIME: Duration = Duration::from_millis(500);
135 const BATCH_WORK_ITEMS: usize = 12;
136 const BATCH_TIME: Duration = Duration::from_millis(100);
137
138 const EXPECTED_INITIAL_BATCHES: usize =
139 (CYCLE_TIME.as_nanos() / BATCH_TIME.as_nanos()) as usize;
140 const EXPECTED_ALLOWED_WORK_ITEMS: usize = EXPECTED_INITIAL_BATCHES * BATCH_WORK_ITEMS;
141
142 let mut limiter = WorkLimiter::new(CYCLE_TIME);
143 reset_time();
144
145 limiter.start_cycle(get_time);
147 let mut initial_batches = 0;
148 while limiter.allow_work(get_time) {
149 limiter.record_work(BATCH_WORK_ITEMS);
150 advance_time(BATCH_TIME);
151 initial_batches += 1;
152 }
153 limiter.finish_cycle(get_time);
154
155 assert_eq!(initial_batches, EXPECTED_INITIAL_BATCHES);
156 assert_eq!(limiter.allowed, EXPECTED_ALLOWED_WORK_ITEMS);
157 let initial_time_per_work_item = limiter.smoothed_time_per_work_item_nanos;
158
159 const BATCH_SIZES: [usize; 4] = [1, 2, 3, 5];
161 for &batch_size in &BATCH_SIZES {
162 limiter.start_cycle(get_time);
163 let mut allowed_work = 0;
164 while limiter.allow_work(get_time) {
165 limiter.record_work(batch_size);
166 allowed_work += batch_size;
167 }
168 limiter.finish_cycle(get_time);
169
170 assert_eq!(allowed_work, EXPECTED_ALLOWED_WORK_ITEMS);
171 }
172
173 for _ in 0..(SAMPLING_INTERVAL as usize - BATCH_SIZES.len() - 1) {
175 limiter.start_cycle(get_time);
176 limiter.record_work(1);
177 limiter.finish_cycle(get_time);
178 }
179
180 const BATCH_WORK_ITEMS_2: usize = 96;
183 const TIME_PER_WORK_ITEMS_2_NANOS: f64 =
184 CYCLE_TIME.as_nanos() as f64 / (EXPECTED_INITIAL_BATCHES * BATCH_WORK_ITEMS_2) as f64;
185
186 let expected_updated_time_per_work_item =
187 (initial_time_per_work_item * 7.0 + TIME_PER_WORK_ITEMS_2_NANOS) / 8.0;
188 let expected_updated_allowed_work_items =
189 (CYCLE_TIME.as_nanos() as f64 / expected_updated_time_per_work_item) as usize;
190
191 limiter.start_cycle(get_time);
192 let mut initial_batches = 0;
193 while limiter.allow_work(get_time) {
194 limiter.record_work(BATCH_WORK_ITEMS_2);
195 advance_time(BATCH_TIME);
196 initial_batches += 1;
197 }
198 limiter.finish_cycle(get_time);
199
200 assert_eq!(initial_batches, EXPECTED_INITIAL_BATCHES);
201 assert_eq!(limiter.allowed, expected_updated_allowed_work_items);
202 }
203
204 thread_local! {
205 pub static TIME: RefCell<Instant> = RefCell::new(Instant::now());
207 }
208
209 fn reset_time() {
210 TIME.with(|t| {
211 *t.borrow_mut() = Instant::now();
212 })
213 }
214
215 fn get_time() -> Instant {
216 TIME.with(|t| *t.borrow())
217 }
218
219 fn advance_time(duration: Duration) {
220 TIME.with(|t| {
221 *t.borrow_mut() += duration;
222 })
223 }
224}