ppaass_v3_common/connection/
mod.rs

1mod agent;
2mod codec;
3mod proxy;
4use crate::connection::codec::CryptoLengthDelimitedCodec;
5use crate::error::CommonError;
6pub use agent::*;
7use futures_util::{Sink, SinkExt, Stream, StreamExt};
8use ppaass_protocol::Encryption;
9pub use proxy::*;
10use std::io::Error;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::net::TcpStream;
17use tokio::pin;
18use tokio_util::bytes::BytesMut;
19use tokio_util::codec::Framed;
20use tokio_util::io::{SinkWriter, StreamReader};
21pub struct CryptoLengthDelimitedFramed<T>
22where
23    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
24{
25    crypto_length_delimited_framed: Framed<T, CryptoLengthDelimitedCodec>,
26}
27
28impl<T> CryptoLengthDelimitedFramed<T>
29where
30    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
31{
32    pub fn new(
33        tcp_stream: T,
34        decoder_encryption: Arc<Encryption>,
35        encoder_encryption: Arc<Encryption>,
36        frame_buffer_size: usize,
37    ) -> Self {
38        Self {
39            crypto_length_delimited_framed: Framed::with_capacity(
40                tcp_stream,
41                CryptoLengthDelimitedCodec::new(decoder_encryption, encoder_encryption),
42                frame_buffer_size,
43            ),
44        }
45    }
46}
47
48impl<T> Stream for CryptoLengthDelimitedFramed<T>
49where
50    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
51{
52    type Item = Result<BytesMut, CommonError>;
53    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54        self.get_mut()
55            .crypto_length_delimited_framed
56            .poll_next_unpin(cx)
57    }
58}
59impl<T> Sink<&[u8]> for CryptoLengthDelimitedFramed<T>
60where
61    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
62{
63    type Error = CommonError;
64    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
65        self.get_mut()
66            .crypto_length_delimited_framed
67            .poll_ready_unpin(cx)
68    }
69    fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), CommonError> {
70        self.get_mut()
71            .crypto_length_delimited_framed
72            .start_send_unpin(BytesMut::from(item))
73    }
74    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
75        self.get_mut()
76            .crypto_length_delimited_framed
77            .poll_flush_unpin(cx)
78    }
79    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
80        self.get_mut()
81            .crypto_length_delimited_framed
82            .poll_close_unpin(cx)
83    }
84}
85
86pub struct FramedConnection<S> {
87    state: S,
88    socket_address: SocketAddr,
89    frame_buffer_size: usize,
90}
91
92impl<S> FramedConnection<S> {
93    pub fn new(state: S, socket_address: SocketAddr, frame_buffer_size: usize) -> Self {
94        Self {
95            state,
96            socket_address,
97            frame_buffer_size,
98        }
99    }
100}
101
102impl AsyncRead
103    for FramedConnection<SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>>
104{
105    fn poll_read(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &mut ReadBuf<'_>,
109    ) -> Poll<std::io::Result<()>> {
110        let crypto_tcp_read_write = &mut self.get_mut().state;
111        pin!(crypto_tcp_read_write);
112        crypto_tcp_read_write.poll_read(cx, buf)
113    }
114}
115
116impl AsyncWrite
117    for FramedConnection<SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>>
118{
119    fn poll_write(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        buf: &[u8],
123    ) -> Poll<Result<usize, Error>> {
124        let crypto_tcp_read_write = &mut self.get_mut().state;
125        pin!(crypto_tcp_read_write);
126        crypto_tcp_read_write.poll_write(cx, buf)
127    }
128    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
129        let crypto_tcp_read_write = &mut self.get_mut().state;
130        pin!(crypto_tcp_read_write);
131        crypto_tcp_read_write.poll_flush(cx)
132    }
133    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
134        let crypto_tcp_read_write = &mut self.get_mut().state;
135        pin!(crypto_tcp_read_write);
136        crypto_tcp_read_write.poll_shutdown(cx)
137    }
138}