use std::{env, thread, time};
use anyhow::anyhow;
#[derive(Clone, Copy, Debug)]
pub enum PollStrategy {
Uniform { interval: time::Duration },
Backoff { initial_interval: time::Duration, factor: f32, max_interval: time::Duration },
}
pub fn sleep_unless<F>(
total_sleep: time::Duration,
mut stop: F,
poll_strategy: PollStrategy,
) -> bool
where
F: FnMut() -> bool,
{
let deadline = time::Instant::now() + total_sleep;
let mut next_interval = match poll_strategy {
PollStrategy::Uniform { interval } => interval,
PollStrategy::Backoff { initial_interval, .. } => initial_interval,
};
if next_interval.is_zero() {
next_interval = time::Duration::from_millis(1);
}
loop {
if stop() {
return true;
}
let now = time::Instant::now();
if now >= deadline {
return false;
}
let remaining = deadline.saturating_duration_since(now);
thread::sleep(remaining.min(next_interval));
if let PollStrategy::Backoff { factor, max_interval, .. } = poll_strategy {
if factor > 1.0 {
next_interval = std::cmp::min(next_interval.mul_f32(factor), max_interval);
}
}
}
}
pub fn resolve_sessions(sessions: &mut Vec<String>, action: &str) -> anyhow::Result<()> {
if sessions.is_empty() {
if let Ok(current_session) = env::var("SHPOOL_SESSION_NAME") {
sessions.push(current_session);
}
}
if sessions.is_empty() {
eprintln!("no session to {action}");
return Err(anyhow!("no session to {action}"));
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::cell::Cell;
use std::time::Duration;
use super::{sleep_unless, PollStrategy};
#[test]
fn sleep_unless_returns_immediately_when_stop_is_true() {
let stopped = sleep_unless(
Duration::from_millis(10),
|| true,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);
assert!(stopped);
}
#[test]
fn sleep_unless_times_out_when_stop_is_false() {
let stopped = sleep_unless(
Duration::from_millis(3),
|| false,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);
assert!(!stopped);
}
#[test]
fn sleep_unless_rechecks_stop_with_backoff() {
let checks = Cell::new(0usize);
let stopped = sleep_unless(
Duration::from_millis(20),
|| {
let n = checks.get() + 1;
checks.set(n);
n >= 3
},
PollStrategy::Backoff {
initial_interval: Duration::from_millis(1),
factor: 2.0,
max_interval: Duration::from_millis(4),
},
);
assert!(stopped);
assert!(checks.get() >= 3);
}
}