1use std::net::ToSocketAddrs;
2use std::net::{TcpListener, TcpStream};
3use std::time::{Duration, Instant};
4
5use std::sync::{Arc, Mutex};
6
7use super::protocol::MessageReader;
8
9pub struct Settings<F, F1, F2, S>
10where
11 F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
12 F1: Fn(&mut SocketHandle, &S) + Send + 'static,
13 F2: Fn(&S) + Send + 'static,
14 S: Default + Send,
15{
16 pub on_message: F,
17 pub on_timer: F1,
18 pub on_disconnect: F2,
19 pub timer: Option<Duration>,
20
21 pub _marker: std::marker::PhantomData<S>,
22}
23
24enum Sender<'a> {
25 WebSocket(&'a ws::Sender),
26 Tcp(&'a mut TcpStream),
27}
28
29pub struct SocketHandle<'a> {
30 sender: Sender<'a>,
31 disconnect: bool,
32}
33
34impl<'a> Sender<'a> {
35 fn send(&mut self, data: &[u8]) -> Option<()> {
36 use std::io::Write;
37
38 match self {
39 Sender::WebSocket(out) => {
40 out.send(data).ok()?;
41 }
42 Sender::Tcp(stream) => {
43 stream.write(&[data.len() as u8]).ok()?;
44 stream.write(data).ok()?;
45 }
46 }
47
48 Some(())
49 }
50}
51
52impl<'a> SocketHandle<'a> {
53 fn new(sender: Sender<'a>) -> SocketHandle<'a> {
54 SocketHandle {
55 sender,
56 disconnect: false,
57 }
58 }
59
60 pub fn send(&mut self, data: &[u8]) -> Result<(), ()> {
61 self.sender.send(data).ok_or(())
62 }
63
64 #[cfg(feature = "nanoserde")]
65 pub fn send_bin<T: nanoserde::SerBin>(&mut self, data: &T) -> Result<(), ()> {
66 self.send(&nanoserde::SerBin::serialize_bin(data))
67 }
68
69 pub fn disconnect(&mut self) {
70 self.disconnect = true;
71 }
72}
73
74pub fn listen<A, A1, F, F1, F2, S>(tcp_addr: A, ws_addr: A1, settings: Settings<F, F1, F2, S>)
75where
76 A: ToSocketAddrs + std::fmt::Debug + Send,
77 A1: ToSocketAddrs + std::fmt::Debug + Send + 'static,
78 F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
79 F1: Fn(&mut SocketHandle, &S) + Send + 'static,
80 F2: Fn(&S) + Send + 'static,
81 S: Default + Send + 'static,
82{
83 let on_message = Arc::new(Mutex::new(settings.on_message));
84 let on_timer = Arc::new(Mutex::new(settings.on_timer));
85 let on_disconnect = Arc::new(Mutex::new(settings.on_disconnect));
86 let timer = settings.timer;
87
88 struct WsHandler<
89 S: Default,
90 F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
91 F1: Fn(&mut SocketHandle, &S) + Send + 'static,
92 F2: Fn(&S) + Send + 'static,
93 > {
94 out: ws::Sender,
95 state: S,
96 on_message: Arc<Mutex<F>>,
97 on_timer: Arc<Mutex<F1>>,
98 on_disconnect: Arc<Mutex<F2>>,
99 timeout: Option<Duration>,
100 }
101
102 impl<
103 S: Default,
104 F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
105 F1: Fn(&mut SocketHandle, &S) + Send + 'static,
106 F2: Fn(&S) + Send + 'static,
107 > ws::Handler for WsHandler<S, F, F1, F2>
108 {
109 fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
110 let data = msg.into_data();
111 let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
112 (self.on_message.lock().unwrap())(&mut handle, &mut self.state, data);
113 if handle.disconnect {
114 self.out.close(ws::CloseCode::Normal)?;
115 }
116 Ok(())
117 }
118
119 fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> {
120 if let Some(timeout) = self.timeout {
121 self.out
122 .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
123 }
124 Ok(())
125 }
126
127 fn on_timeout(&mut self, _: ws::util::Token) -> ws::Result<()> {
128 if let Some(timeout) = self.timeout {
129 let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
130 (self.on_timer.lock().unwrap())(&mut handle, &self.state);
131 if handle.disconnect == false {
132 self.out
133 .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
134 } else {
135 self.out.close(ws::CloseCode::Normal)?;
136 }
137 }
138 Ok(())
139 }
140
141 fn on_close(&mut self, _code: ws::CloseCode, _reason: &str) {
142 (self.on_disconnect.lock().unwrap())(&self.state);
143 }
144 }
145
146 std::thread::spawn({
147 let on_message = on_message.clone();
148 let on_timer = on_timer.clone();
149 let on_disconnect = on_disconnect.clone();
150
151 move || {
152 ws::Builder::new()
153 .with_settings(ws::Settings {
154 timer_tick_millis: 10,
155 tcp_nodelay: true,
156 ..ws::Settings::default()
157 })
158 .build(move |out| {
159 let on_message = on_message.clone();
160 let on_timer = on_timer.clone();
161 let on_disconnect = on_disconnect.clone();
162
163 WsHandler {
164 out,
165 state: S::default(),
166 on_message,
167 on_timer,
168 on_disconnect,
169 timeout: timer,
170 }
171 })
172 .unwrap()
173 .listen(ws_addr)
174 .unwrap();
175 }
176 });
177
178 let listener = TcpListener::bind(tcp_addr).unwrap();
179 for stream in listener.incoming() {
180 let on_message = on_message.clone();
181 let on_timer = on_timer.clone();
182 let on_disconnect = on_disconnect.clone();
183
184 std::thread::spawn(move || {
185 let mut stream = stream.unwrap();
186 stream.set_nodelay(true).unwrap();
187 stream.set_nonblocking(true).unwrap();
188 let mut message_reader = MessageReader::new();
189 let mut state = S::default();
190
191 let mut time = Instant::now();
192 loop {
193 match message_reader.next(&mut stream) {
194 Ok(Some(message)) => {
195 let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));
196 (on_message.lock().unwrap())(&mut handle, &mut state, message);
197 if handle.disconnect {
198 (on_disconnect.lock().unwrap())(&state);
199 return;
200 }
201 }
202 Ok(None) => {}
203 Err(_err) => {
204 (on_disconnect.lock().unwrap())(&state);
205 return;
206 }
207 }
208
209 if let Some(timer) = timer {
210 if time.elapsed() >= timer {
211 time = Instant::now();
212 let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));
213
214 (on_timer.lock().unwrap())(&mut handle, &state);
215 if handle.disconnect {
216 (on_disconnect.lock().unwrap())(&state);
217 return;
218 }
219 }
220 }
221 }
222 });
223 }
224}