use std::{
future::{Future, IntoFuture},
pin::Pin,
sync::Arc,
time::Instant,
};
use futures::{
future::BoxFuture,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use crate::{
adaptors::throttle::{channel, ChatIdHash, FreezeUntil, RequestLock},
errors::AsResponseParameters,
requests::{HasPayload, Output, Request},
};
#[must_use = "Requests are lazy and do nothing unless sent"]
#[derive(Clone)]
pub struct ThrottlingRequest<R: HasPayload> {
pub(super) request: Arc<R>,
pub(super) chat_id: fn(&R::Payload) -> ChatIdHash,
pub(super) worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
}
#[pin_project::pin_project]
pub struct ThrottlingSend<R: Request>(#[pin] BoxFuture<'static, Result<Output<R>, R::Err>>);
enum ShareableRequest<R> {
Shared(Arc<R>),
Owned(Option<R>),
}
impl<R: HasPayload + Clone> HasPayload for ThrottlingRequest<R> {
type Payload = R::Payload;
fn payload_mut(&mut self) -> &mut Self::Payload {
Arc::make_mut(&mut self.request).payload_mut()
}
fn payload_ref(&self) -> &Self::Payload {
self.request.payload_ref()
}
}
impl<R> Request for ThrottlingRequest<R>
where
R: Request + Clone + Send + Sync + 'static, R::Err: AsResponseParameters + Send,
Output<R>: Send,
{
type Err = R::Err;
type Send = ThrottlingSend<R>;
type SendRef = ThrottlingSend<R>;
fn send(self) -> Self::Send {
let chat = (self.chat_id)(self.payload_ref());
let request = match Arc::try_unwrap(self.request) {
Ok(owned) => ShareableRequest::Owned(Some(owned)),
Err(shared) => ShareableRequest::Shared(shared),
};
let fut = send(request, chat, self.worker);
ThrottlingSend(Box::pin(fut))
}
fn send_ref(&self) -> Self::SendRef {
let chat = (self.chat_id)(self.payload_ref());
let request = ShareableRequest::Shared(Arc::clone(&self.request));
let fut = send(request, chat, self.worker.clone());
ThrottlingSend(Box::pin(fut))
}
}
impl<R> IntoFuture for ThrottlingRequest<R>
where
R: Request + Clone + Send + Sync + 'static,
R::Err: AsResponseParameters + Send,
Output<R>: Send,
{
type Output = Result<Output<Self>, <Self as Request>::Err>;
type IntoFuture = <Self as Request>::Send;
fn into_future(self) -> Self::IntoFuture {
self.send()
}
}
impl<R: Request> Future for ThrottlingSend<R>
where
R::Err: AsResponseParameters,
{
type Output = Result<Output<R>, R::Err>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().project().0.poll(cx)
}
}
async fn send<R>(
mut request: ShareableRequest<R>,
chat: ChatIdHash,
worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
) -> Result<Output<R>, R::Err>
where
R: Request + Send + Sync + 'static,
R::Err: AsResponseParameters + Send,
Output<R>: Send,
{
loop {
let (lock, wait) = channel();
if worker.send((chat, lock)).await.is_err() {
log::error!("Worker dropped the queue before sending all requests");
let res = match &mut request {
ShareableRequest::Shared(shared) => shared.send_ref().await,
ShareableRequest::Owned(owned) => owned.take().unwrap().await,
};
return res;
};
let (retry, freeze) = wait.await;
let res = match (retry, &mut request) {
(true, request) => {
let request = match request {
ShareableRequest::Shared(shared) => &**shared,
ShareableRequest::Owned(owned) => owned.as_ref().unwrap(),
};
request.send_ref().await
}
(false, ShareableRequest::Shared(shared)) => shared.send_ref().await,
(false, ShareableRequest::Owned(owned)) => owned.take().unwrap().await,
};
let retry_after = res.as_ref().err().and_then(<_>::retry_after);
if let Some(retry_after) = retry_after {
let after = retry_after.duration();
let until = Instant::now() + after;
let _ = freeze.send(FreezeUntil { until, after, chat }).await;
if retry {
log::warn!("Freezing, before retrying: {retry_after:?}");
tokio::time::sleep_until(until.into()).await;
}
}
match res {
Err(_) if retry && retry_after.is_some() => continue,
res => break res,
};
}
}