use std::marker::PhantomData;
use futures::{future::BoxFuture, FutureExt};
use tower::{Layer, Service};
use crate::{context::HasJobContext, error::BoxDynError, request::JobRequest, worker::WorkerId};
#[derive(Debug, thiserror::Error)]
pub enum AckError {
#[error("Acknowledgement failed {0}")]
NoAck(#[source] BoxDynError),
}
#[async_trait::async_trait]
pub trait Ack<J> {
type Acknowledger;
async fn ack(&self, worker_id: &WorkerId, data: &Self::Acknowledger) -> Result<(), AckError>;
}
#[derive(Debug)]
pub struct AckLayer<A: Ack<J>, J> {
ack: A,
job_type: PhantomData<J>,
worker_id: WorkerId,
}
impl<A: Ack<J>, J> AckLayer<A, J> {
pub fn new(ack: A, worker_id: WorkerId) -> Self {
Self {
ack,
job_type: PhantomData,
worker_id,
}
}
}
impl<A, J, S> Layer<S> for AckLayer<A, J>
where
S: Service<JobRequest<J>> + Send + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
S::Future: Send + 'static,
A: Ack<J> + Clone + Send + Sync + 'static,
{
type Service = AckService<S, A, J>;
fn layer(&self, service: S) -> Self::Service {
AckService {
service,
ack: self.ack.clone(),
job_type: PhantomData,
worker_id: self.worker_id.clone(),
}
}
}
#[derive(Debug)]
pub struct AckService<SV, A, J> {
service: SV,
ack: A,
job_type: PhantomData<J>,
worker_id: WorkerId,
}
impl<SV, A, J> Service<JobRequest<J>> for AckService<SV, A, J>
where
SV: Service<JobRequest<J>> + Send + Sync + 'static,
SV::Error: std::error::Error + Send + Sync + 'static,
<SV as Service<JobRequest<J>>>::Future: std::marker::Send + 'static,
A: Ack<J> + Send + 'static + Clone + Send + Sync,
J: 'static,
<SV as Service<JobRequest<J>>>::Response: std::marker::Send,
<A as Ack<J>>::Acknowledger: Sync + Send + Clone,
{
type Response = SV::Response;
type Error = SV::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: JobRequest<J>) -> Self::Future {
let ack = self.ack.clone();
let worker_id = self.worker_id.clone();
let data = request
.context()
.data_opt::<<A as Ack<J>>::Acknowledger>()
.cloned();
let fut = self.service.call(request);
let fut_with_ack = async move {
let res = fut.await;
if let Some(data) = data {
if let Err(e) = ack.ack(&worker_id.clone(), &data).await {
tracing::warn!("Acknowledgement Failed: {}", e);
}
} else {
tracing::warn!(
"Acknowledgement could not be called due to missing ack data in context : {}",
&std::any::type_name::<<A as Ack<J>>::Acknowledger>()
);
}
res
};
fut_with_ack.boxed()
}
}