use super::clock::Clock;
use crate::Cmd;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct Throttler {
last_run: Option<Instant>,
}
impl Default for Throttler {
fn default() -> Self {
Self::new()
}
}
impl Throttler {
pub fn new() -> Self {
Self { last_run: None }
}
pub fn run<Msg, F>(&mut self, interval: Duration, action: F) -> Cmd<Msg>
where
F: FnOnce() -> Cmd<Msg>,
{
let now = Instant::now();
let should_run = match self.last_run {
None => true,
Some(last) => now.duration_since(last) >= interval,
};
if should_run {
self.last_run = Some(now);
action()
} else {
Cmd::none()
}
}
pub fn run_msg<Msg>(&mut self, interval: Duration, msg: Msg) -> Cmd<Msg>
where
Msg: Clone,
{
self.run(interval, || Cmd::Msg(msg))
}
pub fn is_throttled(&self, interval: Duration) -> bool {
match self.last_run {
None => false,
Some(last) => Instant::now().duration_since(last) < interval,
}
}
pub fn time_remaining(&self, interval: Duration) -> Option<Duration> {
self.last_run.and_then(|last| {
let elapsed = Instant::now().duration_since(last);
if elapsed < interval {
Some(interval - elapsed)
} else {
None
}
})
}
pub fn reset(&mut self) {
self.last_run = None;
}
pub fn suppress(&mut self) {
self.last_run = Some(Instant::now());
}
}
#[derive(Debug, Clone)]
pub struct ThrottlerWithClock<C: Clock> {
clock: C,
last_run: Option<Duration>,
}
impl<C: Clock> ThrottlerWithClock<C> {
pub fn new(clock: C) -> Self {
Self {
clock,
last_run: None,
}
}
pub fn run<Msg, F>(&mut self, interval: Duration, action: F) -> Cmd<Msg>
where
F: FnOnce() -> Cmd<Msg>,
{
let now = self.clock.now();
let should_run = match self.last_run {
None => true,
Some(last) => now >= last + interval,
};
if should_run {
self.last_run = Some(now);
action()
} else {
Cmd::none()
}
}
pub fn run_msg<Msg>(&mut self, interval: Duration, msg: Msg) -> Cmd<Msg>
where
Msg: Clone,
{
self.run(interval, || Cmd::Msg(msg))
}
pub fn is_throttled(&self, interval: Duration) -> bool {
match self.last_run {
None => false,
Some(last) => self.clock.now() < last + interval,
}
}
pub fn time_remaining(&self, interval: Duration) -> Option<Duration> {
self.last_run.and_then(|last| {
let now = self.clock.now();
let expires_at = last + interval;
if now < expires_at {
Some(expires_at - now)
} else {
None
}
})
}
pub fn reset(&mut self) {
self.last_run = None;
}
pub fn suppress(&mut self) {
self.last_run = Some(self.clock.now());
}
}
#[derive(Debug, Clone)]
pub struct TrailingThrottler {
last_run: Option<Instant>,
has_pending: bool,
trailing_scheduled: bool,
}
impl Default for TrailingThrottler {
fn default() -> Self {
Self::new()
}
}
impl TrailingThrottler {
pub fn new() -> Self {
Self {
last_run: None,
has_pending: false,
trailing_scheduled: false,
}
}
#[cfg(feature = "tokio")]
pub fn run<Msg>(&mut self, interval: Duration, msg: Msg, trailing_msg: Msg) -> Cmd<Msg>
where
Msg: Clone + Send + 'static,
{
let now = Instant::now();
let should_run = match self.last_run {
None => true,
Some(last) => now.duration_since(last) >= interval,
};
if should_run {
self.last_run = Some(now);
self.has_pending = false;
self.trailing_scheduled = false;
Cmd::Msg(msg)
} else {
self.has_pending = true;
if !self.trailing_scheduled {
self.trailing_scheduled = true;
let remaining = self
.last_run
.map(|last| interval.saturating_sub(now.duration_since(last)))
.unwrap_or(interval);
Cmd::delay(remaining, trailing_msg)
} else {
Cmd::none()
}
}
}
pub fn mark_run(&mut self, interval: Duration) -> Option<Duration> {
let now = Instant::now();
let should_run = match self.last_run {
None => true,
Some(last) => now.duration_since(last) >= interval,
};
if should_run {
self.last_run = Some(now);
self.has_pending = false;
self.trailing_scheduled = false;
None
} else {
self.has_pending = true;
if !self.trailing_scheduled {
self.trailing_scheduled = true;
let remaining = self
.last_run
.map(|last| interval.saturating_sub(now.duration_since(last)))
.unwrap_or(interval);
Some(remaining)
} else {
None
}
}
}
pub fn should_fire_trailing(&mut self) -> bool {
if self.has_pending {
self.has_pending = false;
self.trailing_scheduled = false;
self.last_run = Some(Instant::now());
true
} else {
self.trailing_scheduled = false;
false
}
}
pub fn reset(&mut self) {
self.last_run = None;
self.has_pending = false;
self.trailing_scheduled = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_throttler_first_call() {
let mut throttler = Throttler::new();
assert!(!throttler.is_throttled(Duration::from_millis(100)));
let cmd = throttler.run(Duration::from_millis(100), || Cmd::Msg(42));
assert!(cmd.is_msg());
}
#[test]
fn test_throttler_blocks_rapid_calls() {
use crate::testing::FakeClock;
let clock = FakeClock::new();
let mut throttler = ThrottlerWithClock::new(clock.clone());
let interval = Duration::from_millis(50);
let cmd1 = throttler.run(interval, || Cmd::Msg(1));
assert!(cmd1.is_msg());
let cmd2 = throttler.run(interval, || Cmd::Msg(2));
assert!(cmd2.is_none());
clock.advance(Duration::from_millis(20));
let cmd3 = throttler.run(interval, || Cmd::Msg(3));
assert!(cmd3.is_none());
clock.advance(Duration::from_millis(35));
let cmd4 = throttler.run(interval, || Cmd::Msg(4));
assert!(cmd4.is_msg());
}
#[test]
fn test_throttler_reset() {
let mut throttler = Throttler::new();
let interval = Duration::from_millis(100);
let _ = throttler.run(interval, || Cmd::Msg(1));
throttler.reset();
let cmd = throttler.run(interval, || Cmd::Msg(2));
assert!(cmd.is_msg());
}
#[test]
fn test_throttler_time_remaining() {
let mut throttler = Throttler::new();
let interval = Duration::from_millis(100);
assert!(throttler.time_remaining(interval).is_none());
let _ = throttler.run(interval, || Cmd::Msg(1));
let remaining = throttler.time_remaining(interval);
assert!(remaining.is_some());
assert!(remaining.unwrap() <= interval);
}
#[test]
#[cfg(feature = "tokio")]
fn test_trailing_throttler_basic() {
let mut throttler = TrailingThrottler::new();
let interval = Duration::from_millis(50);
let cmd1 = throttler.run(interval, 1, 100);
assert!(cmd1.is_msg_eq(&1));
let cmd2 = throttler.run(interval, 2, 100);
assert!(!cmd2.is_none());
let cmd3 = throttler.run(interval, 3, 100);
assert!(cmd3.is_none());
thread::sleep(Duration::from_millis(55));
assert!(throttler.should_fire_trailing());
}
}