use std::{ops::Deref, sync::Arc};
use concurrent_queue::{ConcurrentQueue, PushError};
use onetime::{channel, RecvError, Sender};
use thiserror::Error;
#[derive(Debug)]
struct Waiters<T>(ConcurrentQueue<Sender<T>>);
impl<T> Default for Waiters<T> {
fn default() -> Self {
Self(ConcurrentQueue::unbounded())
}
}
impl<T> Deref for Waiters<T> {
type Target = ConcurrentQueue<Sender<T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct Customer<T> {
waiters: Arc<Waiters<T>>,
}
impl<T> Clone for Customer<T> {
fn clone(&self) -> Self {
Self { waiters: self.waiters.clone() }
}
}
impl<T> Customer<T> {
pub async fn request(&self) -> Result<T, RequestError> {
let (tx, rx) = channel();
self.waiters.push(tx)?;
rx.recv().await.map_err(Into::into)
}
}
#[derive(Debug)]
pub struct Vendor<T> {
waiters: Arc<Waiters<T>>,
}
impl<T> Clone for Vendor<T> {
fn clone(&self) -> Self {
Self { waiters: self.waiters.clone() }
}
}
impl<T> Default for Vendor<T> {
fn default() -> Self {
Self { waiters: Arc::default() }
}
}
impl<T> Vendor<T> {
pub fn new() -> Self {
Self::default()
}
pub fn customer(&self) -> Customer<T> {
Customer { waiters: self.waiters.clone() }
}
pub fn send(&self, resource: T)
where
T: Clone,
{
if self.waiters_count() == 1 {
if let Ok(waiter) = self.waiters.pop() {
let _ = waiter.send(resource);
}
} else {
for _ in 0..self.waiters_count() - 1 {
if let Ok(waiter) = self.waiters.pop() {
let _ = waiter.send(resource.clone());
}
}
if let Ok(waiter) = self.waiters.pop() {
let _ = waiter.send(resource);
}
}
}
pub fn waiters_count(&self) -> usize {
self.waiters.len()
}
pub fn has_waiters(&self) -> bool {
self.waiters_count() > 0
}
}
#[derive(Debug, Error)]
#[error("failed to request resource")]
pub enum RequestError {
Push,
Recv,
}
impl<T> From<PushError<T>> for RequestError {
fn from(_: PushError<T>) -> Self {
Self::Push
}
}
impl From<RecvError> for RequestError {
fn from(_: RecvError) -> Self {
Self::Recv
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
smol::block_on(async move {
let vendor = Vendor::new();
let customer1 = vendor.customer();
let customer2 = vendor.customer();
let t1 = smol::spawn(async move {
assert!(matches!(customer1.request().await, Ok("ok")));
});
let t2 = smol::spawn(async move {
assert!(matches!(customer2.request().await, Ok("ok")));
});
let t3 = smol::spawn(async move {
vendor.send("ok");
});
t1.await;
t2.await;
t3.await;
});
}
}