use std::{
any::Any,
future::Future,
pin::Pin,
task::{self, Poll},
};
use pin_project::pin_project;
use sealed::sealed;
use tokio::time::{Duration, Instant, Sleep};
use crate::{
envelope::{Envelope, MessageKind},
message::Message,
source::{SourceArc, SourceStream, UnattachedSource},
time::far_future,
tracing::TraceId,
Addr,
};
pub struct Interval<M> {
source: SourceArc<IntervalSource<M>>,
}
#[sealed]
impl<M: Message> crate::source::SourceHandle for Interval<M> {
fn is_terminated(&self) -> bool {
self.source.is_terminated()
}
fn terminate_by_ref(&self) -> bool {
self.source.terminate_by_ref()
}
}
const NEVER: Duration = Duration::ZERO;
#[pin_project]
struct IntervalSource<M> {
message: M,
period: Duration,
is_delayed: bool,
#[pin]
sleep: Sleep,
}
impl<M: Message> Interval<M> {
pub fn new(message: M) -> UnattachedSource<Self> {
let source = IntervalSource {
message,
period: NEVER,
is_delayed: false,
sleep: tokio::time::sleep_until(far_future()),
};
let source = SourceArc::new(source, false);
UnattachedSource::new(source, |source| Self { source })
}
pub fn set_message(&self, message: M) {
let mut guard = ward!(self.source.lock());
*guard.stream().project().message = message;
}
#[track_caller]
pub fn set_period(&self, period: Duration) {
assert_ne!(period, NEVER, "period must be non-zero");
let mut guard = ward!(self.source.lock());
let source = guard.stream().project();
if *source.period == NEVER || period == *source.period {
return;
}
if !*source.is_delayed {
let new_deadline = source.sleep.deadline() - *source.period + period;
source.sleep.reset(new_deadline);
*source.period = period;
guard.wake();
} else {
*source.period = period;
}
}
#[track_caller]
pub fn start(&self, period: Duration) {
assert_ne!(period, NEVER, "period must be non-zero");
self.schedule(None, period);
}
#[track_caller]
pub fn start_after(&self, delay: Duration, period: Duration) {
self.start_at(Instant::now() + delay, period);
}
#[instability::unstable]
#[track_caller]
pub fn start_at(&self, when: Instant, period: Duration) {
assert_ne!(period, NEVER, "period must be non-zero");
self.schedule(Some(when), period);
}
pub fn stop(&self) {
self.schedule(Some(far_future()), NEVER);
}
fn schedule(&self, when: Option<Instant>, period: Duration) {
let mut guard = ward!(self.source.lock());
let source = guard.stream().project();
*source.is_delayed = when.is_some();
*source.period = period;
let new_deadline = when.unwrap_or_else(|| Instant::now() + period);
source.sleep.reset(new_deadline);
guard.wake();
}
}
impl<M: Message> SourceStream for IntervalSource<M> {
fn as_any_mut(self: Pin<&mut Self>) -> Pin<&mut dyn Any> {
unsafe { self.map_unchecked_mut(|s| s) }
}
fn poll_recv(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
let mut this = self.project();
if *this.period == NEVER {
return Poll::Pending;
}
if !this.sleep.as_mut().poll(cx).is_ready() {
return Poll::Pending;
}
*this.is_delayed = false;
let new_deadline = this.sleep.deadline() + *this.period;
this.sleep.reset(new_deadline);
let message = this.message.clone();
let kind = MessageKind::regular(Addr::NULL);
let trace_id = TraceId::generate();
let envelope = Envelope::with_trace_id(message, kind, trace_id);
Poll::Ready(Some(envelope))
}
}