use futures::future::BoxFuture;
use futures::{FutureExt, ready};
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[derive(Default)]
pub(crate) struct InnerWrapper<T: 'static> {
pub(crate) buffer: Arc<Mutex<ContentWrapper<T>>>,
lock_state: LockState<T>,
}
impl<T: Send> InnerWrapper<T> {
pub(crate) fn clone_buffer(&self) -> Arc<Mutex<ContentWrapper<T>>> {
Arc::clone(&self.buffer)
}
pub(crate) fn cloned_buffer(&self) -> Self {
assert!(matches!(self.lock_state, LockState::Idle));
InnerWrapper {
buffer: self.clone_buffer(),
lock_state: LockState::Idle,
}
}
pub(crate) fn poll_guard_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match &mut self.lock_state {
LockState::Idle => {
let Ok(guard) = self.buffer.clone().try_lock_owned() else {
self.lock_state =
LockState::TryingToLock(self.buffer.clone().lock_owned().boxed());
return Poll::Pending;
};
self.lock_state = LockState::Locked(guard);
cx.waker().wake_by_ref();
Poll::Ready(())
}
LockState::TryingToLock(lock_fut) => {
let guard = ready!(lock_fut.as_mut().poll(cx));
self.lock_state = LockState::Locked(guard);
cx.waker().wake_by_ref();
Poll::Pending
}
LockState::Locked(_) => Poll::Ready(()),
}
}
pub(crate) fn guard(&mut self) -> Option<&mut OwnedMutexGuard<ContentWrapper<T>>> {
match &mut self.lock_state {
LockState::Locked(guard) => Some(guard),
_ => None,
}
}
pub(crate) fn transition_to_idle(&mut self) {
self.lock_state = LockState::Idle
}
}
#[derive(Default)]
pub(crate) enum LockState<T> {
#[default]
Idle,
TryingToLock(BoxFuture<'static, OwnedMutexGuard<ContentWrapper<T>>>),
Locked(OwnedMutexGuard<ContentWrapper<T>>),
}
#[derive(Default)]
pub struct ContentWrapper<T> {
pub(crate) content: T,
pub(crate) waker: Option<Waker>,
}
impl<T> ContentWrapper<T> {
pub fn into_content(self) -> T {
self.content
}
pub fn content(&self) -> &T {
&self.content
}
pub(crate) fn take_content(&mut self) -> T
where
T: Default,
{
mem::take(&mut self.content)
}
}
impl ContentWrapper<Vec<u8>> {
pub fn take_at_most(&mut self, count: usize) -> Vec<u8> {
if self.content.is_empty() {
return Vec::new();
}
if self.content.len() <= count {
return self.take_content();
}
let remaining = self.content.split_off(count);
mem::replace(&mut self.content, remaining)
}
}
impl<T> LockState<T> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn take_at_most() {
let mut empty: ContentWrapper<Vec<u8>> = ContentWrapper::default();
let mut non_empty: ContentWrapper<Vec<u8>> = ContentWrapper {
content: vec![1, 2, 3, 4, 5],
..Default::default()
};
assert_eq!(empty.take_at_most(0), Vec::<u8>::new());
assert_eq!(empty.take_at_most(1), Vec::<u8>::new());
assert_eq!(empty.take_at_most(42), Vec::<u8>::new());
assert_eq!(non_empty.take_at_most(0), Vec::<u8>::new());
assert_eq!(non_empty.take_at_most(1), vec![1]);
assert_eq!(non_empty.take_at_most(3), vec![2, 3, 4]);
assert_eq!(non_empty.take_at_most(42), vec![5]);
let mut non_empty: ContentWrapper<Vec<u8>> = ContentWrapper {
content: vec![1, 2, 3, 4, 5],
..Default::default()
};
assert_eq!(non_empty.take_at_most(100), vec![1, 2, 3, 4, 5]);
}
}