1use crate::server::server_router::TfServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::Framed;
26
27pub type RequestChannel<C> = (
33 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37
38#[derive(Clone)]
39pub enum ServerMode {
40 Tcp,
42 WebSocket,
44}
45
46pub struct TfServer<C>
53where
54 C: TfCodec,
55{
56 router: Arc<TfServerRouter<C>>,
57 socket: Arc<TcpListener>,
58 shutdown_sig: Arc<Notify>,
59 processor: Option<TrafficProcessorHolder<C>>,
60 codec: C,
61 config: Option<ServerConfig>,
62 mode: ServerMode,
63}
64
65impl<C> TfServer<C>
66where
67 C: TfCodec,
68{
69 pub async fn new(
77 bind_address: String,
78 router: Arc<TfServerRouter<C>>,
79 processor: Option<TrafficProcessorHolder<C>>,
80 codec: C,
81 config: Option<ServerConfig>,
82 mode: ServerMode,
83 ) -> Self {
84 Self {
85 router,
86 socket: Arc::new(
87 TcpListener::bind(&bind_address)
88 .await
89 .expect("Failed to bind to address"),
90 ),
91 shutdown_sig: Arc::new(Notify::new()),
92 processor,
93 codec,
94 config,
95 mode
96 }
97 }
98
99 pub async fn start(&mut self) -> JoinHandle<()> {
103 let (listener, router, shutdown_sig) = (
104 self.socket.clone(),
105 self.router.clone(),
106 self.shutdown_sig.clone(),
107 );
108 let mut processor = if let Some(proc) = self.processor.take() {
109 proc
110 } else {
111 TrafficProcessorHolder::new()
112 };
113 let codec = self.codec.clone();
114 let config = self.config.clone();
115 let mode = self.mode.clone(); tokio::spawn(async move {
118 loop {
119 tokio::select! {
120 res = listener.accept() => {
121 if let Ok((stream, addr)) = res {
122 let _ = stream.set_nodelay(true);
123 let codec = codec.clone();
124 let mode = mode.clone(); let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
128
129 if let Some(mut transport) = transport {
130 if processor.initial_connect(&mut transport.0).await {
131 let mut framed = Framed::new(transport.0, transport.1);
132 if processor.initial_framed_connect(&mut framed).await {
133 let router = router.clone();
134 let prc_clone = processor.clone();
135 tokio::spawn(async move {
136 Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
137 });
138 }
139 } else {
140 let _ = transport.0.shutdown().await;
141 }
142 }
143 }
144 }
145 _ = shutdown_sig.notified() => break,
146 }
147 }
148 })
149 }
150
151 async fn initial_accept(
153 stream: TcpStream,
154 config: Option<ServerConfig>,
155 mut codec_setup: C,
156 mode: &ServerMode,
157 ) -> Option<(Transport, C)> {
158 let transport = match &config {
159 None => Transport::plain(stream),
160 Some(cfg) => {
161 let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
162 match acceptor.accept(stream).await {
163 Ok(tls) => Transport::tls_server(tls),
164 Err(_) => return None,
165 }
166 }
167 };
168
169
170 let mut transport = match mode {
171 ServerMode::Tcp => transport,
172 ServerMode::WebSocket => {
173 match Transport::accept_websocket(transport).await {
174 Ok(ws_stream) => ws_stream,
175 Err(e) => {
176 eprintln!("WebSocket handshake failed: {e}");
177 return None;
178 }
179 }
180 }
181 };
182
183 if !codec_setup.initial_setup(&mut transport).await {
184 return None;
185 }
186
187 Some((transport, codec_setup))
188 }
189 pub fn send_stop(&self) {
191 self.shutdown_sig.notify_waiters();
192 }
193
194 async fn handle_connection(
196 addr: SocketAddr,
197 mut stream: Framed<Transport, C>,
198 router: &TfServerRouter<C>,
199 mut processor: TrafficProcessorHolder<C>,
200 ) {
201 use futures_util::SinkExt;
202 let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
203 let mut move_sig = (Some(move_sig.0), move_sig.1);
204 loop {
205 let meta_data: Result<Option<BytesMut>, bool> =
206 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
207 if meta_data.is_err() {
208 if meta_data.unwrap_err() {
209 stream.close().await.unwrap();
210 return;
211 }
212 continue;
213 }
214
215 let meta_data = meta_data.unwrap();
216 if meta_data.is_none() {
217 continue;
218 }
219 let meta_data = meta_data.unwrap();
220 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
221 Ok(meta) => meta.has_payload,
222 Err(_) => false,
223 };
224
225 let mut payload: BytesMut = BytesMut::new();
226 if has_payload {
227 let payload_res =
228 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
229 if payload_res.is_err() {
230 if payload_res.unwrap_err() {
231 stream.close().await.unwrap();
232 return;
233 }
234 continue;
235 }
236 let payload_opt = payload_res.unwrap();
237 if payload_opt.is_none() {
238 let _ = stream.close().await;
239 return;
240 }
241 payload = payload_opt.unwrap();
242 }
243 let res = router
244 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
245 .await;
246
247 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
248 let res = Self::send_message(&mut stream, message, &mut processor).await;
249
250 if let Ok(requester) = move_sig.1.try_recv() {
251 requester
252 .write()
253 .await
254 .accept_stream(addr, (stream, processor.clone()))
255 .await;
256 return;
257 }
258
259 match res {
260 Err(_) => {
261 let _ = stream.close();
262 return;
263 }
264 _ => {}
265 }
266 }
267 }
268 async fn send_message(
269 stream: &mut Framed<Transport, C>,
270 message: Vec<u8>,
271 processor: &mut TrafficProcessorHolder<C>,
272 ) -> Result<(), io::Error> {
273 let message = Bytes::from(processor.post_process_traffic(message).await);
274 stream.send(message).await
275 }
276
277 async fn receive_message(
278 _: SocketAddr,
279 stream: &mut Framed<Transport, C>,
280 processor: &mut TrafficProcessorHolder<C>,
281 ) -> Result<Option<BytesMut>, bool> {
282 use futures_util::StreamExt;
283 match stream.next().await {
284 Some(data) => match data {
285 Ok(mut data) => {
286 data = processor.pre_process_traffic(data).await;
287 return Ok(Some(data));
288 }
289 Err(e) => {
290 match e.kind() {
292 std::io::ErrorKind::ConnectionReset
294 | std::io::ErrorKind::ConnectionAborted
295 | std::io::ErrorKind::BrokenPipe
296 | std::io::ErrorKind::UnexpectedEof => {
297 println!("Client disconnected");
298 return Err(true);
299 }
300
301 std::io::ErrorKind::InvalidData => {
303 eprintln!("Frame exceeded maximum size: {e}");
304 return Err(false);
305 }
306
307 _ => {
309 eprintln!("IO error while reading frame: {e}");
310 return Err(false);
311 }
312 }
313 }
314 },
315 None => {
316 return Err(true);
317 }
318 }
319 }
320}
321
322impl fmt::Display for ServerErrorEn {
324 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325 match self {
326 ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
327 write!(f, "Malformed meta info: {}", msg)
328 }
329 ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
330 ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
331 ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
332 InternalError(Some(data)) => {
333 write!(
334 f,
335 "{}",
336 String::from_utf8(data.clone())
337 .unwrap_or_else(|_| "Internal server error!".to_owned())
338 )
339 }
340 InternalError(None) => {
341 write!(f, "Internal server error!")
342 }
343 ServerErrorEn::PayloadLost => {
344 write!(f, "Payload lost!")
345 }
346 }
347 }
348}
349
350impl std::error::Error for ServerErrorEn {}