use std::time::Duration;
#[macro_use(defer)]
extern crate scopeguard;
use crossbeam::channel::Receiver;
use crossbeam::crossbeam_channel::{select, tick, unbounded};
use crossbeam::sync::WaitGroup;
use std::fmt;
use std::error::Error;
use std::result;
use std::thread;
use rand::Rng;
use std::time::Instant;
type Result = result::Result<(), Box<dyn Error>>;
pub struct Group {
wg: WaitGroup,
}
impl Group {
pub fn new() -> Group {
Group {
wg: WaitGroup::new(),
}
}
pub fn wait(self) {
self.wg.wait();
}
pub fn start_with_channel<F>(&self, stop_ch: Receiver<bool>, f: F)
where
F: Fn(Receiver<bool>) -> () + 'static + std::marker::Send + std::marker::Sync,
{
self.start(move || f(stop_ch.clone()));
}
pub fn start<F>(&self, f: F)
where
F: Fn() -> () + std::marker::Send + 'static,
{
let wg = self.wg.clone();
thread::spawn(move || {
f();
drop(wg);
});
}
}
pub fn forever<F>(f: F, period: Duration)
where
F: Fn() -> (),
{
let (_s, r) = unbounded();
until(f, period, r)
}
pub fn until<F>(f: F, period: Duration, stop_ch: Receiver<bool>)
where
F: Fn() -> (),
{
jitter_until(f, period, 0.0, true, stop_ch)
}
pub fn non_sliding_until<F>(f: F, period: Duration, stop_ch: Receiver<bool>)
where
F: Fn() -> (),
{
jitter_until(f, period, 0.0, false, stop_ch)
}
pub fn jitter_until<F>(
f: F,
period: Duration,
jitter_factor: f64,
sliding: bool,
stop_ch: Receiver<bool>,
) where
F: Fn() -> (),
{
backoff_until(
f,
JitteredBackoffManager::new_jittered_backoff_manager(period, jitter_factor),
sliding,
stop_ch,
)
}
pub fn backoff_until<F>(
f: F,
mut backoff: Box<dyn BackoffManager>,
sliding: bool,
stop_ch: Receiver<bool>,
) where
F: Fn() -> (),
{
loop {
select! {
recv(stop_ch) -> _ => return ,
default => {}
}
let mut t = backoff.backoff();
f();
if sliding {
t = backoff.backoff();
}
select! {
recv(stop_ch) -> _ => return,
recv(t) -> _msg => { }
}
}
}
pub fn jitter(duration: Duration, max_factor: f64) -> Duration {
let mut mf = max_factor;
if mf <= 0.0 {
mf = 1.0;
}
let mut rng = rand::thread_rng();
Duration::from_nanos((duration.as_nanos() as f64 * (1.0 + rng.gen::<u64>() as f64 * mf)) as u64)
}
#[derive(Debug, Clone)]
struct WaitTimeoutError;
impl fmt::Display for WaitTimeoutError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "timed out waiting for the condition")
}
}
impl Error for WaitTimeoutError {}
fn run_condition_with_crash_protection<F>(condition: F) -> std::result::Result<bool, Box<dyn Error>>
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
condition()
}
pub struct Backoff {
duration: Duration,
factor: f64,
jitter: f64,
steps: i32,
cap: Duration,
}
impl Backoff {
pub fn step(&mut self) -> Duration {
if self.steps < 1 {
if self.jitter > 0.0 {
return jitter(self.duration, self.jitter);
}
return self.duration;
}
self.steps = self.steps - 1;
let mut duration = self.duration;
if self.factor != 0.0 {
self.duration =
Duration::from_nanos((self.duration.as_nanos() as f64 * self.factor) as u64);
if !(self.cap.as_nanos() == 0) && self.duration > self.cap {
self.duration = self.cap;
self.steps = 0;
}
}
if self.jitter > 0.0 {
duration = jitter(duration, self.jitter);
}
duration
}
}
pub trait BackoffManager {
fn backoff(&mut self) -> Receiver<Instant>;
}
pub struct ExponentialBackoffManager {
backoff: Backoff,
last_backoff_start: Instant,
initial_backoff: Duration,
backoff_reset_duration: Duration,
}
impl BackoffManager for ExponentialBackoffManager {
fn backoff(&mut self) -> Receiver<Instant> {
tick(self.get_next_backoff())
}
}
impl ExponentialBackoffManager {
pub fn new_exponential_backoff_manager(
init_backoff: Duration,
max_backoff: Duration,
reset_duration: Duration,
backoff_factor: f64,
jitter: f64,
) -> Box<dyn BackoffManager> {
Box::new(ExponentialBackoffManager {
backoff: Backoff {
duration: init_backoff,
factor: backoff_factor,
jitter: jitter,
steps: std::i32::MAX,
cap: max_backoff,
},
initial_backoff: init_backoff,
last_backoff_start: Instant::now(),
backoff_reset_duration: reset_duration,
})
}
fn get_next_backoff(&mut self) -> Duration {
if Instant::now().duration_since(self.last_backoff_start) > self.backoff_reset_duration {
self.backoff.steps = std::i32::MAX;
self.backoff.duration = self.initial_backoff;
}
self.last_backoff_start = Instant::now();
return self.backoff.step();
}
}
pub struct JitteredBackoffManager {
duration: Duration,
jitter: f64,
}
impl BackoffManager for JitteredBackoffManager {
fn backoff(&mut self) -> Receiver<Instant> {
tick(self.get_next_backoff())
}
}
impl JitteredBackoffManager {
pub fn new_jittered_backoff_manager(
duration: Duration,
jitter: f64,
) -> Box<dyn BackoffManager> {
Box::new(JitteredBackoffManager {
duration: duration,
jitter: jitter,
}) as Box<dyn BackoffManager>
}
fn get_next_backoff(&self) -> Duration {
if self.jitter > 0.0 {
jitter(self.duration, self.jitter)
} else {
self.duration
}
}
}
pub fn exponential_backoff<F>(backoff: &mut Backoff, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
while backoff.steps > 0 {
let ok = run_condition_with_crash_protection(condition)?;
if ok {
return Ok(());
}
if backoff.steps == 1 {
break;
}
thread::sleep(backoff.step());
}
Err(Box::new(WaitTimeoutError))
}
pub fn poll<F>(interval: Duration, timeout: Duration, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
poll_internal(poller(interval, timeout), condition)
}
fn poll_internal<F>(wait: Box<dyn Fn(Receiver<bool>) -> Receiver<bool>>, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let (_s, r) = unbounded();
wait_for(wait, condition, r)
}
pub fn poll_immediate<F>(interval: Duration, timeout: Duration, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
poll_immediate_internal(poller(interval, timeout), condition)
}
fn poll_immediate_internal<F>(
wait: Box<dyn Fn(Receiver<bool>) -> Receiver<bool>>,
condition: F,
) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let done = run_condition_with_crash_protection(condition)?;
if done {
return Ok(());
}
poll_internal(wait, condition)
}
pub fn poll_infinite<F>(interval: Duration, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let (_s, r) = unbounded();
return poll_until(interval, condition, r);
}
pub fn poll_immediate_infinite<F>(interval: Duration, condition: F) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let done = run_condition_with_crash_protection(condition)?;
if done {
return Ok(());
}
poll_infinite(interval, condition)
}
pub fn poll_until<F>(interval: Duration, condition: F, stop_ch: Receiver<bool>) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
return wait_for(poller(interval, Duration::new(0, 0)), condition, stop_ch);
}
pub fn poll_immediate_until<F>(interval: Duration, condition: F, stop_ch: Receiver<bool>) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let done = condition()?;
if done {
return Ok(());
}
select! {
recv(stop_ch) -> _ => return Err(Box::new(WaitTimeoutError)) ,
default => return poll_until(interval, condition, stop_ch)
}
}
pub fn wait_for<F>(
wait: Box<dyn Fn(Receiver<bool>) -> Receiver<bool>>,
func: F,
done: Receiver<bool>,
) -> Result
where
F: Fn() -> std::result::Result<bool, Box<dyn Error>> + Copy,
{
let (s, r) = unbounded();
let c = wait(r);
defer! { drop(s);};
loop {
select! {
recv(c) -> msg => {
let ok = run_condition_with_crash_protection(func)?;
if ok {
return Ok(());
}
if msg.is_err() {
return Err(Box::new(WaitTimeoutError));
}
},
recv(done) -> _ => return Err(Box::new(WaitTimeoutError)),
}
}
}
fn poller(interval: Duration, timeout: Duration) -> Box<dyn Fn(Receiver<bool>) -> Receiver<bool>> {
let func = move |done: Receiver<bool>| -> Receiver<bool> {
let (s, r) = unbounded();
let rr = r.clone();
thread::spawn(move || {
let ticker = tick(interval);
let mut after = tick(Duration::from_secs(1000000000));
if !(timeout.as_nanos() == 0) {
after = tick(timeout);
}
loop {
select! {
recv(ticker) -> _ => {
s.send(true).unwrap();
},
recv(after) -> _ => {
return
},
recv(done) -> _ => {
return
},
}
}
});
rr
};
Box::new(func)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[test]
fn test_poll() {
let (s, r) = unbounded();
let cond_fn = || {
select! {
recv(r) -> msg => {
if msg.unwrap() == 3 {
return Ok(true);
}else {
return Ok(false);
}
},
default => {
return Ok(false)
}
,
}
};
thread::spawn(move || {
for i in 1..4 {
thread::sleep(Duration::from_millis(20));
s.send(i).unwrap();
}
drop(s);
println!("sender dropped");
});
let ret = poll(
Duration::from_millis(100),
Duration::from_millis(300),
cond_fn,
);
assert_eq!(true, ret.is_ok());
let cond_fn = || Ok(false);
let ret = poll(
Duration::from_millis(100),
Duration::from_millis(300),
cond_fn,
);
assert_eq!(true, ret.is_err());
}
#[test]
fn test_until() {
let (s, r) = unbounded();
let counter = Arc::new(Mutex::new(0));
let counter1 = counter.clone();
let worker_fn = move || {
let mut counter = counter1.lock().unwrap();
*counter += 1;
println!("do work {}", *counter);
};
let counter2 = counter.clone();
std::thread::spawn(move || loop {
{
let counter = counter2.lock().unwrap();
if *counter > 4 {
drop(s);
println!("sender dropped");
return;
}
}
thread::sleep(Duration::from_millis(10));
});
until(worker_fn, Duration::from_millis(10), r);
let counter = counter.lock().unwrap();
println!("final counter {}", *counter);
assert_eq!(true, *counter > 4);
}
#[test]
fn test_xxx() {
let zero_seconds = Duration::new(0, 0);
assert_eq!(0, zero_seconds.as_nanos());
let (s, r) = unbounded();
thread::spawn(move || {
s.send(1).unwrap();
s.send(2).unwrap();
});
let msg1 = r.recv().unwrap();
let msg2 = r.recv().unwrap();
assert_eq!(msg1 + msg2, 3);
}
#[test]
fn test_groups() {
let counter = Arc::new(Mutex::new(0));
let counter1 = counter.clone();
let worker_fn1 = move || {
let mut counter = counter1.lock().unwrap();
*counter += 1;
thread::sleep(Duration::from_millis(50));
println!("worker 1 finished");
};
let (s, r) = unbounded();
let counter2 = counter.clone();
let worker_fn2 = move |x| {
let mut counter = counter2.lock().unwrap();
*counter += 1;
select! {
recv(x) -> _ => {
println!("worker 2 finished");
}
}
};
thread::spawn(move || {
thread::sleep(Duration::from_millis(300));
drop(s);
println!("notify worker 2 to finish");
});
let group = Group::new();
group.start(worker_fn1);
println!("worker 1 started");
group.start_with_channel(r, worker_fn2);
println!("worker 2 started");
println!("wait two workeres to finish");
group.wait();
println!("two workeres are finished");
}
#[test]
fn test_exponential_backoff_manager() {
let duration_factors = vec![1, 2, 4, 8, 10, 10, 10];
let mut backoff_mgr = ExponentialBackoffManager::new_exponential_backoff_manager(
Duration::from_millis(1),
Duration::from_millis(10),
Duration::from_secs(3600),
2.0,
0.0,
);
for i in duration_factors {
let start = Instant::now();
let r = backoff_mgr.backoff();
select! {
recv(r) -> _ => {},
}
let passed = Instant::now().duration_since(start).as_millis();
assert_eq!(
true,
passed >= i,
"backoff should be at least {} ms, but got {}",
i,
passed
);
}
}
#[test]
fn test_jitter_backoff_manager_with_real_clock() {
let mut backoff_mgr =
JitteredBackoffManager::new_jittered_backoff_manager(Duration::from_millis(1), 0.0);
for _ in 0..5 {
let start = Instant::now();
let r = backoff_mgr.backoff();
select! {
recv(r) -> _ => {},
}
let passed = Instant::now().duration_since(start).as_millis();
assert_eq!(
true,
passed >= 1,
"backoff should be at least 1ms, but got {}",
passed
);
}
}
}