1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
pub use async_tools::{AsyncToStream, StreamToAsync};
use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305};
use std::{pin::Pin, task::Poll};
mod async_tools;
pub mod error;
pub mod event;
pub mod transport;
pub mod ws;
use error::NetworkError;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};

pub trait AsyncSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
impl<T> AsyncSocket for T where T: AsyncRead + AsyncWrite + Unpin + Send + 'static {}

pub trait UniversalStream<S, E>:
    Stream<Item = Result<S, E>> + Sink<S, Error = E> + Unpin + Send + 'static
{
}
impl<S, E, T> UniversalStream<S, E> for T where
    T: Stream<Item = Result<S, E>> + Sink<S, Error = E> + Unpin + Send + 'static
{
}

pub async fn stream_forward(
    left: impl UniversalStream<Vec<u8>, NetworkError>,
    right: impl UniversalStream<Vec<u8>, NetworkError>,
) -> Result<(), NetworkError> {
    let (mut left_tx, mut left_rx) = left.split();
    let (mut right_tx, mut right_rx) = right.split();

    loop {
        tokio::select! {
            res = left_rx.next() => {
                match res{
                    Some(v)=>right_tx.send(v?).await?,
                    None=>return Ok(())
                };
            },
            res = right_rx.next() => {
                match res{
                    Some(v)=>left_tx.send(v?).await?,
                    None=>return Ok(())
                };
            },
        }
    }
}

pub struct StreamCrypt {
    inner: Box<dyn UniversalStream<Vec<u8>, NetworkError>>,
    cipher: XChaCha20Poly1305,
    nonce: [u8; 24],
}

impl StreamCrypt {
    pub fn new(
        key: [u8; 32],
        nonce: [u8; 24],
        inner: impl UniversalStream<Vec<u8>, NetworkError>,
    ) -> Self {
        let cipher = XChaCha20Poly1305::new(&key.into());
        Self {
            inner: Box::new(inner),
            cipher,
            nonce,
        }
    }
}

impl Stream for StreamCrypt {
    type Item = Result<Vec<u8>, NetworkError>;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match self.inner.poll_next_unpin(cx)? {
            Poll::Ready(Some(buf)) => Poll::Ready(Some(
                self.cipher
                    .decrypt(&self.nonce.into(), buf.as_ref())
                    .map_err(|e| e.into()),
            )),
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => Poll::Pending,
        }
    }
}
impl Sink<Vec<u8>> for StreamCrypt {
    type Error = NetworkError;

    fn poll_ready(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready_unpin(cx)
    }

    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
        let buf = self.cipher.encrypt(&self.nonce.into(), item.as_ref())?;
        self.inner.start_send_unpin(buf)
    }

    fn poll_flush(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_flush_unpin(cx)
    }

    fn poll_close(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_close_unpin(cx)
    }
}