ppaass_v3_proxy_core/tunnel/
mod.rs

1use crate::config::ProxyConfig;
2use crate::tunnel::destination::DestinationEdge;
3use ppaass_common::error::CommonError;
4use ppaass_common::server::ServerState;
5use ppaass_common::user::repo::fs::FileSystemUserInfoRepository;
6use ppaass_common::{
7    AgentTcpConnectionNewState, AgentTcpConnectionTunnelCtlState, FramedConnection,
8    TunnelInitFailureReason, TunnelInitRequest, TunnelInitResponse,
9};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::{
13    io::{copy_bidirectional, copy_bidirectional_with_sizes},
14    net::TcpStream,
15};
16use tokio_util::io::{SinkWriter, StreamReader};
17use tracing::debug;
18mod destination;
19
20pub struct Tunnel {
21    config: Arc<ProxyConfig>,
22    agent_tcp_connection: FramedConnection<AgentTcpConnectionTunnelCtlState>,
23    agent_socket_address: SocketAddr,
24    server_state: Arc<ServerState>,
25}
26
27impl Tunnel {
28    pub async fn new(
29        config: Arc<ProxyConfig>,
30        server_state: Arc<ServerState>,
31        agent_tcp_stream: TcpStream,
32        agent_socket_address: SocketAddr,
33    ) -> Result<Self, CommonError> {
34        let user_repo = server_state
35            .get_value::<Arc<FileSystemUserInfoRepository>>()
36            .ok_or(CommonError::Other(format!(
37                "Fail to get user crypto repository for agent: {agent_socket_address}"
38            )))?;
39        let agent_tcp_connection = FramedConnection::<AgentTcpConnectionNewState>::create(
40            agent_tcp_stream,
41            agent_socket_address,
42            user_repo.as_ref(),
43            config.agent_frame_buffer_size(),
44        )
45        .await?;
46        Ok(Self {
47            config,
48            server_state,
49            agent_tcp_connection,
50            agent_socket_address,
51        })
52    }
53
54    async fn initialize_tunnel(
55        tunnel_init_request: TunnelInitRequest,
56        agent_socket_address: SocketAddr,
57        config: &ProxyConfig,
58        server_state: &ServerState,
59    ) -> Result<DestinationEdge, CommonError> {
60        let TunnelInitRequest {
61            destination_address,
62            keep_alive,
63        } = tunnel_init_request;
64        match config.forward() {
65            None => {
66                debug!(
67                    "[START TCP] Begin to initialize tunnel for agent: {agent_socket_address:?}"
68                );
69                let destination_edge = DestinationEdge::start_direct(
70                    destination_address,
71                    keep_alive,
72                    config.destination_connect_timeout(),
73                )
74                .await?;
75                Ok(destination_edge)
76            }
77            Some(forward_config) => {
78                debug!(
79                    "[START FORWARD] Begin to initialize tunnel for agent: {agent_socket_address:?}"
80                );
81                let destination_edge = DestinationEdge::start_forward(
82                    server_state,
83                    forward_config,
84                    destination_address,
85                )
86                .await?;
87                Ok(destination_edge)
88            }
89        }
90    }
91
92    pub async fn run(mut self) -> Result<(), CommonError> {
93        let tunnel_init_request = self.agent_tcp_connection.wait_tunnel_init().await?;
94        match Self::initialize_tunnel(
95            tunnel_init_request,
96            self.agent_socket_address,
97            self.config.as_ref(),
98            self.server_state.as_ref(),
99        )
100        .await
101        {
102            Err(e) => {
103                self.agent_tcp_connection
104                    .response_tunnel_init(TunnelInitResponse::Failure(
105                        TunnelInitFailureReason::InitWithDestinationFail,
106                    ))
107                    .await?;
108                Err(e)
109            }
110            Ok(destination_edge) => match destination_edge {
111                DestinationEdge::Direct(destination_tcp_endpoint) => {
112                    let mut agent_tcp_connection = self
113                        .agent_tcp_connection
114                        .response_tunnel_init(TunnelInitResponse::Success)
115                        .await?;
116
117                    let destination_tcp_endpoint = StreamReader::new(destination_tcp_endpoint);
118                    let mut destination_tcp_connection = SinkWriter::new(destination_tcp_endpoint);
119
120                    let (agent_data_size, destination_data_size) = copy_bidirectional_with_sizes(
121                        &mut agent_tcp_connection,
122                        &mut destination_tcp_connection,
123                        self.config.proxy_to_destination_data_relay_buffer_size(),
124                        self.config.destination_to_proxy_data_relay_buffer_size(),
125                    )
126                    .await?;
127                    debug!(
128                        "[PROXYING] Copy data between agent and destination, agent data size: {agent_data_size}, destination data size: {destination_data_size}"
129                    );
130                    Ok(())
131                }
132                DestinationEdge::Forward(mut forward_proxy_tcp_connection) => {
133                    let mut agent_tcp_connection = self
134                        .agent_tcp_connection
135                        .response_tunnel_init(TunnelInitResponse::Success)
136                        .await?;
137
138                    let (agent_data_size, proxy_data_size) = copy_bidirectional(
139                        &mut agent_tcp_connection,
140                        &mut forward_proxy_tcp_connection,
141                    )
142                    .await?;
143                    debug!(
144                        "[FORWARDING] Copy data between agent and proxy, agent data size: {agent_data_size}, proxy data size: {proxy_data_size}"
145                    );
146                    Ok(())
147                }
148            },
149        }
150    }
151}
152
153pub async fn handle_agent_connection(
154    config: Arc<ProxyConfig>,
155    server_state: Arc<ServerState>,
156    agent_tcp_stream: TcpStream,
157    agent_socket_address: SocketAddr,
158) -> Result<(), CommonError> {
159    let tunnel = Tunnel::new(config, server_state, agent_tcp_stream, agent_socket_address).await?;
160    tunnel.run().await
161}