Skip to main content

tfserver/server/
server.rs

1use crate::server::server_router::TfServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::Framed;
26
27///The request channel, used to move out tcp stream out of server control.
28///
29///When the stream is moved, the server does not owns it anymore.
30///
31///If is there need to return stream, only reconnect is available.
32pub type RequestChannel<C> = (
33    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37
38#[derive(Clone)]
39pub enum ServerMode {
40    /// Plain TCP or TLS
41    Tcp,
42    /// WebSocket upgrade over plain TCP or TLS
43    WebSocket,
44}
45
46///Base binary tcp server.
47///
48/// 'C' is you codec, that you want to use to encode/decode data.
49///
50///Recommended default codec is LengthDelimitedCodec, from the server codec module.
51
52pub struct TfServer<C>
53where
54    C: TfCodec,
55{
56    router: Arc<TfServerRouter<C>>,
57    socket: Arc<TcpListener>,
58    shutdown_sig: Arc<Notify>,
59    processor: Option<TrafficProcessorHolder<C>>,
60    codec: C,
61    config: Option<ServerConfig>,
62    mode: ServerMode,
63}
64
65impl<C> TfServer<C>
66where
67    C: TfCodec,
68{
69    ///Creates a new instance of a server.
70    ///
71    /// 'bind_address' is a target address to bind current server. E.g: 0.0.0.0:8080
72    /// 'router' setted up router with handlers. Must be called commit_routes before using.
73    /// 'processor' Custom traffic processor, used for all streams.
74    /// 'codec' basically codec used for every stream with it's own instance, when the codec is applied to stream, first call is clone, the second call is initial_setup.
75    /// 'config' optional config for tls connection, when None the tls is not using, when some all connections are passed behind tls.
76    pub async fn new(
77        bind_address: String,
78        router: Arc<TfServerRouter<C>>,
79        processor: Option<TrafficProcessorHolder<C>>,
80        codec: C,
81        config: Option<ServerConfig>,
82        mode: ServerMode,
83    ) -> Self {
84        Self {
85            router,
86            socket: Arc::new(
87                TcpListener::bind(&bind_address)
88                    .await
89                    .expect("Failed to bind to address"),
90            ),
91            shutdown_sig: Arc::new(Notify::new()),
92            processor,
93            codec,
94            config,
95            mode
96        }
97    }
98
99    ///Start the task for handling connections.
100    ///
101    ///Return the join handle, of this task.
102    pub async fn start(&mut self) -> JoinHandle<()> {
103        let (listener, router, shutdown_sig) = (
104            self.socket.clone(),
105            self.router.clone(),
106            self.shutdown_sig.clone(),
107        );
108        let mut processor = if let Some(proc) = self.processor.take() {
109            proc
110        } else {
111            TrafficProcessorHolder::new()
112        };
113        let codec = self.codec.clone();
114        let config = self.config.clone();
115        let mode = self.mode.clone();   // ← new
116
117        tokio::spawn(async move {
118            loop {
119                tokio::select! {
120                res = listener.accept() => {
121                    if let Ok((stream, addr)) = res {
122                        let _ = stream.set_nodelay(true);
123                        let codec = codec.clone();
124                        let mode = mode.clone();    // ← new
125
126                        // ← swapped to new unified accept
127                        let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
128
129                        if let Some(mut transport) = transport {
130                            if processor.initial_connect(&mut transport.0).await {
131                                let mut framed = Framed::new(transport.0, transport.1);
132                                if processor.initial_framed_connect(&mut framed).await {
133                                    let router = router.clone();
134                                    let prc_clone = processor.clone();
135                                    tokio::spawn(async move {
136                                        Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
137                                    });
138                                }
139                            } else {
140                                let _ = transport.0.shutdown().await;
141                            }
142                        }
143                    }
144                }
145                _ = shutdown_sig.notified() => break,
146            }
147            }
148        })
149    }
150
151    ///Initial accept called for every connection, on connected event.
152    async fn initial_accept(
153        stream: TcpStream,
154        config: Option<ServerConfig>,
155        mut codec_setup: C,
156        mode: &ServerMode,
157    ) -> Option<(Transport, C)> {
158        let transport = match &config {
159            None => Transport::plain(stream),
160            Some(cfg) => {
161                let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
162                match acceptor.accept(stream).await {
163                    Ok(tls) => Transport::tls_server(tls),
164                    Err(_) => return None,
165                }
166            }
167        };
168
169
170        let mut transport = match mode {
171            ServerMode::Tcp => transport,
172            ServerMode::WebSocket => {
173                match Transport::accept_websocket(transport).await {
174                    Ok(ws_stream) => ws_stream,
175                    Err(e) => {
176                        eprintln!("WebSocket handshake failed: {e}");
177                        return None;
178                    }
179                }
180            }
181        };
182
183        if !codec_setup.initial_setup(&mut transport).await {
184            return None;
185        }
186
187        Some((transport, codec_setup))
188    }
189    ///Stops the acceptor task.
190    pub fn send_stop(&self) {
191        self.shutdown_sig.notify_waiters();
192    }
193
194    ///Main function for every connection
195    async fn handle_connection(
196        addr: SocketAddr,
197        mut stream: Framed<Transport, C>,
198        router: &TfServerRouter<C>,
199        mut processor: TrafficProcessorHolder<C>,
200    ) {
201        use futures_util::SinkExt;
202        let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
203        let mut move_sig = (Some(move_sig.0), move_sig.1);
204        loop {
205            let meta_data: Result<Option<BytesMut>, bool> =
206                Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
207            if meta_data.is_err() {
208                if meta_data.unwrap_err() {
209                    stream.close().await.unwrap();
210                    return;
211                }
212                continue;
213            }
214
215            let meta_data = meta_data.unwrap();
216            if meta_data.is_none() {
217                continue;
218            }
219            let meta_data = meta_data.unwrap();
220            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
221                Ok(meta) => meta.has_payload,
222                Err(_) => false,
223            };
224
225            let mut payload: BytesMut = BytesMut::new();
226            if has_payload {
227                let payload_res =
228                    Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
229                if payload_res.is_err() {
230                    if payload_res.unwrap_err() {
231                        stream.close().await.unwrap();
232                        return;
233                    }
234                    continue;
235                }
236                let payload_opt = payload_res.unwrap();
237                if payload_opt.is_none() {
238                    let _ = stream.close().await;
239                    return;
240                }
241                payload = payload_opt.unwrap();
242            }
243            let res = router
244                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
245                .await;
246
247            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
248            let res = Self::send_message(&mut stream, message, &mut processor).await;
249
250            if let Ok(requester) = move_sig.1.try_recv() {
251                requester
252                    .write()
253                    .await
254                    .accept_stream(addr, (stream, processor.clone()))
255                    .await;
256                return;
257            }
258
259            match res {
260                Err(_) => {
261                    let _ = stream.close();
262                    return;
263                }
264                _ => {}
265            }
266        }
267    }
268    async fn send_message(
269        stream: &mut Framed<Transport, C>,
270        message: Vec<u8>,
271        processor: &mut TrafficProcessorHolder<C>,
272    ) -> Result<(), io::Error> {
273        let message = Bytes::from(processor.post_process_traffic(message).await);
274        stream.send(message).await
275    }
276
277    async fn receive_message(
278        _: SocketAddr,
279        stream: &mut Framed<Transport, C>,
280        processor: &mut TrafficProcessorHolder<C>,
281    ) -> Result<Option<BytesMut>, bool> {
282        use futures_util::StreamExt;
283        match stream.next().await {
284            Some(data) => match data {
285                Ok(mut data) => {
286                    data = processor.pre_process_traffic(data).await;
287                    return Ok(Some(data));
288                }
289                Err(e) => {
290                    // This is where codec-level decoding errors happen
291                    match e.kind() {
292                        // IO errors usually mean the connection is broken
293                        std::io::ErrorKind::ConnectionReset
294                        | std::io::ErrorKind::ConnectionAborted
295                        | std::io::ErrorKind::BrokenPipe
296                        | std::io::ErrorKind::UnexpectedEof => {
297                            println!("Client disconnected");
298                            return Err(true);
299                        }
300
301                        // Frame too large (if you set max_frame_length)
302                        std::io::ErrorKind::InvalidData => {
303                            eprintln!("Frame exceeded maximum size: {e}");
304                            return Err(false);
305                        }
306
307                        // Other IO errors
308                        _ => {
309                            eprintln!("IO error while reading frame: {e}");
310                            return Err(false);
311                        }
312                    }
313                }
314            },
315            None => {
316                return Err(true);
317            }
318        }
319    }
320}
321
322// Custom Error Display
323impl fmt::Display for ServerErrorEn {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        match self {
326            ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
327                write!(f, "Malformed meta info: {}", msg)
328            }
329            ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
330            ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
331            ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
332            InternalError(Some(data)) => {
333                write!(
334                    f,
335                    "{}",
336                    String::from_utf8(data.clone())
337                        .unwrap_or_else(|_| "Internal server error!".to_owned())
338                )
339            }
340            InternalError(None) => {
341                write!(f, "Internal server error!")
342            }
343            ServerErrorEn::PayloadLost => {
344                write!(f, "Payload lost!")
345            }
346        }
347    }
348}
349
350impl std::error::Error for ServerErrorEn {}