1use crate::server::server_router::TcpServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::{Decoder, Encoder, Framed};
26
27pub type RequestChannel<C> = (
33 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37pub struct TcpServer<C>
44where
45 C: Encoder<Bytes, Error = io::Error>
46 + Decoder<Item = BytesMut, Error = io::Error>
47 + Send
48 + Sync
49 + Clone
50 + 'static
51 + TfCodec,
52{
53 router: Arc<TcpServerRouter<C>>,
54 socket: Arc<TcpListener>,
55 shutdown_sig: Arc<Notify>,
56 processor: Option<TrafficProcessorHolder<C>>,
57 codec: C,
58 config: Option<ServerConfig>,
59}
60
61impl<C> TcpServer<C>
62where
63 C: Encoder<Bytes, Error = io::Error>
64 + Decoder<Item = BytesMut, Error = io::Error>
65 + Send
66 + Sync
67 + Clone
68 + 'static
69 + TfCodec,
70{
71 pub async fn new(
79 bind_address: String,
80 router: Arc<TcpServerRouter<C>>,
81 processor: Option<TrafficProcessorHolder<C>>,
82 codec: C,
83 config: Option<ServerConfig>,
84 ) -> Self {
85 Self {
86 router,
87 socket: Arc::new(
88 TcpListener::bind(&bind_address)
89 .await
90 .expect("Failed to bind to address"),
91 ),
92 shutdown_sig: Arc::new(Notify::new()),
93 processor,
94 codec,
95 config,
96 }
97 }
98
99 pub async fn start(&mut self) -> JoinHandle<()> {
103 let (listener, router, shutdown_sig) = {
104 (
105 self.socket.clone(),
106 self.router.clone(),
107 self.shutdown_sig.clone(),
108 )
109 };
110 let mut processor = if let Some(proc) = self.processor.take() {
111 proc
112 } else {
113 TrafficProcessorHolder::new()
114 };
115 let codec = self.codec.clone();
116 let config = self.config.clone();
117 tokio::spawn(async move {
118 loop {
119 tokio::select! {
120 res = listener.accept() => {
121 if res.is_ok() {
122 let stream = res.unwrap();
123 let res = stream.0.set_nodelay(true);
124 if res.is_ok() {
125 let codec = codec.clone();
126
127
128 let transport = Self::initial_accept(stream.0, config.clone(), codec).await;
129 if let Some(mut transport) = transport {
130 if processor.initial_connect(&mut transport.0).await {
131 let mut framed = Framed::new(transport.0, transport.1);
132
133 if processor.initial_framed_connect(&mut framed).await {
134 let router = router.clone();
135 let prc_clone = processor.clone();
136
137 tokio::spawn(async move {
138 Self::handle_connection(
139 stream.1,
140 framed,
141 router.as_ref(),
142 prc_clone,
143 )
144 .await;
145 });
146 }
147 } else {
148 let _ = transport.0.shutdown().await;
149 }
150 }
151
152 }
153
154 }
155 }
156 _ = shutdown_sig.notified() => break,
157 }
158 }
159 })
160 }
161
162 async fn initial_accept(
164 stream: TcpStream,
165 config: Option<ServerConfig>,
166 mut codec_setup: C,
167 ) -> Option<(Transport, C)> {
168 if config.is_none() {
169 let mut res = Transport::plain(stream);
170 if !codec_setup.initial_setup(&mut res).await {
171 return None;
172 }
173 return Some((res, codec_setup));
174 } else {
175 let cfg = config.unwrap();
176 let acceptor = TlsAcceptor::from(Arc::new(cfg));
177 let res = acceptor.accept(stream).await;
178 if res.is_err() {
179 return None;
180 }
181 let mut res = Transport::tls_server(res.unwrap());
182 if !codec_setup.initial_setup(&mut res).await {
183 return None;
184 }
185 return Some((res, codec_setup));
186 }
187 }
188 pub fn send_stop(&self) {
190 self.shutdown_sig.notify_waiters();
191 }
192
193 async fn handle_connection(
195 addr: SocketAddr,
196 mut stream: Framed<Transport, C>,
197 router: &TcpServerRouter<C>,
198 mut processor: TrafficProcessorHolder<C>,
199 ) {
200 use futures_util::SinkExt;
201 let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
202 let mut move_sig = (Some(move_sig.0), move_sig.1);
203 loop {
204 let meta_data: Result<Option<BytesMut>, bool> =
205 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
206 if meta_data.is_err() {
207 if meta_data.unwrap_err() {
208 stream.close().await.unwrap();
209 return;
210 }
211 continue;
212 }
213
214 let meta_data = meta_data.unwrap();
215 if meta_data.is_none() {
216 continue;
217 }
218 let meta_data = meta_data.unwrap();
219 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
220 Ok(meta) => meta.has_payload,
221 Err(_) => false,
222 };
223
224 let mut payload: BytesMut = BytesMut::new();
225 if has_payload {
226 let payload_res =
227 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
228 if payload_res.is_err() {
229 if payload_res.unwrap_err() {
230 stream.close().await.unwrap();
231 return;
232 }
233 continue;
234 }
235 let payload_opt = payload_res.unwrap();
236 if payload_opt.is_none() {
237 let _ = stream.close().await;
238 return;
239 }
240 payload = payload_opt.unwrap();
241 }
242 let res = router
243 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
244 .await;
245
246 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
247 let res = Self::send_message(&mut stream, message, &mut processor).await;
248
249 if let Ok(requester) = move_sig.1.try_recv() {
250 requester
251 .write().await
252 .accept_stream(addr, (stream, processor.clone()))
253 .await;
254 return;
255 }
256
257 match res {
258 Err(_) => {
259 let _ = stream.close();
260 return;
261 }
262 _ => {}
263 }
264 }
265 }
266 async fn send_message(
267 stream: &mut Framed<Transport, C>,
268 message: Vec<u8>,
269 processor: &mut TrafficProcessorHolder<C>,
270 ) -> Result<(), io::Error> {
271 let message = Bytes::from(processor.post_process_traffic(message).await);
272 stream.send(message).await
273 }
274
275 async fn receive_message(
276 _: SocketAddr,
277 stream: &mut Framed<Transport, C>,
278 processor: &mut TrafficProcessorHolder<C>,
279 ) -> Result<Option<BytesMut>, bool> {
280 use futures_util::StreamExt;
281 match stream.next().await {
282 Some(data) => match data {
283 Ok(mut data) => {
284 data = processor.pre_process_traffic(data).await;
285 return Ok(Some(data));
286 }
287 Err(e) => {
288 match e.kind() {
290 std::io::ErrorKind::ConnectionReset
292 | std::io::ErrorKind::ConnectionAborted
293 | std::io::ErrorKind::BrokenPipe
294 | std::io::ErrorKind::UnexpectedEof => {
295 println!("Client disconnected");
296 return Err(true);
297 }
298
299 std::io::ErrorKind::InvalidData => {
301 eprintln!("Frame exceeded maximum size: {e}");
302 return Err(false);
303 }
304
305 _ => {
307 eprintln!("IO error while reading frame: {e}");
308 return Err(false);
309 }
310 }
311 }
312 },
313 None => {
314 return Err(true);
315 }
316 }
317 }
318}
319
320impl fmt::Display for ServerErrorEn {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 match self {
324 ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
325 write!(f, "Malformed meta info: {}", msg)
326 }
327 ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
328 ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
329 ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
330 InternalError(Some(data)) => {
331 write!(
332 f,
333 "{}",
334 String::from_utf8(data.clone())
335 .unwrap_or_else(|_| "Internal server error!".to_owned())
336 )
337 }
338 InternalError(None) => {
339 write!(f, "Internal server error!")
340 }
341 ServerErrorEn::PayloadLost => {
342 write!(f, "Payload lost!")
343 }
344 }
345 }
346}
347
348impl std::error::Error for ServerErrorEn {}