1use futures::{Stream, StreamExt, stream::Fuse};
4use hashbrown::{HashMap, hash_map::RawEntryMut};
5use pin_project::pin_project;
6use std::{
7 collections::HashSet,
8 hash::Hash,
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12};
13use tokio::time::Instant;
14use tokio_util::time::delay_queue::{self, DelayQueue};
15
16#[derive(Debug)]
18pub struct ScheduleRequest<T> {
19 pub message: T,
21 pub run_at: Instant,
23}
24
25struct ScheduledEntry {
27 run_at: Instant,
28 queue_key: delay_queue::Key,
29}
30
31#[pin_project(project = SchedulerProj)]
36pub struct Scheduler<T, R> {
37 queue: DelayQueue<T>,
45 scheduled: HashMap<T, ScheduledEntry>,
49 pending: HashSet<T>,
51 #[pin]
53 requests: Fuse<R>,
54 debounce: Duration,
60}
61
62impl<T, R: Stream> Scheduler<T, R> {
63 fn new(requests: R, debounce: Duration) -> Self {
64 Self {
65 queue: DelayQueue::new(),
66 scheduled: HashMap::new(),
67 pending: HashSet::new(),
68 requests: requests.fuse(),
69 debounce,
70 }
71 }
72}
73
74impl<T: Hash + Eq + Clone, R> SchedulerProj<'_, T, R> {
75 fn schedule_message(&mut self, request: ScheduleRequest<T>) {
79 if self.pending.contains(&request.message) {
80 return;
82 }
83 let next_time = request
84 .run_at
85 .checked_add(*self.debounce)
86 .map_or_else(max_schedule_time, |time|
87 time.min(max_schedule_time()));
89 match self.scheduled.raw_entry_mut().from_key(&request.message) {
90 RawEntryMut::Occupied(mut old_entry) if old_entry.get().run_at >= request.run_at => {
94 let entry = old_entry.get_mut();
96 self.queue.reset_at(&entry.queue_key, next_time);
97 entry.run_at = next_time;
98 old_entry.insert_key(request.message);
99 }
100 RawEntryMut::Occupied(_old_entry) => {
101 }
103 RawEntryMut::Vacant(entry) => {
104 let message = request.message.clone();
106 entry.insert(request.message, ScheduledEntry {
107 run_at: next_time,
108 queue_key: self.queue.insert_at(message, next_time),
109 });
110 }
111 }
112 }
113
114 fn poll_pop_queue_message(
116 &mut self,
117 cx: &mut Context<'_>,
118 can_take_message: impl Fn(&T) -> bool,
119 ) -> Poll<T> {
120 if let Some(msg) = self.pending.iter().find(|msg| can_take_message(*msg)).cloned() {
121 return Poll::Ready(self.pending.take(&msg).unwrap());
122 }
123
124 loop {
125 match self.queue.poll_expired(cx) {
126 Poll::Ready(Some(msg)) => {
127 let msg = msg.into_inner();
128 let (msg, _) = self.scheduled.remove_entry(&msg).expect(
129 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
130 );
131 if can_take_message(&msg) {
132 break Poll::Ready(msg);
133 }
134 self.pending.insert(msg);
135 }
136 Poll::Ready(None) | Poll::Pending => break Poll::Pending,
137 }
138 }
139 }
140
141 pub fn pop_queue_message_into_pending(&mut self, cx: &mut Context<'_>) {
143 while let Poll::Ready(Some(msg)) = self.queue.poll_expired(cx) {
144 let msg = msg.into_inner();
145 self.scheduled.remove_entry(&msg).expect(
146 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
147 );
148 self.pending.insert(msg);
149 }
150 }
151}
152
153pub struct Hold<'a, T, R> {
155 scheduler: Pin<&'a mut Scheduler<T, R>>,
156}
157
158impl<T, R> Stream for Hold<'_, T, R>
159where
160 T: Eq + Hash + Clone,
161 R: Stream<Item = ScheduleRequest<T>>,
162{
163 type Item = T;
164
165 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
166 let this = self.get_mut();
167 let mut scheduler = this.scheduler.as_mut().project();
168
169 loop {
170 match scheduler.requests.as_mut().poll_next(cx) {
171 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
172 Poll::Ready(None) => return Poll::Ready(None),
173 Poll::Pending => break,
174 }
175 }
176
177 scheduler.pop_queue_message_into_pending(cx);
178 Poll::Pending
179 }
180}
181
182pub struct HoldUnless<'a, T, R, C> {
184 scheduler: Pin<&'a mut Scheduler<T, R>>,
185 can_take_message: C,
186}
187
188impl<T, R, C> Stream for HoldUnless<'_, T, R, C>
189where
190 T: Eq + Hash + Clone,
191 R: Stream<Item = ScheduleRequest<T>>,
192 C: Fn(&T) -> bool + Unpin,
193{
194 type Item = T;
195
196 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197 let this = self.get_mut();
198 let can_take_message = &this.can_take_message;
199 let mut scheduler = this.scheduler.as_mut().project();
200
201 loop {
202 match scheduler.requests.as_mut().poll_next(cx) {
203 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
204 Poll::Ready(None) => return Poll::Ready(None),
205 Poll::Pending => break,
206 }
207 }
208
209 match scheduler.poll_pop_queue_message(cx, can_take_message) {
210 Poll::Ready(expired) => Poll::Ready(Some(expired)),
211 Poll::Pending => Poll::Pending,
212 }
213 }
214}
215
216impl<T, R> Scheduler<T, R>
217where
218 T: Eq + Hash + Clone,
219 R: Stream<Item = ScheduleRequest<T>>,
220{
221 pub fn hold_unless<C: Fn(&T) -> bool>(
232 self: Pin<&'_ mut Self>,
233 can_take_message: C,
234 ) -> HoldUnless<'_, T, R, C> {
235 HoldUnless {
236 scheduler: self,
237 can_take_message,
238 }
239 }
240
241 #[must_use]
245 pub fn hold(self: Pin<&'_ mut Self>) -> Hold<'_, T, R> {
246 Hold { scheduler: self }
247 }
248
249 #[cfg(test)]
251 pub fn contains_pending(&self, msg: &T) -> bool {
252 self.pending.contains(msg)
253 }
254}
255
256impl<T, R> Stream for Scheduler<T, R>
257where
258 T: Eq + Hash + Clone,
259 R: Stream<Item = ScheduleRequest<T>>,
260{
261 type Item = T;
262
263 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264 Pin::new(&mut self.hold_unless(|_| true)).poll_next(cx)
265 }
266}
267
268pub fn scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(requests: S) -> Scheduler<T, S> {
279 Scheduler::new(requests, Duration::ZERO)
280}
281
282pub fn debounced_scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(
290 requests: S,
291 debounce: Duration,
292) -> Scheduler<T, S> {
293 Scheduler::new(requests, debounce)
294}
295
296pub(crate) fn max_schedule_time() -> Instant {
300 Instant::now() + Duration::from_secs(86400 * 30 * 6)
301}
302
303#[cfg(test)]
304mod tests {
305 use crate::utils::KubeRuntimeStreamExt;
306
307 use super::{ScheduleRequest, debounced_scheduler, scheduler};
308 use educe::Educe;
309 use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future, poll, stream};
310 use std::{pin::pin, task::Poll};
311 use tokio::time::{Duration, Instant, advance, pause, sleep};
312
313 fn unwrap_poll<T>(poll: Poll<T>) -> T {
314 if let Poll::Ready(x) = poll {
315 x
316 } else {
317 panic!("Tried to unwrap a pending poll!")
318 }
319 }
320
321 #[derive(Educe, Eq, Clone, Debug)]
323 #[educe(PartialEq, Hash)]
324 struct SingletonMessage(#[educe(PartialEq(ignore), Hash(ignore))] u8);
325
326 #[tokio::test]
327 async fn scheduler_should_hold_and_release_items() {
328 pause();
329 let mut scheduler = Box::pin(scheduler(
330 stream::iter(vec![ScheduleRequest {
331 message: 1_u8,
332 run_at: Instant::now(),
333 }])
334 .on_complete(sleep(Duration::from_secs(4))),
335 ));
336 assert!(!scheduler.contains_pending(&1));
337 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
338 assert!(scheduler.contains_pending(&1));
339 assert_eq!(
340 unwrap_poll(poll!(scheduler.as_mut().hold_unless(|_| true).next())).unwrap(),
341 1_u8
342 );
343 assert!(!scheduler.contains_pending(&1));
344 assert!(scheduler.as_mut().hold_unless(|_| true).next().await.is_none());
345 }
346
347 #[tokio::test]
348 async fn scheduler_should_not_reschedule_pending_items() {
349 pause();
350 let (mut tx, rx) = mpsc::unbounded::<ScheduleRequest<u8>>();
351 let mut scheduler = Box::pin(scheduler(rx));
352 tx.send(ScheduleRequest {
353 message: 1,
354 run_at: Instant::now(),
355 })
356 .await
357 .unwrap();
358 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
359 tx.send(ScheduleRequest {
360 message: 1,
361 run_at: Instant::now(),
362 })
363 .await
364 .unwrap();
365 future::join(
366 async {
367 sleep(Duration::from_secs(2)).await;
368 drop(tx);
369 },
370 async {
371 assert_eq!(scheduler.next().await.unwrap(), 1);
372 assert!(scheduler.next().await.is_none())
373 },
374 )
375 .await;
376 }
377
378 #[tokio::test]
379 async fn scheduler_pending_message_should_not_block_head_of_line() {
380 let mut scheduler = Box::pin(scheduler(
381 stream::iter(vec![
382 ScheduleRequest {
383 message: 1,
384 run_at: Instant::now(),
385 },
386 ScheduleRequest {
387 message: 2,
388 run_at: Instant::now(),
389 },
390 ])
391 .on_complete(sleep(Duration::from_secs(2))),
392 ));
393 assert_eq!(
394 scheduler.as_mut().hold_unless(|x| *x != 1).next().await.unwrap(),
395 2
396 );
397 }
398
399 #[tokio::test]
400 async fn scheduler_should_emit_items_as_requested() {
401 pause();
402 let mut scheduler = pin!(scheduler(
403 stream::iter(vec![
404 ScheduleRequest {
405 message: 1_u8,
406 run_at: Instant::now() + Duration::from_secs(1),
407 },
408 ScheduleRequest {
409 message: 2,
410 run_at: Instant::now() + Duration::from_secs(3),
411 },
412 ])
413 .on_complete(sleep(Duration::from_secs(5))),
414 ));
415 assert!(poll!(scheduler.next()).is_pending());
416 advance(Duration::from_secs(2)).await;
417 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 1);
418 assert!(poll!(scheduler.next()).is_pending());
419 advance(Duration::from_secs(2)).await;
420 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 2);
421 assert!(scheduler.next().await.is_none());
423 }
424
425 #[tokio::test]
426 async fn scheduler_dedupe_should_keep_earlier_item() {
427 pause();
428 let mut scheduler = pin!(scheduler(
429 stream::iter(vec![
430 ScheduleRequest {
431 message: (),
432 run_at: Instant::now() + Duration::from_secs(1),
433 },
434 ScheduleRequest {
435 message: (),
436 run_at: Instant::now() + Duration::from_secs(3),
437 },
438 ])
439 .on_complete(sleep(Duration::from_secs(5))),
440 ));
441 assert!(poll!(scheduler.next()).is_pending());
442 advance(Duration::from_secs(2)).await;
443 scheduler.next().now_or_never().unwrap().unwrap();
444 assert!(scheduler.next().await.is_none());
446 }
447
448 #[tokio::test]
449 async fn scheduler_dedupe_should_replace_later_item() {
450 pause();
451 let mut scheduler = pin!(scheduler(
452 stream::iter(vec![
453 ScheduleRequest {
454 message: (),
455 run_at: Instant::now() + Duration::from_secs(3),
456 },
457 ScheduleRequest {
458 message: (),
459 run_at: Instant::now() + Duration::from_secs(1),
460 },
461 ])
462 .on_complete(sleep(Duration::from_secs(5))),
463 ));
464 assert!(poll!(scheduler.next()).is_pending());
465 advance(Duration::from_secs(2)).await;
466 scheduler.next().now_or_never().unwrap().unwrap();
467 assert!(scheduler.next().await.is_none());
469 }
470
471 #[tokio::test]
472 async fn scheduler_dedupe_should_allow_rescheduling_emitted_item() {
473 pause();
474 let (mut schedule_tx, schedule_rx) = mpsc::unbounded();
475 let mut scheduler = scheduler(schedule_rx);
476 schedule_tx
477 .send(ScheduleRequest {
478 message: (),
479 run_at: Instant::now() + Duration::from_secs(1),
480 })
481 .await
482 .unwrap();
483 assert!(poll!(scheduler.next()).is_pending());
484 advance(Duration::from_secs(2)).await;
485 scheduler.next().now_or_never().unwrap().unwrap();
486 assert!(poll!(scheduler.next()).is_pending());
487 schedule_tx
488 .send(ScheduleRequest {
489 message: (),
490 run_at: Instant::now() + Duration::from_secs(1),
491 })
492 .await
493 .unwrap();
494 assert!(poll!(scheduler.next()).is_pending());
495 advance(Duration::from_secs(2)).await;
496 scheduler.next().now_or_never().unwrap().unwrap();
497 assert!(poll!(scheduler.next()).is_pending());
498 }
499
500 #[tokio::test]
501 async fn scheduler_should_overwrite_message_with_soonest_version() {
502 pause();
503
504 let now = Instant::now();
505 let scheduler = scheduler(
506 stream::iter([
507 ScheduleRequest {
508 message: SingletonMessage(1),
509 run_at: now + Duration::from_secs(2),
510 },
511 ScheduleRequest {
512 message: SingletonMessage(2),
513 run_at: now + Duration::from_secs(1),
514 },
515 ])
516 .on_complete(sleep(Duration::from_secs(5))),
517 );
518 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![2]);
519 }
520
521 #[tokio::test]
522 async fn scheduler_should_not_overwrite_message_with_later_version() {
523 pause();
524
525 let now = Instant::now();
526 let scheduler = scheduler(
527 stream::iter([
528 ScheduleRequest {
529 message: SingletonMessage(1),
530 run_at: now + Duration::from_secs(1),
531 },
532 ScheduleRequest {
533 message: SingletonMessage(2),
534 run_at: now + Duration::from_secs(2),
535 },
536 ])
537 .on_complete(sleep(Duration::from_secs(5))),
538 );
539 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![1]);
540 }
541
542 #[tokio::test]
543 async fn scheduler_should_add_debounce_to_a_request() {
544 pause();
545
546 let now = Instant::now();
547 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
548 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(2));
549
550 sched_tx
551 .send(ScheduleRequest {
552 message: SingletonMessage(1),
553 run_at: now,
554 })
555 .await
556 .unwrap();
557 advance(Duration::from_secs(1)).await;
558 assert!(poll!(scheduler.next()).is_pending());
559 advance(Duration::from_secs(3)).await;
560 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 1);
561 }
562
563 #[tokio::test]
564 async fn scheduler_should_dedup_message_within_debounce_period() {
565 pause();
566
567 let mut now = Instant::now();
568 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
569 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(3));
570
571 sched_tx
572 .send(ScheduleRequest {
573 message: SingletonMessage(1),
574 run_at: now,
575 })
576 .await
577 .unwrap();
578 assert!(poll!(scheduler.next()).is_pending());
579 advance(Duration::from_secs(1)).await;
580
581 now = Instant::now();
582 sched_tx
583 .send(ScheduleRequest {
584 message: SingletonMessage(2),
585 run_at: now,
586 })
587 .await
588 .unwrap();
589 advance(Duration::from_millis(2500)).await;
591 assert!(poll!(scheduler.next()).is_pending());
592
593 advance(Duration::from_secs(3)).await;
594 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 2);
595 assert!(poll!(scheduler.next()).is_pending());
596 }
597}