narrowlink_network/
lib.rs

1pub use async_tools::{AsyncToStream, StreamToAsync};
2use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305};
3use chunkio::ChunkIO;
4use std::{io, pin::Pin, task::Poll};
5mod async_tools;
6pub mod error;
7pub mod event;
8pub mod p2p;
9pub mod transport;
10pub mod ws;
11use error::NetworkError;
12use futures_util::{Sink, SinkExt, Stream, StreamExt};
13use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
14
15pub trait AsyncSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
16impl<T> AsyncSocket for T where T: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
17
18pub trait UniversalStream<S, E>:
19    Stream<Item = Result<S, E>> + Sink<S, Error = E> + Unpin + Send + 'static
20{
21}
22impl<S, E, T> UniversalStream<S, E> for T where
23    T: Stream<Item = Result<S, E>> + Sink<S, Error = E> + Unpin + Send + 'static
24{
25}
26
27pub async fn stream_forward(
28    left: impl UniversalStream<Vec<u8>, NetworkError>,
29    right: impl UniversalStream<Vec<u8>, NetworkError>,
30) -> Result<(), NetworkError> {
31    let (mut left_tx, mut left_rx) = left.split();
32    let (mut right_tx, mut right_rx) = right.split();
33
34    loop {
35        tokio::select! {
36            res = left_rx.next() => {
37                match res{
38                    Some(v)=>right_tx.send(v?).await?,
39                    None=>{
40                        let _ = left_tx.close().await;
41                        let _ = right_tx.close().await;
42                        return Ok(())
43                    }
44                };
45            },
46            res = right_rx.next() => {
47                match res{
48                    Some(v)=>left_tx.send(v?).await?,
49                    None=>{
50                        let _ = left_tx.close().await;
51                        let _ = right_tx.close().await;
52                        return Ok(())
53                    }
54                };
55            },
56        }
57    }
58}
59
60pub async fn async_forward(
61    left: impl AsyncSocket,
62    right: impl AsyncSocket,
63) -> Result<(), NetworkError> {
64    let (mut left_rx, mut left_tx) = tokio::io::split(left);
65    let (mut right_rx, mut right_tx) = tokio::io::split(right);
66    loop {
67        tokio::select! {
68            res = tokio::io::copy(&mut left_rx, &mut right_tx) => {
69                if res? == 0 {
70                    let _ = left_tx.shutdown().await;
71                    let _ = right_tx.shutdown().await;
72                    return Ok(())
73                }
74            },
75            res = tokio::io::copy(&mut right_rx, &mut left_tx) => {
76                if res? == 0 {
77                    let _ = left_tx.shutdown().await;
78                    let _ = right_tx.shutdown().await;
79                    return Ok(())
80                }
81            },
82        }
83    }
84}
85
86pub struct AsyncSocketCrypt {
87    inner: ChunkIO<Box<dyn AsyncSocket>>,
88    cipher: XChaCha20Poly1305,
89    nonce: [u8; 24],
90}
91
92impl AsyncSocketCrypt {
93    pub async fn new(key: [u8; 32], nonce: [u8; 24], inner: Box<dyn AsyncSocket>) -> Self {
94        let cipher = XChaCha20Poly1305::new(&key.into());
95        Self {
96            inner: ChunkIO::new(inner),
97            cipher,
98            nonce,
99        }
100    }
101}
102
103impl Stream for AsyncSocketCrypt {
104    type Item = Result<Vec<u8>, std::io::Error>;
105
106    fn poll_next(
107        mut self: Pin<&mut Self>,
108        cx: &mut std::task::Context<'_>,
109    ) -> Poll<Option<Self::Item>> {
110        match self
111            .inner
112            .poll_next_unpin(cx)
113            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
114        {
115            Poll::Ready(Some(chunk)) => Poll::Ready(Some(
116                self.cipher
117                    .decrypt(&self.nonce.into(), chunk.as_ref())
118                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
119            )),
120            Poll::Pending => Poll::Pending,
121            _ => Poll::Ready(None),
122        }
123    }
124}
125impl Sink<Vec<u8>> for AsyncSocketCrypt {
126    type Error = std::io::Error;
127
128    fn poll_ready(
129        mut self: Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131    ) -> Poll<Result<(), Self::Error>> {
132        self.inner
133            .poll_ready_unpin(cx)
134            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
135    }
136
137    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
138        let buf = self
139            .cipher
140            .encrypt(&self.nonce.into(), item.as_ref())
141            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
142        self.inner
143            .start_send_unpin(buf)
144            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
145    }
146
147    fn poll_flush(
148        mut self: Pin<&mut Self>,
149        cx: &mut std::task::Context<'_>,
150    ) -> Poll<Result<(), Self::Error>> {
151        self.inner
152            .poll_flush_unpin(cx)
153            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
154    }
155
156    fn poll_close(
157        mut self: Pin<&mut Self>,
158        cx: &mut std::task::Context<'_>,
159    ) -> Poll<Result<(), Self::Error>> {
160        self.inner
161            .poll_close_unpin(cx)
162            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
163    }
164}
165
166impl AsyncRead for AsyncSocketCrypt {
167    fn poll_read(
168        mut self: Pin<&mut Self>,
169        cx: &mut std::task::Context<'_>,
170        buf: &mut tokio::io::ReadBuf<'_>,
171    ) -> Poll<std::io::Result<()>> {
172        match Pin::new(&mut self).poll_next(cx)? {
173            Poll::Ready(Some(item)) => {
174                let b = self
175                    .cipher
176                    .decrypt(&self.nonce.into(), item.as_slice())
177                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
178                buf.put_slice(&b); // todo: fix
179                Poll::Ready(Ok(()))
180            }
181            Poll::Ready(None) => Poll::Ready(Ok(())),
182            Poll::Pending => Poll::Pending,
183        }
184    }
185}
186
187impl AsyncWrite for AsyncSocketCrypt {
188    fn poll_write(
189        mut self: Pin<&mut Self>,
190        cx: &mut std::task::Context<'_>,
191        buf: &[u8],
192    ) -> Poll<Result<usize, std::io::Error>> {
193        match self.as_mut().poll_ready(cx) {
194            Poll::Ready(Ok(())) => {
195                let cipher_text = self
196                    .cipher
197                    .encrypt(&self.nonce.into(), buf)
198                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
199                match self.start_send_unpin(cipher_text) {
200                    Ok(()) => Poll::Ready(Ok(buf.len())),
201                    Err(e) => Poll::Ready(Err(e)),
202                }
203            }
204            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
205            Poll::Pending => Poll::Pending,
206        }
207    }
208
209    fn poll_flush(
210        self: Pin<&mut Self>,
211        cx: &mut std::task::Context<'_>,
212    ) -> Poll<Result<(), std::io::Error>> {
213        <dyn Sink<Vec<u8>, Error = std::io::Error>>::poll_flush(self, cx)
214    }
215
216    fn poll_shutdown(
217        self: Pin<&mut Self>,
218        cx: &mut std::task::Context<'_>,
219    ) -> Poll<Result<(), std::io::Error>> {
220        self.poll_close(cx)
221    }
222}