1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
use futures::Future;
use futures::lock::Mutex;
use futures::prelude::{AsyncRead, AsyncWrite};
use futures::stream::Stream;
use futures::task::Context;
use std::clone::Clone;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use crate::State;

/// The asynchronous socket that mimics the network stream.
#[derive(Debug)]
pub struct Socket {
    /// A central socket state which is shared among all the cloned instances.
    pub state: Arc<Mutex<State>>,
}

impl Socket {
    /// Returns a new instance with a specific chunk size.
    pub fn with_chunk_size(csize: usize) -> Self {
        Self {
            state: Arc::new(Mutex::new(State::with_chunk_size(csize))),
        }
    }
}

impl Default for Socket {
    fn default() -> Self {
        Self {
            state: Arc::new(Mutex::new(State::default())),
        }
    }
}

impl Clone for Socket {
    fn clone(&self) -> Self {
        Self {
            state: self.state.clone(),
        }
    }
}

impl AsyncRead for Socket {
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
        match Pin::new(&mut self.state.lock()).poll(cx) {
            Poll::Ready(mut state) => {

                let dsize = state.buf.len();
                let bsize = buf.len();
                if dsize < bsize {
                    return Poll::Pending;
                }

                let data = state.buf.drain(0..bsize).as_slice().to_vec();
                buf[..bsize].copy_from_slice(&data);
                Poll::Ready(Ok(bsize))
            },
            Poll::Pending => {
                Poll::Pending
            },
        }
    }
}

impl AsyncWrite for Socket {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, data: &[u8]) -> Poll<std::io::Result<usize>> {
        match Pin::new(&mut self.state.lock()).poll(cx) {
            Poll::Ready(mut state) => {
                state.buf.append(&mut data.to_vec());
                state.wake();
                Poll::Ready(Ok(state.buf.len()))
            },
            Poll::Pending => {
                Poll::Pending
            },
        }
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }
}

impl Stream for Socket {
    type Item = Vec<u8>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        match Pin::new(&mut self.state.lock()).poll(cx) {
            Poll::Ready(mut state) => {
                state.waker = Some(cx.waker().clone());
                
                let max = std::cmp::min(state.chunk_size, state.buf.len());
                let data = state.buf.drain(0..max).as_slice().to_vec();
                if data.is_empty() {
                    Poll::Pending
                } else {
                    Poll::Ready(Some(data))
                }
            },
            _ => {
                Poll::Pending
            },
        }
    }
}