use http::Uri;
use hyper::client::connect::Connection;
use std::{
cmp::min,
collections::VecDeque,
future::Future,
io::Result as IoResult,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tower_service::Service;
pub(crate) fn chan() -> (SimStream, SimStream) {
let one = Arc::new(Mutex::new(BufferState::new()));
let two = Arc::new(Mutex::new(BufferState::new()));
let left = SimStream {
read: ReadHalf { buffer: one.clone() },
write: WriteHalf { buffer: two.clone() },
};
let right = SimStream {
read: ReadHalf { buffer: two },
write: WriteHalf { buffer: one },
};
(left, right)
}
#[derive(Clone)]
pub struct Connector {
pub inner: SimStream,
}
impl Service<Uri> for Connector {
type Response = SimStream;
type Error = std::io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Uri) -> Self::Future {
let inner = self.inner.clone();
Box::pin(async move { Ok(inner) })
}
}
impl Connection for SimStream {
fn connected(&self) -> hyper::client::connect::Connected {
hyper::client::connect::Connected::new()
}
}
#[derive(Debug, Clone)]
pub struct SimStream {
read: ReadHalf,
write: WriteHalf,
}
impl AsyncWrite for SimStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.write).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.write).poll_shutdown(cx)
}
}
impl AsyncRead for SimStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.read).poll_read(cx, buf)
}
}
#[derive(Debug, Clone)]
pub struct BufferState {
buffer: VecDeque<u8>,
read_waker: Option<Waker>,
}
impl BufferState {
fn new() -> Self {
BufferState {
buffer: VecDeque::new(),
read_waker: None,
}
}
fn write(&mut self, buf: &[u8]) {
for b in buf {
self.buffer.push_front(*b)
}
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
fn read(&mut self, to_buf: &mut [u8]) -> usize {
let bytes_to_read = min(to_buf.len(), self.buffer.len());
for i in 0..bytes_to_read {
to_buf[i] = self.buffer.pop_back().unwrap();
}
bytes_to_read
}
}
#[derive(Debug, Clone)]
pub struct WriteHalf {
buffer: Arc<Mutex<BufferState>>,
}
impl AsyncWrite for WriteHalf {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let mut write_to = self
.buffer
.lock()
.expect("Lock was poisoned when acquiring buffer lock for WriteHalf");
write_to.write(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Clone)]
pub struct ReadHalf {
buffer: Arc<Mutex<BufferState>>,
}
impl AsyncRead for ReadHalf {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
let mut read_from = self
.buffer
.lock()
.expect("Lock was poisoned when acquiring buffer lock for ReadHalf");
let bytes_read = read_from.read(buf.initialize_unfilled());
if bytes_read == 0 {
read_from.read_waker = Some(cx.waker().clone());
Poll::Pending
} else {
buf.advance(bytes_read);
Poll::Ready(Ok(()))
}
}
}
#[cfg(test)]
mod tests {
use super::chan;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn ends_should_talk_to_each_other() {
let (mut client, mut server) = chan();
client.write_all(b"Ping").await.expect("Write should succeed");
let mut read_on_server = [0_u8; 4];
server
.read_exact(&mut read_on_server)
.await
.expect("Read should succeed");
assert_eq!(&read_on_server, b"Ping");
server.write_all(b"Pong").await.expect("Write should succeed");
let mut read_on_client = [0_u8; 4];
client
.read_exact(&mut read_on_client)
.await
.expect("Read should succeed");
assert_eq!(&read_on_client, b"Pong");
}
}