ppaass_v3_common/connection/
agent.rs

1use crate::connection::codec::{
2    HandshakeRequestDecoder, HandshakeResponseEncoder, TunnelControlRequestResponseCodec,
3};
4use crate::connection::CryptoLengthDelimitedFramed;
5use crate::error::CommonError;
6use crate::user::repo::fs::USER_INFO_ADDITION_INFO_EXPIRED_DATE_TIME;
7use crate::user::UserInfoRepository;
8use crate::{
9    random_generate_encryption, rsa_decrypt_encryption, rsa_encrypt_encryption, FramedConnection,
10};
11use chrono::{DateTime, Utc};
12use futures_util::SinkExt;
13use futures_util::StreamExt;
14use ppaass_protocol::{
15    Encryption, HandshakeRequest, HandshakeResponse, HeartbeatResponse, TunnelControlRequest,
16    TunnelControlResponse, TunnelInitRequest, TunnelInitResponse,
17};
18use std::net::SocketAddr;
19use std::sync::Arc;
20use tokio::net::TcpStream;
21use tokio_util::bytes::BytesMut;
22use tokio_util::codec::{Framed, FramedParts};
23use tokio_util::io::{SinkWriter, StreamReader};
24use tracing::debug;
25pub struct AgentTcpConnectionNewState {}
26pub struct AgentTcpConnectionTunnelCtlState {
27    tunnel_ctl_request_response_framed: Framed<TcpStream, TunnelControlRequestResponseCodec>,
28    proxy_encryption: Arc<Encryption>,
29    agent_encryption: Arc<Encryption>,
30}
31
32impl FramedConnection<AgentTcpConnectionNewState> {
33    pub async fn create<R>(
34        agent_tcp_stream: TcpStream,
35        agent_socket_address: SocketAddr,
36        user_info_repo: &R,
37        frame_buffer_size: usize,
38    ) -> Result<FramedConnection<AgentTcpConnectionTunnelCtlState>, CommonError>
39    where
40        R: UserInfoRepository + Sync + Send + 'static,
41    {
42        let mut handshake_request_framed =
43            Framed::new(agent_tcp_stream, HandshakeRequestDecoder::new());
44        let HandshakeRequest {
45            authentication,
46            encryption,
47        } = handshake_request_framed
48            .next()
49            .await
50            .ok_or(CommonError::ConnectionExhausted(agent_socket_address))??;
51        let user_info = user_info_repo
52            .get_user(&authentication)
53            .await?
54            .ok_or(CommonError::RsaCryptoNotFound(authentication.clone()))?;
55        let user_info = user_info.read().await;
56        let user_expired_time = user_info
57            .get_additional_info::<DateTime<Utc>>(USER_INFO_ADDITION_INFO_EXPIRED_DATE_TIME);
58        if let Some(user_expired_time) = user_expired_time {
59            if Utc::now() > *user_expired_time {
60                return Err(CommonError::UserExpired(authentication));
61            }
62        }
63        let agent_encryption =
64            rsa_decrypt_encryption(&encryption, user_info.rsa_crypto())?.into_owned();
65        let proxy_encryption = random_generate_encryption();
66        let encrypted_proxy_encryption =
67            rsa_encrypt_encryption(&proxy_encryption, user_info.rsa_crypto())?;
68        let handshake_response = HandshakeResponse {
69            encryption: encrypted_proxy_encryption.into_owned(),
70        };
71        let FramedParts {
72            io: agent_tcp_stream,
73            ..
74        } = handshake_request_framed.into_parts();
75        let mut handshake_response_framed =
76            Framed::new(agent_tcp_stream, HandshakeResponseEncoder::new());
77        handshake_response_framed.send(handshake_response).await?;
78        let FramedParts {
79            io: agent_tcp_stream,
80            ..
81        } = handshake_response_framed.into_parts();
82        let proxy_encryption = Arc::new(proxy_encryption);
83        let agent_encryption = Arc::new(agent_encryption);
84        Ok(FramedConnection {
85            socket_address: agent_socket_address,
86
87            frame_buffer_size,
88            state: AgentTcpConnectionTunnelCtlState {
89                proxy_encryption: proxy_encryption.clone(),
90                agent_encryption: agent_encryption.clone(),
91                tunnel_ctl_request_response_framed: Framed::with_capacity(
92                    agent_tcp_stream,
93                    TunnelControlRequestResponseCodec::new(agent_encryption, proxy_encryption),
94                    frame_buffer_size,
95                ),
96            },
97        })
98    }
99}
100impl FramedConnection<AgentTcpConnectionTunnelCtlState> {
101    pub async fn wait_tunnel_init(&mut self) -> Result<TunnelInitRequest, CommonError> {
102        loop {
103            let tunnel_ctl_request = self
104                .state
105                .tunnel_ctl_request_response_framed
106                .next()
107                .await
108                .ok_or(CommonError::ConnectionExhausted(self.socket_address))??;
109            match tunnel_ctl_request {
110                TunnelControlRequest::Heartbeat(heartbeat_request) => {
111                    debug!(
112                        "Receive heartbeat request from agent connection [{}]: {heartbeat_request:?}",
113                        self.socket_address
114                    );
115                    let heartbeat_response =
116                        TunnelControlResponse::Heartbeat(HeartbeatResponse::new());
117                    self.state
118                        .tunnel_ctl_request_response_framed
119                        .send(heartbeat_response)
120                        .await?;
121                    continue;
122                }
123                TunnelControlRequest::TunnelInit(tunnel_init_request) => {
124                    return Ok(tunnel_init_request);
125                }
126            }
127        }
128    }
129
130    pub async fn response_tunnel_init(
131        mut self,
132        tunnel_init_response: TunnelInitResponse,
133    ) -> Result<
134        FramedConnection<
135            SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>,
136        >,
137        CommonError,
138    > {
139        let tunnel_ctl_response = TunnelControlResponse::TunnelInit(tunnel_init_response);
140        self.state
141            .tunnel_ctl_request_response_framed
142            .send(tunnel_ctl_response)
143            .await?;
144        let FramedParts { io, .. } = self.state.tunnel_ctl_request_response_framed.into_parts();
145        Ok(FramedConnection {
146            socket_address: self.socket_address,
147            state: SinkWriter::new(StreamReader::new(CryptoLengthDelimitedFramed::new(
148                io,
149                self.state.agent_encryption,
150                self.state.proxy_encryption,
151                self.frame_buffer_size,
152            ))),
153            frame_buffer_size: self.frame_buffer_size,
154        })
155    }
156}