use anyhow::{Context, Result};
use std::future::Future;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::time::timeout;
pub const DEFAULT_ASYNC_TIMEOUT: Duration = Duration::from_secs(30);
pub const SHORT_ASYNC_TIMEOUT: Duration = Duration::from_secs(5);
pub const LONG_ASYNC_TIMEOUT: Duration = Duration::from_secs(300);
pub async fn with_timeout<F, T>(fut: F, duration: Duration, context: &str) -> Result<T>
where
F: Future<Output = T>,
{
match timeout(duration, fut).await {
Ok(result) => Ok(result),
Err(_) => anyhow::bail!("Operation timed out after {duration:?}: {context}"),
}
}
pub async fn with_default_timeout<F, T>(fut: F, context: &str) -> Result<T>
where
F: Future<Output = T>,
{
with_timeout(fut, DEFAULT_ASYNC_TIMEOUT, context).await
}
pub async fn with_short_timeout<F, T>(fut: F, context: &str) -> Result<T>
where
F: Future<Output = T>,
{
with_timeout(fut, SHORT_ASYNC_TIMEOUT, context).await
}
pub async fn with_long_timeout<F, T>(fut: F, context: &str) -> Result<T>
where
F: Future<Output = T>,
{
with_timeout(fut, LONG_ASYNC_TIMEOUT, context).await
}
pub async fn retry_with_backoff<F, Fut, T>(
mut op: F,
max_retries: usize,
initial_delay: Duration,
context: &str,
) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut delay = initial_delay;
let mut last_error = None;
for i in 0..=max_retries {
match op().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if i < max_retries {
tokio::time::sleep(delay).await;
delay *= 2;
}
}
}
}
let err = last_error.unwrap_or_else(|| anyhow::anyhow!("Retry failed without error"));
Err(err).with_context(|| format!("Operation failed after {max_retries} retries: {context}"))
}
pub async fn sleep_with_context(duration: Duration, _context: &str) {
tokio::time::sleep(duration).await;
}
pub async fn join_all_with_timeout<F, T>(
futs: Vec<F>,
duration: Duration,
context: &str,
) -> Result<Vec<T>>
where
F: Future<Output = T>,
{
with_timeout(futures::future::join_all(futs), duration, context).await
}
pub async fn read_exact_uninit<R>(reader: &mut R, len: usize) -> std::io::Result<Vec<u8>>
where
R: tokio::io::AsyncRead + Unpin,
{
let mut buf = Vec::with_capacity(len);
while buf.len() < len {
let n = reader.read_buf(&mut buf).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"unexpected EOF before reading {len} bytes (got {})",
buf.len()
),
));
}
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn read_exact_uninit_round_trips_known_payload() {
let payload: Vec<u8> = (0..64u8).collect();
let mut reader = std::io::Cursor::new(payload.clone());
let got = read_exact_uninit(&mut reader, payload.len())
.await
.expect("read full payload");
assert_eq!(got, payload);
}
#[tokio::test]
async fn read_exact_uninit_reads_across_multiple_poll_reads() {
let payload: Vec<u8> = (0..2000u32).map(|i| (i % 256) as u8).collect();
let mut reader = std::io::Cursor::new(payload.clone());
let got = read_exact_uninit(&mut reader, payload.len())
.await
.expect("read full payload");
assert_eq!(got, payload);
}
#[tokio::test]
async fn read_exact_uninit_returns_unexpected_eof_on_short_read() {
let payload = b"only ten!".to_vec();
let mut reader = std::io::Cursor::new(payload);
let err = read_exact_uninit(&mut reader, 64)
.await
.expect_err("short read must error");
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn read_exact_uninit_returns_unexpected_eof_on_empty_reader() {
let mut reader = std::io::Cursor::new(Vec::<u8>::new());
let err = read_exact_uninit(&mut reader, 1)
.await
.expect_err("empty reader must error");
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn read_exact_uninit_zero_len_returns_empty_vec() {
let mut reader = std::io::Cursor::new(Vec::<u8>::new());
let got = read_exact_uninit(&mut reader, 0)
.await
.expect("zero-length read must succeed");
assert!(got.is_empty());
}
}