use std::{future::Future, sync::Arc, time::Duration};
use futures_util::future::BoxFuture;
use crate::policy::Policy;
use crate::timeout::TimeoutError;
type Hook = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
#[derive(Clone)]
pub struct TimeoutPolicy {
pub(crate) duration: Duration,
pub(crate) cancel: bool,
pub(crate) name: Option<String>,
pub(crate) on_timeout: Option<Hook>,
pub(crate) on_success: Option<Hook>,
pub(crate) on_failure: Option<Hook>,
}
pub struct Builder {
duration: Duration,
cancel: bool,
name: Option<String>,
on_timeout: Option<Hook>,
on_success: Option<Hook>,
on_failure: Option<Hook>,
}
impl Builder {
pub fn new() -> Self {
Self {
duration: Duration::from_secs(30),
cancel: true,
name: None,
on_timeout: None,
on_success: None,
on_failure: None,
}
}
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.duration = duration;
self
}
pub fn with_timeout_millis(mut self, millis: u64) -> Self {
self.duration = Duration::from_millis(millis);
self
}
pub fn with_timeout_secs(mut self, secs: u64) -> Self {
self.duration = Duration::from_secs(secs);
self
}
pub fn with_timeout_minutes(mut self, mins: u64) -> Self {
self.duration = Duration::from_secs(mins * 60);
self
}
pub fn with_timeout_hours(mut self, hours: u64) -> Self {
self.duration = Duration::from_secs(hours * 3600);
self
}
pub fn with_cancel(mut self, cancel: bool) -> Self {
self.cancel = cancel;
self
}
pub fn with_name(mut self, name: impl ToString) -> Self {
self.name = Some(name.to_string());
self
}
pub fn with_on_timeout<F, Fut>(mut self, f: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.on_timeout = Some(Arc::new(move || -> BoxFuture<'static, ()> {
Box::pin(f())
}));
self
}
pub fn with_on_success<F, Fut>(mut self, f: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.on_success = Some(Arc::new(move || -> BoxFuture<'static, ()> {
Box::pin(f())
}));
self
}
pub fn with_on_failure<F, Fut>(mut self, f: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.on_failure = Some(Arc::new(move || -> BoxFuture<'static, ()> {
Box::pin(f())
}));
self
}
pub fn build(self) -> TimeoutPolicy {
TimeoutPolicy {
duration: self.duration,
cancel: self.cancel,
name: self.name,
on_timeout: self.on_timeout,
on_success: self.on_success,
on_failure: self.on_failure,
}
}
}
impl TimeoutPolicy {
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.duration = duration;
self
}
pub fn with_timeout_millis(mut self, millis: u64) -> Self {
self.duration = Duration::from_millis(millis);
self
}
pub fn with_timeout_secs(mut self, secs: u64) -> Self {
self.duration = Duration::from_secs(secs);
self
}
pub fn with_timeout_minutes(mut self, mins: u64) -> Self {
self.duration = Duration::from_secs(mins * 60);
self
}
pub fn with_timeout_hours(mut self, hours: u64) -> Self {
self.duration = Duration::from_secs(hours * 3600);
self
}
}
impl TimeoutPolicy {
pub async fn run<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send + From<TimeoutError>,
{
let this = self.clone();
if this.cancel {
match tokio::time::timeout(this.duration, f()).await {
Ok(Ok(val)) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
Ok(val)
}
Ok(Err(e)) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
Err(e)
}
Err(_elapsed) => {
if let Some(ref cb) = this.on_timeout {
cb().await;
}
Err(TimeoutError::Elapsed {
duration: this.duration,
name: this.name,
}
.into())
}
}
} else {
let result = f().await;
match &result {
Ok(_) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
}
Err(_) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
}
}
result
}
}
pub async fn run_with_timeout<F, Fut, T, E>(&self, mut f: F) -> Result<T, TimeoutError>
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send + Sync + 'static + std::error::Error,
{
let this = self.clone();
if this.cancel {
match tokio::time::timeout(this.duration, f()).await {
Ok(Ok(val)) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
Ok(val)
}
Ok(Err(e)) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
Err(TimeoutError::Returning(Box::new(e)))
}
Err(_elapsed) => {
if let Some(ref cb) = this.on_timeout {
cb().await;
}
Err(TimeoutError::Elapsed {
duration: this.duration,
name: this.name,
})
}
}
} else {
let result = f().await;
match &result {
Ok(_) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
}
Err(_) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
}
}
result.map_err(|e| TimeoutError::Returning(Box::new(e)))
}
}
}
impl Default for TimeoutPolicy {
fn default() -> Self {
Builder::new().build()
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl<T, E> Policy<T, E> for TimeoutPolicy
where
E: Send + From<TimeoutError>,
{
fn call<F, Fut>(&self, f: &mut F) -> impl Future<Output = Result<T, E>> + Send
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send,
{
let this = self.clone();
async move {
if this.cancel {
match tokio::time::timeout(this.duration, f()).await {
Ok(Ok(val)) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
Ok(val)
}
Ok(Err(e)) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
Err(e)
}
Err(_elapsed) => {
if let Some(ref cb) = this.on_timeout {
cb().await;
}
Err(TimeoutError::Elapsed {
duration: this.duration,
name: this.name,
}
.into())
}
}
} else {
let result = f().await;
match &result {
Ok(_) => {
if let Some(ref cb) = this.on_success {
cb().await;
}
}
Err(_) => {
if let Some(ref cb) = this.on_failure {
cb().await;
}
}
}
result
}
}
}
}