use std::io;
use std::fmt;
use std::mem;
use futures::{Async, Future, Poll};
use futures::sync::{BiLock, BiLockAcquired, BiLockAcquire};
use tokio_io::{AsyncRead, AsyncWrite};
use frame;
use {Buf, Encode, Decode, ReadFramed, WriteFramed};
struct Shared<S> {
socket: S,
done: bool,
}
pub struct ReadBuf<S> {
pub in_buf: Buf,
shared: BiLock<Shared<S>>,
}
pub struct WriteBuf<S> {
pub out_buf: Buf,
shared: BiLock<Shared<S>>,
}
pub struct WriteRaw<S> {
io: BiLockAcquired<Shared<S>>,
}
pub struct FutureWriteRaw<S>(WriteRawFutState<S>);
enum WriteRawFutState<S> {
Flushing(WriteBuf<S>),
Locking(BiLockAcquire<Shared<S>>),
Done,
}
pub fn create<S>(in_buf: Buf, out_buf: Buf, socket: S, done: bool)
-> (WriteBuf<S>, ReadBuf<S>)
{
let (a, b) = BiLock::new(Shared {
socket: socket,
done: done,
});
return (
WriteBuf {
out_buf: in_buf,
shared: b,
},
ReadBuf {
in_buf: out_buf,
shared: a,
});
}
impl<S> ReadBuf<S> {
pub fn read(&mut self) -> Result<usize, io::Error>
where S: AsyncRead
{
if let Async::Ready(ref mut s) = self.shared.poll_lock() {
match self.in_buf.read_from(&mut s.socket) {
Ok(0) => {
s.done = true;
Ok(0)
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(ref e)
if e.kind() == io::ErrorKind::BrokenPipe ||
e.kind() == io::ErrorKind::ConnectionReset
=> {
s.done = true;
Ok(0)
}
result => result,
}
} else {
Ok(0)
}
}
pub fn done(&self) -> bool {
if let Async::Ready(ref mut s) = self.shared.poll_lock() {
return s.done;
} else {
return false;
}
}
pub fn framed<D: Decode>(self, codec: D) -> ReadFramed<S, D> {
frame::read_framed(self, codec)
}
}
impl<S> WriteBuf<S> {
pub fn flush(&mut self) -> Result<(), io::Error>
where S: AsyncWrite
{
if let Async::Ready(ref mut s) = self.shared.poll_lock() {
loop {
if self.out_buf.len() == 0 {
break;
}
match self.out_buf.write_to(&mut s.socket) {
Ok(0) => break,
Ok(_) => continue,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(ref e)
if e.kind() == io::ErrorKind::BrokenPipe ||
e.kind() == io::ErrorKind::ConnectionReset
=> {
s.done = true;
break;
}
Err(e) => {
return Err(e);
},
}
}
match s.socket.flush() {
Ok(()) => Ok(()),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(()),
Err(ref e) if e.kind() == io::ErrorKind::BrokenPipe ||
e.kind() == io::ErrorKind::ConnectionReset
=> {
s.done = true;
Ok(())
}
Err(e) => Err(e),
}
} else {
Ok(())
}
}
pub fn done(&self) -> bool {
if let Async::Ready(ref mut s) = self.shared.poll_lock() {
return s.done;
} else {
return false;
}
}
pub fn borrow_raw(self) -> FutureWriteRaw<S> {
if self.out_buf.len() == 0 {
FutureWriteRaw(WriteRawFutState::Locking(self.shared.lock()))
} else {
FutureWriteRaw(WriteRawFutState::Flushing(self))
}
}
pub fn framed<E: Encode>(self, codec: E) -> WriteFramed<S, E> {
frame::write_framed(self, codec)
}
}
impl<S> WriteRaw<S> {
pub fn into_buf(self) -> WriteBuf<S> {
WriteBuf {
out_buf: Buf::new(),
shared: self.io.unlock(),
}
}
pub fn get_ref(&self) -> &S {
&self.io.socket
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.io.socket
}
}
impl<S: AsyncWrite> Future for FutureWriteRaw<S> {
type Item = WriteRaw<S>;
type Error = io::Error;
fn poll(&mut self) -> Poll<WriteRaw<S>, io::Error> {
use self::WriteRawFutState::*;
self.0 = match mem::replace(&mut self.0, Done) {
Flushing(mut buf) => {
buf.flush()?;
if buf.out_buf.len() == 0 {
let mut lock = buf.shared.lock();
match lock.poll().expect("lock never fails") {
Async::Ready(s) => {
return Ok(Async::Ready(WriteRaw { io: s }));
}
Async::NotReady => {}
}
Locking(lock)
} else {
Flushing(buf)
}
}
Locking(mut f) => {
match f.poll().expect("lock never fails") {
Async::Ready(s) => {
return Ok(Async::Ready(WriteRaw { io: s }));
}
Async::NotReady => {}
}
Locking(f)
}
Done => panic!("future polled after completion"),
};
return Ok(Async::NotReady);
}
}
impl<S: AsyncWrite> io::Write for WriteRaw<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.socket.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.io.socket.flush()
}
}
impl<S: AsyncWrite> AsyncWrite for WriteRaw<S> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.io.socket.shutdown()
}
}
impl<S> fmt::Debug for ReadBuf<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ReadBuf {{ in: {}b }}", self.in_buf.len())
}
}
impl<S> fmt::Debug for WriteBuf<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WriteBuf {{ out: {}b }}", self.out_buf.len())
}
}