use super::KCPPeer;
use crate::kcp::kcp_module::prelude::Kcp;
use async_lock::MutexGuard;
use futures::future::BoxFuture;
use futures::{ready, FutureExt};
use std::io::{Error, ErrorKind};
use std::pin::Pin;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
pub struct KcpReader<'a> {
peer: &'a KCPPeer,
cache: Vec<u8>,
cache_siz: usize,
cache_pos: usize,
state: KcpReaderState<'a>,
}
enum KcpReaderState<'a> {
Begin,
ReadBuff(BoxFuture<'a, MutexGuard<'a, Kcp>>),
}
impl<'a> From<&'a KCPPeer> for KcpReader<'a> {
fn from(peer: &'a KCPPeer) -> Self {
Self {
peer,
cache: vec![],
cache_siz: 0,
cache_pos: 0,
state: KcpReaderState::Begin,
}
}
}
impl KcpReader<'_> {
#[inline]
fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
loop {
match self.state {
KcpReaderState::Begin => {
if self.cache_pos < self.cache_siz {
let copy_size = (self.cache_siz - self.cache_pos).min(buf.len());
buf[..copy_size].copy_from_slice(
&self.cache[self.cache_pos..self.cache_pos + copy_size],
);
self.cache_pos += copy_size;
return Ok(copy_size).into();
}
self.state = KcpReaderState::ReadBuff(self.peer.kcp.lock().boxed());
}
KcpReaderState::ReadBuff(ref mut size_future) => {
let mut kcp = ready!(size_future.as_mut().poll(cx));
let size = match kcp.peek_size() {
Ok(size) => Ok(size),
Err(err) => {
if self.peer.is_broken_pipe.load(Ordering::Acquire) {
Err(crate::prelude::kcp_module::Error::BrokenPipe)
} else {
Err(err)
}
}
};
match size {
Ok(size) => {
if size > buf.len() {
if self.cache.len() < size {
self.cache.resize(size, 0);
}
match kcp.recv(&mut self.cache) {
Ok(0) => return Ok(0).into(),
Ok(size) => {
self.cache_siz = size;
self.cache_pos = 0;
self.state = KcpReaderState::Begin;
}
Err(err) => return Err(err.into()).into(),
}
} else {
match kcp.recv(buf) {
Ok(size) => return Ok(size).into(),
Err(crate::prelude::kcp_module::Error::RecvQueueEmpty) => {
self.state = KcpReaderState::Begin;
}
Err(crate::prelude::kcp_module::Error::UserBufTooSmall(
size,
)) => {
if self.cache.len() < size {
self.cache.resize(size, 0);
}
match kcp.recv(&mut self.cache) {
Ok(0) => return Ok(0).into(),
Ok(size) => {
self.cache_siz = size;
self.cache_pos = 0;
self.state = KcpReaderState::Begin;
}
Err(err) => return Err(err.into()).into(),
}
}
Err(err) => return Err(err.into()).into(),
}
}
}
Err(crate::prelude::kcp_module::Error::BrokenPipe) => {
return Err(Error::new(
ErrorKind::BrokenPipe,
"kcp peer is broken pipe",
))
.into()
}
Err(_) => {
self.peer.wake.register(cx.waker());
self.state = KcpReaderState::Begin;
return Poll::Pending;
}
}
}
}
}
}
}
impl<'a> futures::AsyncRead for KcpReader<'a> {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let poll = ready!(self.poll_recv(cx, buf));
self.state = KcpReaderState::Begin;
poll.into()
}
}
impl<'a> tokio::io::AsyncRead for KcpReader<'a> {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let res = ready!(self.poll_recv(cx, buf.initialize_unfilled()));
self.state = KcpReaderState::Begin;
match res {
Ok(size) => {
buf.advance(size);
Ok(()).into()
}
Err(err) => Err(err).into(),
}
}
}