ppaass_v3_common/connection/
mod.rs1mod agent;
2mod codec;
3mod proxy;
4use crate::connection::codec::CryptoLengthDelimitedCodec;
5use crate::error::CommonError;
6pub use agent::*;
7use futures_util::{Sink, SinkExt, Stream, StreamExt};
8use ppaass_protocol::Encryption;
9pub use proxy::*;
10use std::io::Error;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::net::TcpStream;
17use tokio::pin;
18use tokio_util::bytes::BytesMut;
19use tokio_util::codec::Framed;
20use tokio_util::io::{SinkWriter, StreamReader};
21pub struct CryptoLengthDelimitedFramed<T>
22where
23 T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
24{
25 crypto_length_delimited_framed: Framed<T, CryptoLengthDelimitedCodec>,
26}
27
28impl<T> CryptoLengthDelimitedFramed<T>
29where
30 T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
31{
32 pub fn new(
33 tcp_stream: T,
34 decoder_encryption: Arc<Encryption>,
35 encoder_encryption: Arc<Encryption>,
36 frame_buffer_size: usize,
37 ) -> Self {
38 Self {
39 crypto_length_delimited_framed: Framed::with_capacity(
40 tcp_stream,
41 CryptoLengthDelimitedCodec::new(decoder_encryption, encoder_encryption),
42 frame_buffer_size,
43 ),
44 }
45 }
46}
47
48impl<T> Stream for CryptoLengthDelimitedFramed<T>
49where
50 T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
51{
52 type Item = Result<BytesMut, CommonError>;
53 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 self.get_mut()
55 .crypto_length_delimited_framed
56 .poll_next_unpin(cx)
57 }
58}
59impl<T> Sink<&[u8]> for CryptoLengthDelimitedFramed<T>
60where
61 T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
62{
63 type Error = CommonError;
64 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
65 self.get_mut()
66 .crypto_length_delimited_framed
67 .poll_ready_unpin(cx)
68 }
69 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), CommonError> {
70 self.get_mut()
71 .crypto_length_delimited_framed
72 .start_send_unpin(BytesMut::from(item))
73 }
74 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
75 self.get_mut()
76 .crypto_length_delimited_framed
77 .poll_flush_unpin(cx)
78 }
79 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), CommonError>> {
80 self.get_mut()
81 .crypto_length_delimited_framed
82 .poll_close_unpin(cx)
83 }
84}
85
86pub struct FramedConnection<S> {
87 state: S,
88 socket_address: SocketAddr,
89 frame_buffer_size: usize,
90}
91
92impl<S> FramedConnection<S> {
93 pub fn new(state: S, socket_address: SocketAddr, frame_buffer_size: usize) -> Self {
94 Self {
95 state,
96 socket_address,
97 frame_buffer_size,
98 }
99 }
100}
101
102impl AsyncRead
103 for FramedConnection<SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>>
104{
105 fn poll_read(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf<'_>,
109 ) -> Poll<std::io::Result<()>> {
110 let crypto_tcp_read_write = &mut self.get_mut().state;
111 pin!(crypto_tcp_read_write);
112 crypto_tcp_read_write.poll_read(cx, buf)
113 }
114}
115
116impl AsyncWrite
117 for FramedConnection<SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>>
118{
119 fn poll_write(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 buf: &[u8],
123 ) -> Poll<Result<usize, Error>> {
124 let crypto_tcp_read_write = &mut self.get_mut().state;
125 pin!(crypto_tcp_read_write);
126 crypto_tcp_read_write.poll_write(cx, buf)
127 }
128 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
129 let crypto_tcp_read_write = &mut self.get_mut().state;
130 pin!(crypto_tcp_read_write);
131 crypto_tcp_read_write.poll_flush(cx)
132 }
133 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
134 let crypto_tcp_read_write = &mut self.get_mut().state;
135 pin!(crypto_tcp_read_write);
136 crypto_tcp_read_write.poll_shutdown(cx)
137 }
138}