use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use future::FutureExt;
use futures::channel::{mpsc, oneshot};
use futures::future::join_all;
use futures::prelude::*;
use sharded_slab::Slab;
use tokio::sync::Mutex;
use tokio::time::sleep;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("Send failed: {0}")]
SendFailed(String),
#[error("Receive failed: {0}")]
ReceiveFailed(String),
#[error("Response timeout")]
Timeout,
#[error("Send response failed: {0}")]
SendResponseFailed(String),
#[error("Internal error: {0}")]
InternalError(String),
}
#[derive(Clone)]
pub struct Client<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
tx: Arc<Mutex<mpsc::Sender<(usize, Req)>>>,
pending: Arc<Slab<oneshot::Sender<Resp>>>,
timeout: Duration,
}
pub struct ResponderBuilder<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
req_rx: mpsc::Receiver<(usize, Req)>,
resp_tx: mpsc::Sender<(usize, Resp)>,
}
impl<Req, Resp> ResponderBuilder<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
pub fn build<
Fut: Future<Output = Resp> + Send + 'static,
F: (FnMut(Req) -> Fut) + Send + 'static,
>(
self,
mut handler: F,
) -> Responder<Req, Resp> {
let hdl = move |req: Req| {
let pinned_fut: InnerPinBoxFuture<Resp> = Box::pin(handler(req));
pinned_fut
};
let handler: PinBoxHandler<Req, Resp> = Box::pin(hdl);
Responder {
req_rx: self.req_rx,
resp_tx: self.resp_tx,
handler,
}
}
}
type InnerPinBoxFuture<Resp> = Pin<Box<dyn Future<Output = Resp> + Send>>;
type PinBoxHandler<Req, Resp> = Pin<Box<dyn (FnMut(Req) -> InnerPinBoxFuture<Resp>) + Send>>;
pub struct Responder<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
req_rx: mpsc::Receiver<(usize, Req)>,
resp_tx: mpsc::Sender<(usize, Resp)>,
handler: PinBoxHandler<Req, Resp>,
}
#[derive(Debug, Clone, Copy)]
pub enum ConcurrencyStrategy {
Fixed(usize),
Dynamic(usize, usize),
}
impl<Req, Resp> Client<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
pub async fn request_timeout(&self, req: Req, timeout: Duration) -> Result<Resp, Error> {
let (tx, rx) = oneshot::channel();
let id = match self.pending.insert(tx) {
Some(id) => id,
None => {
return Err(Error::InternalError(
"Failed to insert into pending slab".into(),
));
}
};
self.tx
.lock()
.await
.send((id, req))
.map_err(|e| Error::SendFailed(e.to_string()))
.await?;
let pending = self.pending.clone();
tokio::select! {
resp = rx => {
resp.map_err(|e| Error::ReceiveFailed(e.to_string()))
},
_ = sleep(timeout) => {
pending.remove(id);
Err(Error::Timeout)
}
}
}
pub async fn request(&self, req: Req) -> Result<Resp, Error> {
self.request_timeout(req, self.timeout).await
}
pub async fn request_batch_timeout<ReqSeq>(
&self,
reqs: ReqSeq,
timeout: Duration,
concurrency: usize,
) -> Result<Vec<Resp>, Error>
where
ReqSeq: IntoIterator<Item = Req> + Send + 'static,
{
let req_seq: Vec<_> = reqs.into_iter().collect();
let count = req_seq.len();
let mut ids: Vec<usize> = Vec::with_capacity(count);
let mut rxs: Vec<oneshot::Receiver<Resp>> = Vec::with_capacity(count);
for _ in 0..count {
let (tx, rx) = oneshot::channel();
let id = self.pending.insert(tx).unwrap();
ids.push(id);
rxs.push(rx);
}
stream::iter(ids.clone().into_iter().zip(req_seq.into_iter()))
.for_each_concurrent(concurrency, |v| async {
self.tx.lock().await.send(v).await.unwrap();
})
.await;
tokio::select! {
results = join_all(rxs) => {
let results = results
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| Error::ReceiveFailed(e.to_string()))?;
Ok(results)
},
_ = sleep(timeout) => {
for &id in ids.iter() {
self.pending.remove(id);
}
Err(Error::Timeout)
}
}
}
pub async fn request_batch<ReqSeq>(
&self,
reqs: ReqSeq,
concurrency: usize,
) -> Result<Vec<Resp>, Error>
where
ReqSeq: IntoIterator<Item = Req> + Send + 'static,
{
self.request_batch_timeout(reqs, self.timeout * (concurrency as u32), concurrency)
.await
}
}
impl<Req, Resp> Responder<Req, Resp>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
pub async fn process_requests_with_strategy(
self,
strategy: ConcurrencyStrategy,
) -> Result<(), Error> {
match strategy {
ConcurrencyStrategy::Fixed(concurrency) => {
self.process_requests_fixed(concurrency).await
}
ConcurrencyStrategy::Dynamic(initial, max) => {
self.process_requests_dynamic(initial, max).await
}
}
}
pub async fn process_requests(self) -> Result<(), Error> {
self.process_requests_fixed(16).await
}
pub async fn process_requests_fixed(self, concurrency: usize) -> Result<(), Error> {
let Self {
req_rx,
resp_tx,
mut handler,
} = self;
req_rx
.map(move |(id, req)| {
let hdl = unsafe { handler.as_mut().get_unchecked_mut() };
let fut = hdl(req);
fut.map(move |resp| Ok((id, resp)))
})
.buffer_unordered(concurrency)
.forward(resp_tx)
.map_err(|e| Error::SendResponseFailed(e.to_string()))
.await?;
Ok(())
}
pub async fn process_requests_dynamic(
self,
initial_concurrency: usize,
max_concurrency: usize,
) -> Result<(), Error> {
let Self {
req_rx,
resp_tx,
mut handler,
} = self;
let concurrency = Arc::new(AtomicUsize::new(initial_concurrency));
let concurrency_cloned = concurrency.clone();
req_rx
.map(move |(id, req)| {
let concurrency = Arc::clone(&concurrency);
let start = Instant::now();
let hdl = unsafe { handler.as_mut().get_unchecked_mut() };
let fut = hdl(req);
fut.map(move |resp| {
let dur = start.elapsed();
let currency = concurrency.load(Ordering::Relaxed);
if dur < Duration::from_millis(10) {
if currency < max_concurrency {
let increment = std::cmp::max(1, currency / 4);
let new_value =
std::cmp::min(max_concurrency, currency.saturating_add(increment));
concurrency.store(new_value, Ordering::Relaxed);
}
} else if dur < Duration::from_millis(100) {
if currency > 1 {
concurrency.fetch_sub(1, Ordering::Relaxed);
}
} else {
let decrement = std::cmp::max(2, currency / 4);
let new_value = std::cmp::min(1, currency.saturating_sub(decrement));
concurrency.store(new_value, Ordering::Relaxed);
}
Ok((id, resp))
})
})
.buffer_unordered(concurrency_cloned.load(Ordering::Relaxed))
.forward(resp_tx)
.map_err(|e| Error::SendResponseFailed(e.to_string()))
.await?;
Ok(())
}
}
pub fn channel<Req, Resp>(
buffer: usize,
timeout: Duration,
) -> (Client<Req, Resp>, ResponderBuilder<Req, Resp>)
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
let (req_tx, req_rx) = mpsc::channel(buffer);
let (resp_tx, resp_rx) = mpsc::channel(buffer);
let client = Client {
tx: Arc::new(Mutex::new(req_tx)),
pending: Arc::new(Slab::new()),
timeout,
};
let responder_builder = ResponderBuilder { req_rx, resp_tx };
let pending = client.pending.clone();
tokio::spawn(async move {
resp_rx
.for_each_concurrent(64, |(id, res)| {
let pending = pending.clone();
async move {
if let Some(tx) = pending.take(id) {
let _ = tx.send(res);
}
}
})
.await;
});
(client, responder_builder)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::runtime::Runtime;
#[test]
fn test_single_request() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let response = client.request("Hello, world!".to_string()).await.unwrap();
assert_eq!(response, "Hello, world!");
});
}
#[test]
fn test_batch_requests() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let requests = vec!["Request 1".to_string(), "Request 2".to_string()];
let responses = client.request_batch(requests, 4).await.unwrap();
assert_eq!(
responses,
vec!["Request 1".to_string(), "Request 2".to_string()]
);
});
}
#[test]
fn test_fixed_concurrency() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder
.process_requests_with_strategy(ConcurrencyStrategy::Fixed(16))
.await
.unwrap();
});
let response = client.request("Hello, world!".to_string()).await.unwrap();
assert_eq!(response, "Hello, world!");
});
}
#[test]
fn test_dynamic_concurrency() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder
.process_requests_with_strategy(ConcurrencyStrategy::Dynamic(4, 16))
.await
.unwrap();
});
let response = client.request("Hello, world!".to_string()).await.unwrap();
assert_eq!(response, "Hello, world!");
});
}
#[test]
fn test_single_request_timeout() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let response = client
.request_timeout("Hello, world!".to_string(), Duration::from_secs(1))
.await
.unwrap();
assert_eq!(response, "Hello, world!");
});
}
#[test]
fn test_batch_requests_timeout() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move { req });
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let requests = vec!["Request 1".to_string(), "Request 2".to_string()];
let responses = client
.request_batch_timeout(requests, Duration::from_secs(1), 4)
.await
.unwrap();
assert_eq!(
responses,
vec!["Request 1".to_string(), "Request 2".to_string()]
);
});
}
#[test]
fn test_request_failure() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move {
if req == "fail" {
Err(Error::InternalError("Request failed".into()))
} else {
Ok(req)
}
});
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let result = client.request("fail".to_string()).await.unwrap();
assert!(matches!(result, Err(Error::InternalError(_))));
});
}
#[test]
fn test_request_timeout() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let (client, responder_builder) = channel(64, Duration::from_secs(2));
let responder = responder_builder.build(|req: String| async move {
tokio::time::sleep(Duration::from_secs(3)).await;
req
});
tokio::spawn(async move {
responder.process_requests().await.unwrap();
});
let result = client
.request_timeout("Hello, world!".to_string(), Duration::from_secs(1))
.await;
assert!(matches!(result, Err(Error::Timeout)));
});
}
}