ppaass_v3_common/connection/
agent.rs1use crate::connection::codec::{
2 HandshakeRequestDecoder, HandshakeResponseEncoder, TunnelControlRequestResponseCodec,
3};
4use crate::connection::CryptoLengthDelimitedFramed;
5use crate::error::CommonError;
6use crate::user::repo::fs::USER_INFO_ADDITION_INFO_EXPIRED_DATE_TIME;
7use crate::user::UserInfoRepository;
8use crate::{
9 random_generate_encryption, rsa_decrypt_encryption, rsa_encrypt_encryption, FramedConnection,
10};
11use chrono::{DateTime, Utc};
12use futures_util::SinkExt;
13use futures_util::StreamExt;
14use ppaass_protocol::{
15 Encryption, HandshakeRequest, HandshakeResponse, HeartbeatResponse, TunnelControlRequest,
16 TunnelControlResponse, TunnelInitRequest, TunnelInitResponse,
17};
18use std::net::SocketAddr;
19use std::sync::Arc;
20use tokio::net::TcpStream;
21use tokio_util::bytes::BytesMut;
22use tokio_util::codec::{Framed, FramedParts};
23use tokio_util::io::{SinkWriter, StreamReader};
24use tracing::debug;
25pub struct AgentTcpConnectionNewState {}
26pub struct AgentTcpConnectionTunnelCtlState {
27 tunnel_ctl_request_response_framed: Framed<TcpStream, TunnelControlRequestResponseCodec>,
28 proxy_encryption: Arc<Encryption>,
29 agent_encryption: Arc<Encryption>,
30}
31
32impl FramedConnection<AgentTcpConnectionNewState> {
33 pub async fn create<R>(
34 agent_tcp_stream: TcpStream,
35 agent_socket_address: SocketAddr,
36 user_info_repo: &R,
37 frame_buffer_size: usize,
38 ) -> Result<FramedConnection<AgentTcpConnectionTunnelCtlState>, CommonError>
39 where
40 R: UserInfoRepository + Sync + Send + 'static,
41 {
42 let mut handshake_request_framed =
43 Framed::new(agent_tcp_stream, HandshakeRequestDecoder::new());
44 let HandshakeRequest {
45 authentication,
46 encryption,
47 } = handshake_request_framed
48 .next()
49 .await
50 .ok_or(CommonError::ConnectionExhausted(agent_socket_address))??;
51 let user_info = user_info_repo
52 .get_user(&authentication)
53 .await?
54 .ok_or(CommonError::RsaCryptoNotFound(authentication.clone()))?;
55 let user_info = user_info.read().await;
56 let user_expired_time = user_info
57 .get_additional_info::<DateTime<Utc>>(USER_INFO_ADDITION_INFO_EXPIRED_DATE_TIME);
58 if let Some(user_expired_time) = user_expired_time {
59 if Utc::now() > *user_expired_time {
60 return Err(CommonError::UserExpired(authentication));
61 }
62 }
63 let agent_encryption =
64 rsa_decrypt_encryption(&encryption, user_info.rsa_crypto())?.into_owned();
65 let proxy_encryption = random_generate_encryption();
66 let encrypted_proxy_encryption =
67 rsa_encrypt_encryption(&proxy_encryption, user_info.rsa_crypto())?;
68 let handshake_response = HandshakeResponse {
69 encryption: encrypted_proxy_encryption.into_owned(),
70 };
71 let FramedParts {
72 io: agent_tcp_stream,
73 ..
74 } = handshake_request_framed.into_parts();
75 let mut handshake_response_framed =
76 Framed::new(agent_tcp_stream, HandshakeResponseEncoder::new());
77 handshake_response_framed.send(handshake_response).await?;
78 let FramedParts {
79 io: agent_tcp_stream,
80 ..
81 } = handshake_response_framed.into_parts();
82 let proxy_encryption = Arc::new(proxy_encryption);
83 let agent_encryption = Arc::new(agent_encryption);
84 Ok(FramedConnection {
85 socket_address: agent_socket_address,
86
87 frame_buffer_size,
88 state: AgentTcpConnectionTunnelCtlState {
89 proxy_encryption: proxy_encryption.clone(),
90 agent_encryption: agent_encryption.clone(),
91 tunnel_ctl_request_response_framed: Framed::with_capacity(
92 agent_tcp_stream,
93 TunnelControlRequestResponseCodec::new(agent_encryption, proxy_encryption),
94 frame_buffer_size,
95 ),
96 },
97 })
98 }
99}
100impl FramedConnection<AgentTcpConnectionTunnelCtlState> {
101 pub async fn wait_tunnel_init(&mut self) -> Result<TunnelInitRequest, CommonError> {
102 loop {
103 let tunnel_ctl_request = self
104 .state
105 .tunnel_ctl_request_response_framed
106 .next()
107 .await
108 .ok_or(CommonError::ConnectionExhausted(self.socket_address))??;
109 match tunnel_ctl_request {
110 TunnelControlRequest::Heartbeat(heartbeat_request) => {
111 debug!(
112 "Receive heartbeat request from agent connection [{}]: {heartbeat_request:?}",
113 self.socket_address
114 );
115 let heartbeat_response =
116 TunnelControlResponse::Heartbeat(HeartbeatResponse::new());
117 self.state
118 .tunnel_ctl_request_response_framed
119 .send(heartbeat_response)
120 .await?;
121 continue;
122 }
123 TunnelControlRequest::TunnelInit(tunnel_init_request) => {
124 return Ok(tunnel_init_request);
125 }
126 }
127 }
128 }
129
130 pub async fn response_tunnel_init(
131 mut self,
132 tunnel_init_response: TunnelInitResponse,
133 ) -> Result<
134 FramedConnection<
135 SinkWriter<StreamReader<CryptoLengthDelimitedFramed<TcpStream>, BytesMut>>,
136 >,
137 CommonError,
138 > {
139 let tunnel_ctl_response = TunnelControlResponse::TunnelInit(tunnel_init_response);
140 self.state
141 .tunnel_ctl_request_response_framed
142 .send(tunnel_ctl_response)
143 .await?;
144 let FramedParts { io, .. } = self.state.tunnel_ctl_request_response_framed.into_parts();
145 Ok(FramedConnection {
146 socket_address: self.socket_address,
147 state: SinkWriter::new(StreamReader::new(CryptoLengthDelimitedFramed::new(
148 io,
149 self.state.agent_encryption,
150 self.state.proxy_encryption,
151 self.frame_buffer_size,
152 ))),
153 frame_buffer_size: self.frame_buffer_size,
154 })
155 }
156}