Skip to main content

tfserver/server/
tcp_server.rs

1use crate::server::server_router::TcpServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::{Decoder, Encoder, Framed};
26
27///The request channel, used to move out tcp stream out of server control.
28///
29///When the stream is moved, the server does not owns it anymore.
30///
31///If is there need to return stream, only reconnect is available.
32pub type RequestChannel<C> = (
33    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37///Base binary tcp server.
38///
39/// 'C' is you codec, that you want to use to encode/decode data.
40///
41///Recommended default codec is LengthDelimitedCodec, from the server codec module.
42
43pub struct TcpServer<C>
44where
45    C: Encoder<Bytes, Error = io::Error>
46        + Decoder<Item = BytesMut, Error = io::Error>
47        + Send
48        + Sync
49        + Clone
50        + 'static
51        + TfCodec,
52{
53    router: Arc<TcpServerRouter<C>>,
54    socket: Arc<TcpListener>,
55    shutdown_sig: Arc<Notify>,
56    processor: Option<TrafficProcessorHolder<C>>,
57    codec: C,
58    config: Option<ServerConfig>,
59}
60
61impl<C> TcpServer<C>
62where
63    C: Encoder<Bytes, Error = io::Error>
64        + Decoder<Item = BytesMut, Error = io::Error>
65        + Send
66        + Sync
67        + Clone
68        + 'static
69        + TfCodec,
70{
71    ///Creates a new instance of a server.
72    ///
73    /// 'bind_address' is a target address to bind current server. E.g: 0.0.0.0:8080
74    /// 'router' setted up router with handlers. Must be called commit_routes before using.
75    /// 'processor' Custom traffic processor, used for all streams.
76    /// 'codec' basically codec used for every stream with it's own instance, when the codec is applied to stream, first call is clone, the second call is initial_setup.
77    /// 'config' optional config for tls connection, when None the tls is not using, when some all connections are passed behind tls.
78    pub async fn new(
79        bind_address: String,
80        router: Arc<TcpServerRouter<C>>,
81        processor: Option<TrafficProcessorHolder<C>>,
82        codec: C,
83        config: Option<ServerConfig>,
84    ) -> Self {
85        Self {
86            router,
87            socket: Arc::new(
88                TcpListener::bind(&bind_address)
89                    .await
90                    .expect("Failed to bind to address"),
91            ),
92            shutdown_sig: Arc::new(Notify::new()),
93            processor,
94            codec,
95            config,
96        }
97    }
98
99    ///Start the task for handling connections.
100    ///
101    ///Return the join handle, of this task.
102    pub async fn start(&mut self) -> JoinHandle<()> {
103        let (listener, router, shutdown_sig) = {
104            (
105                self.socket.clone(),
106                self.router.clone(),
107                self.shutdown_sig.clone(),
108            )
109        };
110        let mut processor = if let Some(proc) = self.processor.take() {
111            proc
112        } else {
113            TrafficProcessorHolder::new()
114        };
115        let codec = self.codec.clone();
116        let config = self.config.clone();
117        tokio::spawn(async move {
118            loop {
119                tokio::select! {
120                            res = listener.accept() => {
121                            if res.is_ok() {
122                                let stream = res.unwrap();
123                                let res = stream.0.set_nodelay(true);
124                                if res.is_ok() {
125                                let codec = codec.clone();
126
127
128                                let transport = Self::initial_accept(stream.0, config.clone(), codec).await;
129                                if let Some(mut transport) = transport {
130                                 if processor.initial_connect(&mut transport.0).await {
131                                    let mut framed = Framed::new(transport.0, transport.1);
132
133                                if processor.initial_framed_connect(&mut framed).await {
134                                    let router = router.clone();
135                                    let prc_clone = processor.clone();
136
137                                    tokio::spawn(async move {
138                                        Self::handle_connection(
139                                            stream.1,
140                                            framed,
141                                            router.as_ref(),
142                                            prc_clone,
143                                        )
144                                        .await;
145                                    });
146                                }
147                                } else {
148                                    let _ = transport.0.shutdown().await;
149                                }
150                            }
151
152                            }
153
154                        }
155                    }
156                    _ = shutdown_sig.notified() => break,
157                }
158            }
159        })
160    }
161
162    ///Initial accept called for every connection, on connected event.
163    async fn initial_accept(
164        stream: TcpStream,
165        config: Option<ServerConfig>,
166        mut codec_setup: C,
167    ) -> Option<(Transport, C)> {
168        if config.is_none() {
169            let mut res = Transport::plain(stream);
170            if !codec_setup.initial_setup(&mut res).await {
171                return None;
172            }
173            return Some((res, codec_setup));
174        } else {
175            let cfg = config.unwrap();
176            let acceptor = TlsAcceptor::from(Arc::new(cfg));
177            let res = acceptor.accept(stream).await;
178            if res.is_err() {
179                return None;
180            }
181            let mut res = Transport::tls_server(res.unwrap());
182            if !codec_setup.initial_setup(&mut res).await {
183                return None;
184            }
185            return Some((res, codec_setup));
186        }
187    }
188    ///Stops the acceptor task.
189    pub fn send_stop(&self) {
190        self.shutdown_sig.notify_waiters();
191    }
192
193    ///Main function for every connection
194    async fn handle_connection(
195        addr: SocketAddr,
196        mut stream: Framed<Transport, C>,
197        router: &TcpServerRouter<C>,
198        mut processor: TrafficProcessorHolder<C>,
199    ) {
200        use futures_util::SinkExt;
201        let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
202        let mut move_sig = (Some(move_sig.0), move_sig.1);
203        loop {
204            let meta_data: Result<Option<BytesMut>, bool> =
205                Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
206            if meta_data.is_err() {
207                if meta_data.unwrap_err() {
208                    stream.close().await.unwrap();
209                    return;
210                }
211                continue;
212            }
213
214            let meta_data = meta_data.unwrap();
215            if meta_data.is_none() {
216                continue;
217            }
218            let meta_data = meta_data.unwrap();
219            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
220                Ok(meta) => meta.has_payload,
221                Err(_) => false,
222            };
223
224            let mut payload: BytesMut = BytesMut::new();
225            if has_payload {
226                let payload_res =
227                    Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
228                if payload_res.is_err() {
229                    if payload_res.unwrap_err() {
230                        stream.close().await.unwrap();
231                        return;
232                    }
233                    continue;
234                }
235                let payload_opt = payload_res.unwrap();
236                if payload_opt.is_none() {
237                    let _ = stream.close().await;
238                    return;
239                }
240                payload = payload_opt.unwrap();
241            }
242            let res = router
243                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
244                .await;
245
246            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
247            let res = Self::send_message(&mut stream, message, &mut processor).await;
248
249            if let Ok(requester) = move_sig.1.try_recv() {
250                requester
251                    .write().await
252                    .accept_stream(addr, (stream, processor.clone()))
253                    .await;
254                return;
255            }
256
257            match res {
258                Err(_) => {
259                    let _ = stream.close();
260                    return;
261                }
262                _ => {}
263            }
264        }
265    }
266    async fn send_message(
267        stream: &mut Framed<Transport, C>,
268        message: Vec<u8>,
269        processor: &mut TrafficProcessorHolder<C>,
270    ) -> Result<(), io::Error> {
271        let message = Bytes::from(processor.post_process_traffic(message).await);
272        stream.send(message).await
273    }
274
275    async fn receive_message(
276        _: SocketAddr,
277        stream: &mut Framed<Transport, C>,
278        processor: &mut TrafficProcessorHolder<C>,
279    ) -> Result<Option<BytesMut>, bool> {
280        use futures_util::StreamExt;
281        match stream.next().await {
282            Some(data) => match data {
283                Ok(mut data) => {
284                    data = processor.pre_process_traffic(data).await;
285                    return Ok(Some(data));
286                }
287                Err(e) => {
288                    // This is where codec-level decoding errors happen
289                    match e.kind() {
290                        // IO errors usually mean the connection is broken
291                        std::io::ErrorKind::ConnectionReset
292                        | std::io::ErrorKind::ConnectionAborted
293                        | std::io::ErrorKind::BrokenPipe
294                        | std::io::ErrorKind::UnexpectedEof => {
295                            println!("Client disconnected");
296                            return Err(true);
297                        }
298
299                        // Frame too large (if you set max_frame_length)
300                        std::io::ErrorKind::InvalidData => {
301                            eprintln!("Frame exceeded maximum size: {e}");
302                            return Err(false);
303                        }
304
305                        // Other IO errors
306                        _ => {
307                            eprintln!("IO error while reading frame: {e}");
308                            return Err(false);
309                        }
310                    }
311                }
312            },
313            None => {
314                return Err(true);
315            }
316        }
317    }
318}
319
320// Custom Error Display
321impl fmt::Display for ServerErrorEn {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        match self {
324            ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
325                write!(f, "Malformed meta info: {}", msg)
326            }
327            ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
328            ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
329            ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
330            InternalError(Some(data)) => {
331                write!(
332                    f,
333                    "{}",
334                    String::from_utf8(data.clone())
335                        .unwrap_or_else(|_| "Internal server error!".to_owned())
336                )
337            }
338            InternalError(None) => {
339                write!(f, "Internal server error!")
340            }
341            ServerErrorEn::PayloadLost => {
342                write!(f, "Payload lost!")
343            }
344        }
345    }
346}
347
348impl std::error::Error for ServerErrorEn {}