use std::time::{Duration, Instant};
#[derive(Debug)]
pub enum PollError<E> {
Timeout,
ConditionError(E),
}
impl<E> std::fmt::Display for PollError<E>
where
E: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PollError::Timeout => write!(f, "Operation timed out"),
PollError::ConditionError(e) => write!(f, "Condition error: {}", e),
}
}
}
impl<E> std::error::Error for PollError<E>
where
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PollError::Timeout => None,
PollError::ConditionError(e) => Some(e),
}
}
}
pub fn poll_until<F, E>(
mut condition: F,
timeout: Duration,
poll_interval: Duration,
) -> Result<(), PollError<E>>
where
F: FnMut() -> Result<bool, E>,
{
let start = Instant::now();
loop {
if start.elapsed() >= timeout {
return Err(PollError::Timeout);
}
match condition() {
Ok(true) => return Ok(()),
Ok(false) => {
std::thread::sleep(poll_interval);
}
Err(e) => return Err(PollError::ConditionError(e)),
}
}
}
pub fn poll_with_timeout<F, T, E>(
mut operation: F,
timeout: Duration,
poll_interval: Duration,
) -> Result<Option<T>, PollError<E>>
where
F: FnMut() -> Result<Option<T>, E>,
{
let start = Instant::now();
loop {
if start.elapsed() >= timeout {
return Ok(None);
}
match operation() {
Ok(Some(result)) => return Ok(Some(result)),
Ok(None) => {
std::thread::sleep(poll_interval);
}
Err(e) => return Err(PollError::ConditionError(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[test]
fn test_poll_until_success() {
let counter = Arc::new(Mutex::new(0));
let counter_clone = counter.clone();
let result = poll_until(
|| {
let mut count = counter_clone.lock().unwrap();
*count += 1;
Ok::<bool, &str>(*count >= 3)
},
Duration::from_millis(500),
Duration::from_millis(10),
);
assert!(result.is_ok());
assert!(*counter.lock().unwrap() >= 3);
}
#[test]
fn test_poll_until_timeout() {
let result = poll_until(
|| Ok::<bool, &str>(false), Duration::from_millis(50),
Duration::from_millis(10),
);
assert!(matches!(result, Err(PollError::Timeout)));
}
#[test]
fn test_poll_until_error() {
let result = poll_until(
|| Err::<bool, &str>("test error"),
Duration::from_millis(100),
Duration::from_millis(10),
);
assert!(matches!(result, Err(PollError::ConditionError("test error"))));
}
#[test]
fn test_poll_with_timeout_success() {
let counter = Arc::new(Mutex::new(0));
let counter_clone = counter.clone();
let result = poll_with_timeout(
|| {
let mut count = counter_clone.lock().unwrap();
*count += 1;
if *count >= 3 {
Ok::<Option<i32>, &str>(Some(*count))
} else {
Ok(None)
}
},
Duration::from_millis(500),
Duration::from_millis(10),
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(3));
}
#[test]
fn test_poll_with_timeout_timeout() {
let result = poll_with_timeout(
|| Ok::<Option<()>, &str>(None), Duration::from_millis(50),
Duration::from_millis(10),
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[test]
fn test_poll_with_timeout_error() {
let result = poll_with_timeout(
|| Err::<Option<()>, &str>("test error"),
Duration::from_millis(100),
Duration::from_millis(10),
);
assert!(matches!(result, Err(PollError::ConditionError("test error"))));
}
}