quad_net/quad_socket/
server.rs

1use std::net::ToSocketAddrs;
2use std::net::{TcpListener, TcpStream};
3use std::time::{Duration, Instant};
4
5use std::sync::{Arc, Mutex};
6
7use super::protocol::MessageReader;
8
9pub struct Settings<F, F1, F2, S>
10where
11    F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
12    F1: Fn(&mut SocketHandle, &S) + Send + 'static,
13    F2: Fn(&S) + Send + 'static,
14    S: Default + Send,
15{
16    pub on_message: F,
17    pub on_timer: F1,
18    pub on_disconnect: F2,
19    pub timer: Option<Duration>,
20
21    pub _marker: std::marker::PhantomData<S>,
22}
23
24enum Sender<'a> {
25    WebSocket(&'a ws::Sender),
26    Tcp(&'a mut TcpStream),
27}
28
29pub struct SocketHandle<'a> {
30    sender: Sender<'a>,
31    disconnect: bool,
32}
33
34impl<'a> Sender<'a> {
35    fn send(&mut self, data: &[u8]) -> Option<()> {
36        use std::io::Write;
37
38        match self {
39            Sender::WebSocket(out) => {
40                out.send(data).ok()?;
41            }
42            Sender::Tcp(stream) => {
43                stream.write(&[data.len() as u8]).ok()?;
44                stream.write(data).ok()?;
45            }
46        }
47
48        Some(())
49    }
50}
51
52impl<'a> SocketHandle<'a> {
53    fn new(sender: Sender<'a>) -> SocketHandle<'a> {
54        SocketHandle {
55            sender,
56            disconnect: false,
57        }
58    }
59
60    pub fn send(&mut self, data: &[u8]) -> Result<(), ()> {
61        self.sender.send(data).ok_or(())
62    }
63
64    #[cfg(feature = "nanoserde")]
65    pub fn send_bin<T: nanoserde::SerBin>(&mut self, data: &T) -> Result<(), ()> {
66        self.send(&nanoserde::SerBin::serialize_bin(data))
67    }
68
69    pub fn disconnect(&mut self) {
70        self.disconnect = true;
71    }
72}
73
74pub fn listen<A, A1, F, F1, F2, S>(tcp_addr: A, ws_addr: A1, settings: Settings<F, F1, F2, S>)
75where
76    A: ToSocketAddrs + std::fmt::Debug + Send,
77    A1: ToSocketAddrs + std::fmt::Debug + Send + 'static,
78    F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
79    F1: Fn(&mut SocketHandle, &S) + Send + 'static,
80    F2: Fn(&S) + Send + 'static,
81    S: Default + Send + 'static,
82{
83    let on_message = Arc::new(Mutex::new(settings.on_message));
84    let on_timer = Arc::new(Mutex::new(settings.on_timer));
85    let on_disconnect = Arc::new(Mutex::new(settings.on_disconnect));
86    let timer = settings.timer;
87
88    struct WsHandler<
89        S: Default,
90        F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
91        F1: Fn(&mut SocketHandle, &S) + Send + 'static,
92        F2: Fn(&S) + Send + 'static,
93    > {
94        out: ws::Sender,
95        state: S,
96        on_message: Arc<Mutex<F>>,
97        on_timer: Arc<Mutex<F1>>,
98        on_disconnect: Arc<Mutex<F2>>,
99        timeout: Option<Duration>,
100    }
101
102    impl<
103            S: Default,
104            F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
105            F1: Fn(&mut SocketHandle, &S) + Send + 'static,
106            F2: Fn(&S) + Send + 'static,
107        > ws::Handler for WsHandler<S, F, F1, F2>
108    {
109        fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
110            let data = msg.into_data();
111            let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
112            (self.on_message.lock().unwrap())(&mut handle, &mut self.state, data);
113            if handle.disconnect {
114                self.out.close(ws::CloseCode::Normal)?;
115            }
116            Ok(())
117        }
118
119        fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> {
120            if let Some(timeout) = self.timeout {
121                self.out
122                    .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
123            }
124            Ok(())
125        }
126
127        fn on_timeout(&mut self, _: ws::util::Token) -> ws::Result<()> {
128            if let Some(timeout) = self.timeout {
129                let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
130                (self.on_timer.lock().unwrap())(&mut handle, &self.state);
131                if handle.disconnect == false {
132                    self.out
133                        .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
134                } else {
135                    self.out.close(ws::CloseCode::Normal)?;
136                }
137            }
138            Ok(())
139        }
140
141        fn on_close(&mut self, _code: ws::CloseCode, _reason: &str) {
142            (self.on_disconnect.lock().unwrap())(&self.state);
143        }
144    }
145
146    std::thread::spawn({
147        let on_message = on_message.clone();
148        let on_timer = on_timer.clone();
149        let on_disconnect = on_disconnect.clone();
150
151        move || {
152            ws::Builder::new()
153                .with_settings(ws::Settings {
154                    timer_tick_millis: 10,
155                    tcp_nodelay: true,
156                    ..ws::Settings::default()
157                })
158                .build(move |out| {
159                    let on_message = on_message.clone();
160                    let on_timer = on_timer.clone();
161                    let on_disconnect = on_disconnect.clone();
162
163                    WsHandler {
164                        out,
165                        state: S::default(),
166                        on_message,
167                        on_timer,
168                        on_disconnect,
169                        timeout: timer,
170                    }
171                })
172                .unwrap()
173                .listen(ws_addr)
174                .unwrap();
175        }
176    });
177
178    let listener = TcpListener::bind(tcp_addr).unwrap();
179    for stream in listener.incoming() {
180        let on_message = on_message.clone();
181        let on_timer = on_timer.clone();
182        let on_disconnect = on_disconnect.clone();
183
184        std::thread::spawn(move || {
185            let mut stream = stream.unwrap();
186            stream.set_nodelay(true).unwrap();
187            stream.set_nonblocking(true).unwrap();
188            let mut message_reader = MessageReader::new();
189            let mut state = S::default();
190
191            let mut time = Instant::now();
192            loop {
193                match message_reader.next(&mut stream) {
194                    Ok(Some(message)) => {
195                        let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));
196                        (on_message.lock().unwrap())(&mut handle, &mut state, message);
197                        if handle.disconnect {
198                            (on_disconnect.lock().unwrap())(&state);
199                            return;
200                        }
201                    }
202                    Ok(None) => {}
203                    Err(_err) => {
204                        (on_disconnect.lock().unwrap())(&state);
205                        return;
206                    }
207                }
208
209                if let Some(timer) = timer {
210                    if time.elapsed() >= timer {
211                        time = Instant::now();
212                        let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));
213
214                        (on_timer.lock().unwrap())(&mut handle, &state);
215                        if handle.disconnect {
216                            (on_disconnect.lock().unwrap())(&state);
217                            return;
218                        }
219                    }
220                }
221            }
222        });
223    }
224}