1use futures::{Stream, StreamExt, stream::Fuse};
4use hashbrown::{HashMap, hash_map::RawEntryMut};
5use pin_project::pin_project;
6use std::{
7 collections::HashSet,
8 hash::Hash,
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12};
13use tokio::time::Instant;
14use tokio_util::time::delay_queue::{self, DelayQueue};
15
16#[derive(Debug)]
18pub struct ScheduleRequest<T> {
19 pub message: T,
20 pub run_at: Instant,
21}
22
23struct ScheduledEntry {
25 run_at: Instant,
26 queue_key: delay_queue::Key,
27}
28
29#[pin_project(project = SchedulerProj)]
30pub struct Scheduler<T, R> {
31 queue: DelayQueue<T>,
39 scheduled: HashMap<T, ScheduledEntry>,
43 pending: HashSet<T>,
45 #[pin]
47 requests: Fuse<R>,
48 debounce: Duration,
54}
55
56impl<T, R: Stream> Scheduler<T, R> {
57 fn new(requests: R, debounce: Duration) -> Self {
58 Self {
59 queue: DelayQueue::new(),
60 scheduled: HashMap::new(),
61 pending: HashSet::new(),
62 requests: requests.fuse(),
63 debounce,
64 }
65 }
66}
67
68impl<T: Hash + Eq + Clone, R> SchedulerProj<'_, T, R> {
69 fn schedule_message(&mut self, request: ScheduleRequest<T>) {
73 if self.pending.contains(&request.message) {
74 return;
76 }
77 let next_time = request
78 .run_at
79 .checked_add(*self.debounce)
80 .map_or_else(max_schedule_time, |time|
81 time.min(max_schedule_time()));
83 match self.scheduled.raw_entry_mut().from_key(&request.message) {
84 RawEntryMut::Occupied(mut old_entry) if old_entry.get().run_at >= request.run_at => {
88 let entry = old_entry.get_mut();
90 self.queue.reset_at(&entry.queue_key, next_time);
91 entry.run_at = next_time;
92 old_entry.insert_key(request.message);
93 }
94 RawEntryMut::Occupied(_old_entry) => {
95 }
97 RawEntryMut::Vacant(entry) => {
98 let message = request.message.clone();
100 entry.insert(request.message, ScheduledEntry {
101 run_at: next_time,
102 queue_key: self.queue.insert_at(message, next_time),
103 });
104 }
105 }
106 }
107
108 fn poll_pop_queue_message(
110 &mut self,
111 cx: &mut Context<'_>,
112 can_take_message: impl Fn(&T) -> bool,
113 ) -> Poll<T> {
114 if let Some(msg) = self.pending.iter().find(|msg| can_take_message(*msg)).cloned() {
115 return Poll::Ready(self.pending.take(&msg).unwrap());
116 }
117
118 loop {
119 match self.queue.poll_expired(cx) {
120 Poll::Ready(Some(msg)) => {
121 let msg = msg.into_inner();
122 let (msg, _) = self.scheduled.remove_entry(&msg).expect(
123 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
124 );
125 if can_take_message(&msg) {
126 break Poll::Ready(msg);
127 }
128 self.pending.insert(msg);
129 }
130 Poll::Ready(None) | Poll::Pending => break Poll::Pending,
131 }
132 }
133 }
134
135 pub fn pop_queue_message_into_pending(&mut self, cx: &mut Context<'_>) {
137 while let Poll::Ready(Some(msg)) = self.queue.poll_expired(cx) {
138 let msg = msg.into_inner();
139 self.scheduled.remove_entry(&msg).expect(
140 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
141 );
142 self.pending.insert(msg);
143 }
144 }
145}
146
147pub struct Hold<'a, T, R> {
149 scheduler: Pin<&'a mut Scheduler<T, R>>,
150}
151
152impl<T, R> Stream for Hold<'_, T, R>
153where
154 T: Eq + Hash + Clone,
155 R: Stream<Item = ScheduleRequest<T>>,
156{
157 type Item = T;
158
159 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160 let this = self.get_mut();
161 let mut scheduler = this.scheduler.as_mut().project();
162
163 loop {
164 match scheduler.requests.as_mut().poll_next(cx) {
165 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
166 Poll::Ready(None) => return Poll::Ready(None),
167 Poll::Pending => break,
168 }
169 }
170
171 scheduler.pop_queue_message_into_pending(cx);
172 Poll::Pending
173 }
174}
175
176pub struct HoldUnless<'a, T, R, C> {
178 scheduler: Pin<&'a mut Scheduler<T, R>>,
179 can_take_message: C,
180}
181
182impl<T, R, C> Stream for HoldUnless<'_, T, R, C>
183where
184 T: Eq + Hash + Clone,
185 R: Stream<Item = ScheduleRequest<T>>,
186 C: Fn(&T) -> bool + Unpin,
187{
188 type Item = T;
189
190 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191 let this = self.get_mut();
192 let can_take_message = &this.can_take_message;
193 let mut scheduler = this.scheduler.as_mut().project();
194
195 loop {
196 match scheduler.requests.as_mut().poll_next(cx) {
197 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
198 Poll::Ready(None) => return Poll::Ready(None),
199 Poll::Pending => break,
200 }
201 }
202
203 match scheduler.poll_pop_queue_message(cx, can_take_message) {
204 Poll::Ready(expired) => Poll::Ready(Some(expired)),
205 Poll::Pending => Poll::Pending,
206 }
207 }
208}
209
210impl<T, R> Scheduler<T, R>
211where
212 T: Eq + Hash + Clone,
213 R: Stream<Item = ScheduleRequest<T>>,
214{
215 pub fn hold_unless<C: Fn(&T) -> bool>(
226 self: Pin<&'_ mut Self>,
227 can_take_message: C,
228 ) -> HoldUnless<'_, T, R, C> {
229 HoldUnless {
230 scheduler: self,
231 can_take_message,
232 }
233 }
234
235 #[must_use]
239 pub fn hold(self: Pin<&'_ mut Self>) -> Hold<'_, T, R> {
240 Hold { scheduler: self }
241 }
242
243 #[cfg(test)]
245 pub fn contains_pending(&self, msg: &T) -> bool {
246 self.pending.contains(msg)
247 }
248}
249
250impl<T, R> Stream for Scheduler<T, R>
251where
252 T: Eq + Hash + Clone,
253 R: Stream<Item = ScheduleRequest<T>>,
254{
255 type Item = T;
256
257 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
258 Pin::new(&mut self.hold_unless(|_| true)).poll_next(cx)
259 }
260}
261
262pub fn scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(requests: S) -> Scheduler<T, S> {
273 Scheduler::new(requests, Duration::ZERO)
274}
275
276#[allow(clippy::module_name_repetitions)]
284pub fn debounced_scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(
285 requests: S,
286 debounce: Duration,
287) -> Scheduler<T, S> {
288 Scheduler::new(requests, debounce)
289}
290
291pub(crate) fn max_schedule_time() -> Instant {
295 Instant::now() + Duration::from_secs(86400 * 30 * 6)
296}
297
298#[cfg(test)]
299mod tests {
300 use crate::utils::KubeRuntimeStreamExt;
301
302 use super::{ScheduleRequest, debounced_scheduler, scheduler};
303 use educe::Educe;
304 use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future, poll, stream};
305 use std::{pin::pin, task::Poll};
306 use tokio::time::{Duration, Instant, advance, pause, sleep};
307
308 fn unwrap_poll<T>(poll: Poll<T>) -> T {
309 if let Poll::Ready(x) = poll {
310 x
311 } else {
312 panic!("Tried to unwrap a pending poll!")
313 }
314 }
315
316 #[derive(Educe, Eq, Clone, Debug)]
318 #[educe(PartialEq, Hash)]
319 struct SingletonMessage(#[educe(PartialEq(ignore), Hash(ignore))] u8);
320
321 #[tokio::test]
322 async fn scheduler_should_hold_and_release_items() {
323 pause();
324 let mut scheduler = Box::pin(scheduler(
325 stream::iter(vec![ScheduleRequest {
326 message: 1_u8,
327 run_at: Instant::now(),
328 }])
329 .on_complete(sleep(Duration::from_secs(4))),
330 ));
331 assert!(!scheduler.contains_pending(&1));
332 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
333 assert!(scheduler.contains_pending(&1));
334 assert_eq!(
335 unwrap_poll(poll!(scheduler.as_mut().hold_unless(|_| true).next())).unwrap(),
336 1_u8
337 );
338 assert!(!scheduler.contains_pending(&1));
339 assert!(scheduler.as_mut().hold_unless(|_| true).next().await.is_none());
340 }
341
342 #[tokio::test]
343 async fn scheduler_should_not_reschedule_pending_items() {
344 pause();
345 let (mut tx, rx) = mpsc::unbounded::<ScheduleRequest<u8>>();
346 let mut scheduler = Box::pin(scheduler(rx));
347 tx.send(ScheduleRequest {
348 message: 1,
349 run_at: Instant::now(),
350 })
351 .await
352 .unwrap();
353 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
354 tx.send(ScheduleRequest {
355 message: 1,
356 run_at: Instant::now(),
357 })
358 .await
359 .unwrap();
360 future::join(
361 async {
362 sleep(Duration::from_secs(2)).await;
363 drop(tx);
364 },
365 async {
366 assert_eq!(scheduler.next().await.unwrap(), 1);
367 assert!(scheduler.next().await.is_none())
368 },
369 )
370 .await;
371 }
372
373 #[tokio::test]
374 async fn scheduler_pending_message_should_not_block_head_of_line() {
375 let mut scheduler = Box::pin(scheduler(
376 stream::iter(vec![
377 ScheduleRequest {
378 message: 1,
379 run_at: Instant::now(),
380 },
381 ScheduleRequest {
382 message: 2,
383 run_at: Instant::now(),
384 },
385 ])
386 .on_complete(sleep(Duration::from_secs(2))),
387 ));
388 assert_eq!(
389 scheduler.as_mut().hold_unless(|x| *x != 1).next().await.unwrap(),
390 2
391 );
392 }
393
394 #[tokio::test]
395 async fn scheduler_should_emit_items_as_requested() {
396 pause();
397 let mut scheduler = pin!(scheduler(
398 stream::iter(vec![
399 ScheduleRequest {
400 message: 1_u8,
401 run_at: Instant::now() + Duration::from_secs(1),
402 },
403 ScheduleRequest {
404 message: 2,
405 run_at: Instant::now() + Duration::from_secs(3),
406 },
407 ])
408 .on_complete(sleep(Duration::from_secs(5))),
409 ));
410 assert!(poll!(scheduler.next()).is_pending());
411 advance(Duration::from_secs(2)).await;
412 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 1);
413 assert!(poll!(scheduler.next()).is_pending());
414 advance(Duration::from_secs(2)).await;
415 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 2);
416 assert!(scheduler.next().await.is_none());
418 }
419
420 #[tokio::test]
421 async fn scheduler_dedupe_should_keep_earlier_item() {
422 pause();
423 let mut scheduler = pin!(scheduler(
424 stream::iter(vec![
425 ScheduleRequest {
426 message: (),
427 run_at: Instant::now() + Duration::from_secs(1),
428 },
429 ScheduleRequest {
430 message: (),
431 run_at: Instant::now() + Duration::from_secs(3),
432 },
433 ])
434 .on_complete(sleep(Duration::from_secs(5))),
435 ));
436 assert!(poll!(scheduler.next()).is_pending());
437 advance(Duration::from_secs(2)).await;
438 scheduler.next().now_or_never().unwrap().unwrap();
439 assert!(scheduler.next().await.is_none());
441 }
442
443 #[tokio::test]
444 async fn scheduler_dedupe_should_replace_later_item() {
445 pause();
446 let mut scheduler = pin!(scheduler(
447 stream::iter(vec![
448 ScheduleRequest {
449 message: (),
450 run_at: Instant::now() + Duration::from_secs(3),
451 },
452 ScheduleRequest {
453 message: (),
454 run_at: Instant::now() + Duration::from_secs(1),
455 },
456 ])
457 .on_complete(sleep(Duration::from_secs(5))),
458 ));
459 assert!(poll!(scheduler.next()).is_pending());
460 advance(Duration::from_secs(2)).await;
461 scheduler.next().now_or_never().unwrap().unwrap();
462 assert!(scheduler.next().await.is_none());
464 }
465
466 #[tokio::test]
467 async fn scheduler_dedupe_should_allow_rescheduling_emitted_item() {
468 pause();
469 let (mut schedule_tx, schedule_rx) = mpsc::unbounded();
470 let mut scheduler = scheduler(schedule_rx);
471 schedule_tx
472 .send(ScheduleRequest {
473 message: (),
474 run_at: Instant::now() + Duration::from_secs(1),
475 })
476 .await
477 .unwrap();
478 assert!(poll!(scheduler.next()).is_pending());
479 advance(Duration::from_secs(2)).await;
480 scheduler.next().now_or_never().unwrap().unwrap();
481 assert!(poll!(scheduler.next()).is_pending());
482 schedule_tx
483 .send(ScheduleRequest {
484 message: (),
485 run_at: Instant::now() + Duration::from_secs(1),
486 })
487 .await
488 .unwrap();
489 assert!(poll!(scheduler.next()).is_pending());
490 advance(Duration::from_secs(2)).await;
491 scheduler.next().now_or_never().unwrap().unwrap();
492 assert!(poll!(scheduler.next()).is_pending());
493 }
494
495 #[tokio::test]
496 async fn scheduler_should_overwrite_message_with_soonest_version() {
497 pause();
498
499 let now = Instant::now();
500 let scheduler = scheduler(
501 stream::iter([
502 ScheduleRequest {
503 message: SingletonMessage(1),
504 run_at: now + Duration::from_secs(2),
505 },
506 ScheduleRequest {
507 message: SingletonMessage(2),
508 run_at: now + Duration::from_secs(1),
509 },
510 ])
511 .on_complete(sleep(Duration::from_secs(5))),
512 );
513 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![2]);
514 }
515
516 #[tokio::test]
517 async fn scheduler_should_not_overwrite_message_with_later_version() {
518 pause();
519
520 let now = Instant::now();
521 let scheduler = scheduler(
522 stream::iter([
523 ScheduleRequest {
524 message: SingletonMessage(1),
525 run_at: now + Duration::from_secs(1),
526 },
527 ScheduleRequest {
528 message: SingletonMessage(2),
529 run_at: now + Duration::from_secs(2),
530 },
531 ])
532 .on_complete(sleep(Duration::from_secs(5))),
533 );
534 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![1]);
535 }
536
537 #[tokio::test]
538 async fn scheduler_should_add_debounce_to_a_request() {
539 pause();
540
541 let now = Instant::now();
542 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
543 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(2));
544
545 sched_tx
546 .send(ScheduleRequest {
547 message: SingletonMessage(1),
548 run_at: now,
549 })
550 .await
551 .unwrap();
552 advance(Duration::from_secs(1)).await;
553 assert!(poll!(scheduler.next()).is_pending());
554 advance(Duration::from_secs(3)).await;
555 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 1);
556 }
557
558 #[tokio::test]
559 async fn scheduler_should_dedup_message_within_debounce_period() {
560 pause();
561
562 let mut now = Instant::now();
563 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
564 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(3));
565
566 sched_tx
567 .send(ScheduleRequest {
568 message: SingletonMessage(1),
569 run_at: now,
570 })
571 .await
572 .unwrap();
573 assert!(poll!(scheduler.next()).is_pending());
574 advance(Duration::from_secs(1)).await;
575
576 now = Instant::now();
577 sched_tx
578 .send(ScheduleRequest {
579 message: SingletonMessage(2),
580 run_at: now,
581 })
582 .await
583 .unwrap();
584 advance(Duration::from_millis(2500)).await;
586 assert!(poll!(scheduler.next()).is_pending());
587
588 advance(Duration::from_secs(3)).await;
589 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 2);
590 assert!(poll!(scheduler.next()).is_pending());
591 }
592}