ppaass_v3_common/connection/proxy/
mod.rs

1mod pool;
2use crate::connection::codec::{
3    HandshakeRequestEncoder, HandshakeResponseDecoder, TunnelControlResponseRequestCodec,
4};
5use crate::connection::CryptoLengthDelimitedFramed;
6use crate::error::CommonError;
7use crate::user::repo::fs::USER_INFO_ADDITION_INFO_PROXY_SERVERS;
8use crate::user::UserInfo;
9use crate::{
10    parse_to_socket_addresses, random_generate_encryption, rsa_decrypt_encryption,
11    rsa_encrypt_encryption, FramedConnection,
12};
13use bytes::BytesMut;
14use chrono::Utc;
15use futures_util::{SinkExt, StreamExt};
16pub use pool::*;
17use ppaass_protocol::{
18    Encryption, HandshakeRequest, HandshakeResponse, HeartbeatRequest, TunnelControlRequest,
19    TunnelControlResponse, TunnelInitFailureReason, TunnelInitRequest, TunnelInitResponse,
20};
21use std::net::SocketAddr;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::net::TcpStream;
25use tokio::time::timeout;
26use tokio_util::codec::{Framed, FramedParts};
27use tokio_util::io::{SinkWriter, StreamReader};
28use tracing::debug;
29#[derive(Debug, Clone)]
30pub struct ProxyTcpConnectionInfo {
31    proxy_address: SocketAddr,
32    authentication: String,
33}
34impl ProxyTcpConnectionInfo {
35    pub fn new(proxy_address: SocketAddr, authentication: String) -> Self {
36        Self {
37            proxy_address,
38            authentication,
39        }
40    }
41    pub fn authentication(&self) -> &str {
42        &self.authentication
43    }
44    pub fn proxy_address(&self) -> SocketAddr {
45        self.proxy_address
46    }
47}
48
49pub struct ProxyTcpConnectionNewState {}
50pub struct ProxyTcpConnectionTunnelCtlState {
51    tunnel_ctl_response_request_framed: Framed<TcpStream, TunnelControlResponseRequestCodec>,
52    proxy_encryption: Arc<Encryption>,
53    agent_encryption: Arc<Encryption>,
54}
55
56fn select_proxy_tcp_connection_info(
57    username: &str,
58    user_info: &UserInfo,
59) -> Result<ProxyTcpConnectionInfo, CommonError> {
60    let proxy_addresses = user_info
61        .get_additional_info::<Vec<String>>(USER_INFO_ADDITION_INFO_PROXY_SERVERS)
62        .ok_or(CommonError::Other(format!(
63            "No proxy servers defined in user info configuration: {user_info:?}"
64        )))?;
65    let proxy_addresses = parse_to_socket_addresses(proxy_addresses.iter())?;
66
67    let select_index = rand::random::<u64>() % proxy_addresses.len() as u64;
68    let proxy_address = proxy_addresses[select_index as usize];
69
70    Ok(ProxyTcpConnectionInfo::new(
71        proxy_address,
72        username.to_owned(),
73    ))
74}
75
76impl FramedConnection<ProxyTcpConnectionNewState> {
77    pub async fn create(
78        username: &str,
79        user_info: &UserInfo,
80        frame_buffer_size: usize,
81        connect_timeout: u64,
82    ) -> Result<FramedConnection<ProxyTcpConnectionTunnelCtlState>, CommonError> {
83        let proxy_tcp_connection_info = select_proxy_tcp_connection_info(username, user_info)?;
84        let proxy_tcp_stream = timeout(
85            Duration::from_secs(connect_timeout),
86            TcpStream::connect(proxy_tcp_connection_info.proxy_address()),
87        )
88        .await??;
89        proxy_tcp_stream.set_nodelay(true)?;
90        proxy_tcp_stream.set_linger(None)?;
91        let proxy_socket_address = proxy_tcp_stream.peer_addr()?;
92        let agent_encryption = random_generate_encryption();
93        let encrypt_agent_encryption =
94            rsa_encrypt_encryption(&agent_encryption, user_info.rsa_crypto())?;
95        let mut handshake_request_framed =
96            Framed::new(proxy_tcp_stream, HandshakeRequestEncoder::new());
97        let handshake_request = HandshakeRequest {
98            authentication: proxy_tcp_connection_info.authentication().to_owned(),
99            encryption: encrypt_agent_encryption.into_owned(),
100        };
101        debug!("Begin to send handshake request to proxy: {handshake_request:?}");
102        handshake_request_framed.send(handshake_request).await?;
103        debug!("Success to send handshake request to proxy: {proxy_socket_address:?}");
104        debug!("Begin to receive handshake response from proxy: {proxy_socket_address:?}");
105        let FramedParts {
106            io: proxy_tcp_stream,
107            ..
108        } = handshake_request_framed.into_parts();
109        let mut handshake_response_framed =
110            Framed::new(proxy_tcp_stream, HandshakeResponseDecoder::new());
111        let HandshakeResponse {
112            encryption: proxy_encryption,
113        } = handshake_response_framed
114            .next()
115            .await
116            .ok_or(CommonError::ConnectionExhausted(proxy_socket_address))??;
117        debug!("Success to receive handshake response from proxy: {proxy_socket_address:?}");
118        let proxy_encryption =
119            rsa_decrypt_encryption(&proxy_encryption, user_info.rsa_crypto())?.into_owned();
120        let FramedParts {
121            io: proxy_tcp_stream,
122            ..
123        } = handshake_response_framed.into_parts();
124        let socket_address = proxy_tcp_stream.peer_addr()?;
125        let proxy_encryption = Arc::new(proxy_encryption);
126        let agent_encryption = Arc::new(agent_encryption);
127        Ok(FramedConnection {
128            state: ProxyTcpConnectionTunnelCtlState {
129                proxy_encryption: proxy_encryption.clone(),
130                agent_encryption: agent_encryption.clone(),
131                tunnel_ctl_response_request_framed: Framed::with_capacity(
132                    proxy_tcp_stream,
133                    TunnelControlResponseRequestCodec::new(proxy_encryption, agent_encryption),
134                    frame_buffer_size,
135                ),
136            },
137            socket_address,
138            frame_buffer_size,
139        })
140    }
141}
142
143impl FramedConnection<ProxyTcpConnectionTunnelCtlState> {
144    pub async fn tunnel_init(
145        mut self,
146        tunnel_init_request: TunnelInitRequest,
147    ) -> Result<
148        FramedConnection<
149            SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>,
150        >,
151        CommonError,
152    > {
153        let tunnel_ctl_request = TunnelControlRequest::TunnelInit(tunnel_init_request);
154        self.state
155            .tunnel_ctl_response_request_framed
156            .send(tunnel_ctl_request)
157            .await?;
158        let mut times_to_receive_heartbeat = 0;
159        loop {
160            let tunnel_ctl_response = self
161                .state
162                .tunnel_ctl_response_request_framed
163                .next()
164                .await
165                .ok_or(CommonError::ConnectionExhausted(self.socket_address))??;
166            match tunnel_ctl_response {
167                TunnelControlResponse::Heartbeat(heartbeat) => {
168                    debug!("Receive heartbeat response from proxy connection: {heartbeat:?}");
169                    times_to_receive_heartbeat += 1;
170                    if times_to_receive_heartbeat >= 3 {
171                        return Err(CommonError::Other(
172                            "Receive too many heartbeats when initialize tunnel.".to_string(),
173                        ));
174                    }
175                    continue;
176                }
177                TunnelControlResponse::TunnelInit(tunnel_init_response) => {
178                    return match tunnel_init_response {
179                        TunnelInitResponse::Success => {
180                            let FramedParts { io, .. } =
181                                self.state.tunnel_ctl_response_request_framed.into_parts();
182                            Ok(FramedConnection {
183                                socket_address: self.socket_address,
184                                frame_buffer_size: self.frame_buffer_size,
185                                state: SinkWriter::new(StreamReader::new(
186                                    CryptoLengthDelimitedFramed::new(
187                                        io,
188                                        self.state.proxy_encryption,
189                                        self.state.agent_encryption,
190                                        self.frame_buffer_size,
191                                    ),
192                                )),
193                            })
194                        }
195                        TunnelInitResponse::Failure(TunnelInitFailureReason::AuthenticateFail) => {
196                            Err(CommonError::Other(format!(
197                                "Tunnel init fail on authenticate: {tunnel_init_response:?}",
198                            )))
199                        }
200                        TunnelInitResponse::Failure(
201                            TunnelInitFailureReason::InitWithDestinationFail,
202                        ) => Err(CommonError::Other(format!(
203                            "Tunnel init fail on connect destination: {tunnel_init_response:?}",
204                        ))),
205                    };
206                }
207            }
208        }
209    }
210    pub async fn heartbeat(&mut self, timeout_seconds: u64) -> Result<i64, CommonError> {
211        let start_time = Utc::now();
212
213        let heartbeat_request = TunnelControlRequest::Heartbeat(HeartbeatRequest::new());
214        self.state
215            .tunnel_ctl_response_request_framed
216            .send(heartbeat_request)
217            .await?;
218        let tunnel_ctl_response = timeout(
219            Duration::from_secs(timeout_seconds),
220            self.state.tunnel_ctl_response_request_framed.next(),
221        )
222        .await?
223        .ok_or(CommonError::ConnectionExhausted(self.socket_address))??;
224        match tunnel_ctl_response {
225            TunnelControlResponse::Heartbeat(heartbeat_response) => {
226                let end_time = Utc::now();
227                let check_duration = end_time
228                    .signed_duration_since(start_time)
229                    .num_milliseconds();
230                debug!("Receive heartbeat response from proxy connection: {heartbeat_response:?}");
231                Ok(check_duration)
232            }
233            TunnelControlResponse::TunnelInit(_) => Err(CommonError::Other(format!(
234                "Receive tunnel init response from proxy connection: {}",
235                self.socket_address
236            ))),
237        }
238    }
239
240    pub async fn close(&mut self) -> Result<(), CommonError> {
241        self.state.tunnel_ctl_response_request_framed.close().await
242    }
243}