use crate::*;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub struct KitsuneBackoff {
timeout: KitsuneTimeout,
cur_ms: Arc<AtomicU64>,
max_ms: u64,
}
impl KitsuneBackoff {
pub fn new(timeout: KitsuneTimeout, initial_ms: u64, max_ms: u64) -> Self {
let cur_ms = Arc::new(AtomicU64::new(initial_ms));
Self {
timeout,
cur_ms,
max_ms,
}
}
pub async fn wait(&self) {
let cur = self.cur_ms.load(Ordering::Relaxed);
self.cur_ms.fetch_add(cur, Ordering::Relaxed);
let cur = std::cmp::min(
cur,
std::cmp::min(
self.max_ms,
self.timeout.time_remaining().as_millis() as u64 + 1,
),
);
if cur > 0 {
tokio::time::sleep(std::time::Duration::from_millis(cur)).await;
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct KitsuneTimeout(tokio::time::Instant);
impl KitsuneTimeout {
pub fn new(duration: std::time::Duration) -> Self {
Self(tokio::time::Instant::now().checked_add(duration).unwrap())
}
pub fn from_millis(millis: u64) -> Self {
Self::new(std::time::Duration::from_millis(millis))
}
pub fn backoff(&self, initial_ms: u64, max_ms: u64) -> KitsuneBackoff {
KitsuneBackoff::new(*self, initial_ms, max_ms)
}
pub fn time_remaining(&self) -> std::time::Duration {
self.0
.saturating_duration_since(tokio::time::Instant::now())
}
pub fn is_expired(&self) -> bool {
self.0 <= tokio::time::Instant::now()
}
pub fn ok(&self, ctx: &str) -> KitsuneResult<()> {
if self.is_expired() {
Err(KitsuneErrorKind::TimedOut(ctx.into()).into())
} else {
Ok(())
}
}
pub fn mix<'a, 'b, R, F>(
&'a self,
ctx: &str,
f: F,
) -> impl std::future::Future<Output = KitsuneResult<R>> + 'b + Send
where
R: 'b,
F: std::future::Future<Output = KitsuneResult<R>> + 'b + Send,
{
let time_remaining = self.time_remaining();
let ctx = ctx.to_string();
async move {
match tokio::time::timeout(time_remaining, f).await {
Ok(r) => r,
Err(_) => Err(KitsuneErrorKind::TimedOut(ctx).into()),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_kitsune_timeout() {
let t = KitsuneTimeout::new(std::time::Duration::from_millis(40));
assert!(t.time_remaining().as_millis() > 0);
assert!(!t.is_expired());
}
#[tokio::test]
async fn expired_kitsune_timeout() {
let t = KitsuneTimeout::new(std::time::Duration::from_millis(1));
tokio::time::sleep(std::time::Duration::from_millis(2)).await;
assert!(t.time_remaining().as_micros() == 0);
assert!(t.is_expired());
}
#[tokio::test(flavor = "multi_thread")]
async fn kitsune_backoff() {
let t = KitsuneTimeout::from_millis(100);
let mut times = Vec::new();
let start = tokio::time::Instant::now();
let bo = t.backoff(2, 15);
while !t.is_expired() {
times.push(start.elapsed().as_millis() as u64);
bo.wait().await;
}
println!("{:?}", times);
assert!(times.len() > 4);
assert!(times.len() < 20);
}
}