commonware_p2p/utils/requester/
requester.rs1use super::{Config, PeerLabel};
4use commonware_cryptography::PublicKey;
5use commonware_runtime::{
6 telemetry::metrics::status::{CounterExt, Status},
7 Clock, Metrics,
8};
9use commonware_utils::PrioritySet;
10use either::Either;
11use governor::{
12 clock::Clock as GClock, middleware::NoOpMiddleware, state::keyed::HashMapStateStore,
13 RateLimiter,
14};
15use rand::{seq::SliceRandom, Rng};
16use std::{
17 collections::{HashMap, HashSet},
18 time::{Duration, SystemTime},
19};
20
21pub type ID = u64;
27
28pub struct Requester<E: Clock + GClock + Rng + Metrics, P: PublicKey> {
35 context: E,
36 me: Option<P>,
37 metrics: super::Metrics,
38 initial: Duration,
39 timeout: Duration,
40
41 excluded: HashSet<P>,
43
44 #[allow(clippy::type_complexity)]
46 rate_limiter: RateLimiter<P, HashMapStateStore<P>, E, NoOpMiddleware<E::Instant>>,
47 participants: PrioritySet<P, u128>,
49
50 id: ID,
52 requests: HashMap<ID, (P, SystemTime)>,
54 deadlines: PrioritySet<ID, SystemTime>,
56}
57
58pub struct Request<P: PublicKey> {
65 pub id: ID,
67
68 participant: P,
70
71 start: SystemTime,
73}
74
75impl<E: Clock + GClock + Rng + Metrics, P: PublicKey> Requester<E, P> {
76 pub fn new(context: E, config: Config<P>) -> Self {
78 let rate_limiter = RateLimiter::hashmap_with_clock(config.rate_limit, &context);
79
80 let metrics = super::Metrics::init(context.clone());
82 Self {
83 context,
84 me: config.me,
85 metrics,
86 initial: config.initial,
87 timeout: config.timeout,
88
89 excluded: HashSet::new(),
90
91 rate_limiter,
92 participants: PrioritySet::new(),
93
94 id: 0,
95 requests: HashMap::new(),
96 deadlines: PrioritySet::new(),
97 }
98 }
99
100 pub fn reconcile(&mut self, participants: &[P]) {
102 self.participants
103 .reconcile(participants, self.initial.as_millis());
104 self.rate_limiter.shrink_to_fit();
105 }
106
107 pub fn block(&mut self, participant: P) {
112 self.excluded.insert(participant);
113 }
114
115 pub fn request(&mut self, shuffle: bool) -> Option<(P, ID)> {
121 let participant_iter = if shuffle {
123 let mut participants = self.participants.iter().collect::<Vec<_>>();
124 participants.shuffle(&mut self.context);
125 Either::Left(participants.into_iter())
126 } else {
127 Either::Right(self.participants.iter())
128 };
129
130 for (participant, _) in participant_iter {
132 if Some(participant) == self.me.as_ref() {
134 continue;
135 }
136
137 if self.excluded.contains(participant) {
139 continue;
140 }
141
142 if self.rate_limiter.check_key(participant).is_err() {
144 continue;
145 }
146
147 let id = self.id;
149 self.id = self.id.wrapping_add(1);
150
151 let now = self.context.current();
153 self.requests.insert(id, (participant.clone(), now));
154 let deadline = now.checked_add(self.timeout).expect("time overflowed");
155 self.deadlines.put(id, deadline);
156
157 self.metrics.created.inc(Status::Success);
159 return Some((participant.clone(), id));
160 }
161
162 self.metrics.created.inc(Status::Failure);
164 None
165 }
166
167 fn update(&mut self, participant: P, elapsed: Duration) {
169 let Some(past) = self.participants.get(&participant) else {
170 return;
171 };
172 let next = past.saturating_add(elapsed.as_millis()) / 2;
173 self.metrics
174 .performance
175 .get_or_create(&PeerLabel::from(&participant))
176 .set(next as i64);
177 self.participants.put(participant, next);
178 }
179
180 pub fn cancel(&mut self, id: ID) -> Option<Request<P>> {
182 let (participant, start) = self.requests.remove(&id)?;
183 self.deadlines.remove(&id);
184 Some(Request {
185 id,
186 participant,
187 start,
188 })
189 }
190
191 pub fn handle(&mut self, participant: &P, id: ID) -> Option<Request<P>> {
197 let (expected, _) = self.requests.get(&id)?;
199 if expected != participant {
200 return None;
201 }
202
203 self.cancel(id)
205 }
206
207 pub fn resolve(&mut self, request: Request<P>) {
209 let elapsed = self
215 .context
216 .current()
217 .duration_since(request.start)
218 .unwrap_or_default();
219
220 self.update(request.participant, elapsed);
222 self.metrics.requests.inc(Status::Success);
223 self.metrics.resolves.observe(elapsed.as_secs_f64());
224 }
225
226 pub fn timeout(&mut self, request: Request<P>) {
228 self.update(request.participant, self.timeout);
229 self.metrics.requests.inc(Status::Timeout);
230 }
231
232 pub fn fail(&mut self, request: Request<P>) {
237 self.update(request.participant, self.timeout);
238 self.metrics.requests.inc(Status::Failure);
239 }
240
241 pub fn next(&self) -> Option<(ID, SystemTime)> {
243 let (id, deadline) = self.deadlines.peek()?;
244 Some((*id, *deadline))
245 }
246
247 #[allow(clippy::len_without_is_empty)]
249 pub fn len(&self) -> usize {
250 self.requests.len()
251 }
252
253 pub fn len_blocked(&self) -> usize {
255 self.excluded.len()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use commonware_cryptography::{ed25519::PrivateKey, PrivateKeyExt as _, Signer as _};
263 use commonware_runtime::{deterministic, Runner};
264 use commonware_utils::NZU32;
265 use governor::Quota;
266 use std::time::Duration;
267
268 #[test]
269 fn test_requester_basic() {
270 let executor = deterministic::Runner::seeded(0);
272 executor.start(|context| async move {
273 let scheme = PrivateKey::from_seed(0);
275 let me = scheme.public_key();
276 let timeout = Duration::from_secs(5);
277 let config = Config {
278 me: Some(scheme.public_key()),
279 rate_limit: Quota::per_second(NZU32!(1)),
280 initial: Duration::from_millis(100),
281 timeout,
282 };
283 let mut requester = Requester::new(context.clone(), config);
284
285 assert_eq!(requester.request(false), None);
287 assert_eq!(requester.len(), 0);
288
289 assert_eq!(requester.next(), None);
291
292 assert!(requester.handle(&me, 0).is_none());
294
295 let other = PrivateKey::from_seed(1).public_key();
297 requester.reconcile(&[me.clone(), other.clone()]);
298
299 let current = context.current();
301 let (participant, id) = requester.request(false).expect("failed to get participant");
302 assert_eq!(id, 0);
303 assert_eq!(participant, other);
304
305 let (id, deadline) = requester.next().expect("failed to get deadline");
307 assert_eq!(id, 0);
308 assert_eq!(deadline, current + timeout);
309 assert_eq!(requester.len(), 1);
310
311 assert_eq!(requester.request(false), None);
313
314 context.sleep(Duration::from_millis(10)).await;
316
317 assert!(requester.handle(&me, id).is_none());
319
320 let request = requester
322 .handle(&participant, id)
323 .expect("failed to get request");
324 assert_eq!(request.id, id);
325 requester.resolve(request);
326
327 assert_eq!(requester.request(false), None);
329
330 assert_eq!(requester.request(false), None);
332
333 context.sleep(Duration::from_secs(1)).await;
335
336 let (participant, id) = requester.request(false).expect("failed to get participant");
338 assert_eq!(participant, other);
339 assert_eq!(id, 1);
340
341 let request = requester
343 .handle(&participant, id)
344 .expect("failed to get request");
345 requester.timeout(request);
346
347 assert_eq!(requester.request(false), None);
349
350 context.sleep(Duration::from_secs(1)).await;
352
353 let (participant, id) = requester.request(false).expect("failed to get participant");
355 assert_eq!(participant, other);
356 assert_eq!(id, 2);
357
358 assert!(requester.cancel(id).is_some());
360
361 assert_eq!(requester.next(), None);
363 assert_eq!(requester.len(), 0);
364
365 context.sleep(Duration::from_secs(1)).await;
367
368 requester.block(other);
370
371 assert_eq!(requester.request(false), None);
373 });
374 }
375
376 #[test]
377 fn test_requester_multiple() {
378 let executor = deterministic::Runner::seeded(0);
380 executor.start(|context| async move {
381 let scheme = PrivateKey::from_seed(0);
383 let me = scheme.public_key();
384 let timeout = Duration::from_secs(5);
385 let config = Config {
386 me: Some(scheme.public_key()),
387 rate_limit: Quota::per_second(NZU32!(1)),
388 initial: Duration::from_millis(100),
389 timeout,
390 };
391 let mut requester = Requester::new(context.clone(), config);
392
393 assert_eq!(requester.request(false), None);
395
396 assert_eq!(requester.next(), None);
398
399 let other1 = PrivateKey::from_seed(1).public_key();
401 let other2 = PrivateKey::from_seed(2).public_key();
402 requester.reconcile(&[me.clone(), other1.clone(), other2.clone()]);
403
404 let (participant, id) = requester.request(false).expect("failed to get participant");
406 assert_eq!(id, 0);
407 if participant == other1 {
408 let request = requester
409 .handle(&participant, id)
410 .expect("failed to get request");
411 requester.timeout(request);
412 } else {
413 panic!("unexpected participant");
414 }
415
416 let (participant, id) = requester.request(false).expect("failed to get participant");
418 assert_eq!(id, 1);
419 if participant == other2 {
420 context.sleep(Duration::from_millis(10)).await;
421 let request = requester
422 .handle(&participant, id)
423 .expect("failed to get request");
424 requester.resolve(request);
425 } else {
426 panic!("unexpected participant");
427 }
428
429 assert_eq!(requester.request(false), None);
431
432 context.sleep(Duration::from_secs(1)).await;
434
435 let (participant, id) = requester.request(false).expect("failed to get participant");
437 assert_eq!(participant, other2);
438 assert_eq!(id, 2);
439
440 assert!(requester.cancel(id).is_some());
442
443 let other3 = PrivateKey::from_seed(3).public_key();
445 requester.reconcile(&[me, other1, other2.clone(), other3.clone()]);
446
447 let (participant, id) = requester.request(false).expect("failed to get participant");
449 assert_eq!(participant, other3);
450 assert_eq!(id, 3);
451
452 context.sleep(Duration::from_secs(1)).await;
454 loop {
455 let (participant, _) = requester.request(true).unwrap();
457 if participant == other2 {
458 break;
459 }
460
461 context.sleep(Duration::from_secs(1)).await;
463 }
464 });
465 }
466}