turn_server/server.rs
1use crate::{
2 config::{Config, Interface},
3 router::Router,
4 statistics::Statistics,
5 turn::{Observer, Service},
6};
7
8use std::net::SocketAddr;
9
10#[allow(unused)]
11struct ServerStartOptions<T> {
12 bind: SocketAddr,
13 external: SocketAddr,
14 service: Service<T>,
15 router: Router,
16 statistics: Statistics,
17}
18
19#[allow(unused)]
20trait Server {
21 async fn start<T>(options: ServerStartOptions<T>) -> Result<(), anyhow::Error>
22 where
23 T: Clone + Observer + 'static;
24}
25
26#[cfg(feature = "udp")]
27mod udp {
28 use super::{Server as ServerExt, ServerStartOptions};
29 use crate::{
30 statistics::Stats,
31 stun::Transport,
32 turn::{Observer, ResponseMethod, SessionAddr},
33 };
34
35 use std::{io::ErrorKind::ConnectionReset, sync::Arc};
36
37 use tokio::net::UdpSocket;
38
39 /// udp socket process thread.
40 ///
41 /// read the data packet from the UDP socket and hand
42 /// it to the proto for processing, and send the processed
43 /// data packet to the specified address.
44 pub struct Server;
45
46 impl ServerExt for Server {
47 async fn start<T>(
48 ServerStartOptions {
49 bind,
50 external,
51 service,
52 router,
53 statistics,
54 }: ServerStartOptions<T>,
55 ) -> Result<(), anyhow::Error>
56 where
57 T: Clone + Observer + 'static,
58 {
59 let socket = Arc::new(UdpSocket::bind(bind).await?);
60 let local_addr = socket.local_addr()?;
61
62 {
63 let socket = socket.clone();
64 let router = router.clone();
65 let reporter = statistics.get_reporter(Transport::UDP);
66 let mut operationer = service.get_operationer(external, external);
67
68 let mut session_addr = SessionAddr {
69 address: external,
70 interface: external,
71 };
72
73 tokio::spawn(async move {
74 let mut buf = vec![0u8; 2048];
75
76 loop {
77 // Note: An error will also be reported when the remote host is
78 // shut down, which is not processed yet, but a
79 // warning will be issued.
80 let (size, addr) = match socket.recv_from(&mut buf).await {
81 Err(e) if e.kind() != ConnectionReset => break,
82 Ok(s) => s,
83 _ => continue,
84 };
85
86 session_addr.address = addr;
87
88 reporter.send(&session_addr, &[Stats::ReceivedBytes(size), Stats::ReceivedPkts(1)]);
89
90 // The stun message requires at least 4 bytes. (currently the
91 // smallest stun message is channel data,
92 // excluding content)
93 if size >= 4 {
94 if let Ok(Some(res)) = operationer.route(&buf[..size], addr) {
95 let target = res.relay.as_ref().unwrap_or(&addr);
96 if let Some(ref endpoint) = res.endpoint {
97 router.send(endpoint, res.method, target, res.bytes);
98 } else {
99 if let Err(e) = socket.send_to(res.bytes, target).await {
100 if e.kind() != ConnectionReset {
101 break;
102 }
103 }
104
105 reporter
106 .send(&session_addr, &[Stats::SendBytes(res.bytes.len()), Stats::SendPkts(1)]);
107
108 if let ResponseMethod::Stun(method) = res.method {
109 if method.is_error() {
110 reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
111 }
112 }
113 }
114 }
115 }
116 }
117 });
118 }
119
120 tokio::spawn(async move {
121 let mut session_addr = SessionAddr {
122 address: external,
123 interface: external,
124 };
125
126 let reporter = statistics.get_reporter(Transport::UDP);
127 let mut receiver = router.get_receiver(external);
128 while let Some((bytes, _, addr)) = receiver.recv().await {
129 session_addr.address = addr;
130
131 if let Err(e) = socket.send_to(&bytes, addr).await {
132 if e.kind() != ConnectionReset {
133 break;
134 }
135 } else {
136 reporter.send(&session_addr, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
137 }
138 }
139
140 router.remove(&external);
141
142 log::error!("udp server close: interface={:?}", local_addr);
143 });
144
145 log::info!(
146 "turn server listening: bind={}, external={}, transport=UDP",
147 bind,
148 external,
149 );
150
151 Ok(())
152 }
153 }
154}
155
156#[cfg(feature = "tcp")]
157mod tcp {
158 use super::{Server as ServerExt, ServerStartOptions};
159 use crate::{
160 statistics::Stats,
161 stun::{Decoder, Transport},
162 turn::{Observer, ResponseMethod, SessionAddr},
163 };
164
165 use std::{
166 ops::{Deref, DerefMut},
167 sync::Arc,
168 };
169
170 use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::TcpListener, sync::Mutex};
171
172 static ZERO_BYTES: [u8; 8] = [0u8; 8];
173
174 /// An emulated double buffer queue, this is used when reading data over
175 /// TCP.
176 ///
177 /// When reading data over TCP, you need to keep adding to the buffer until
178 /// you find the delimited position. But this double buffer queue solves
179 /// this problem well, in the queue, the separation is treated as the first
180 /// read operation and after the separation the buffer is reversed and
181 /// another free buffer is used for writing the data.
182 ///
183 /// If the current buffer in the separation after the existence of
184 /// unconsumed data, this time the unconsumed data will be copied to another
185 /// free buffer, and fill the length of the free buffer data, this time to
186 /// write data again when you can continue to fill to the end of the
187 /// unconsumed data.
188 ///
189 /// This queue only needs to copy the unconsumed data without duplicating
190 /// the memory allocation, which will reduce a lot of overhead.
191 struct ExchangeBuffer {
192 buffers: [(Vec<u8>, usize /* len */); 2],
193 index: usize,
194 }
195
196 impl Default for ExchangeBuffer {
197 #[rustfmt::skip]
198 fn default() -> Self {
199 Self {
200 index: 0,
201 buffers: [
202 (vec![0u8; 2048], 0),
203 (vec![0u8; 2048], 0),
204 ],
205 }
206 }
207 }
208
209 impl Deref for ExchangeBuffer {
210 type Target = [u8];
211
212 fn deref(&self) -> &Self::Target {
213 &self.buffers[self.index].0[..]
214 }
215 }
216
217 impl DerefMut for ExchangeBuffer {
218 // Writes need to take into account overwriting written data, so fetching the
219 // writable buffer starts with the internal cursor.
220 fn deref_mut(&mut self) -> &mut Self::Target {
221 let len = self.buffers[self.index].1;
222 &mut self.buffers[self.index].0[len..]
223 }
224 }
225
226 impl ExchangeBuffer {
227 fn len(&self) -> usize {
228 self.buffers[self.index].1
229 }
230
231 /// The buffer does not automatically advance the cursor as BytesMut
232 /// does, and you need to manually advance the length of the data
233 /// written.
234 fn advance(&mut self, len: usize) {
235 self.buffers[self.index].1 += len;
236 }
237
238 fn split(&mut self, len: usize) -> &[u8] {
239 let (ref current_bytes, current_len) = self.buffers[self.index];
240
241 // The length of the separation cannot be greater than the length of the data.
242 assert!(len <= current_len);
243
244 // Length of unconsumed data
245 let remaining = current_len - len;
246
247 {
248 // The current buffer is no longer in use, resetting the content length.
249 self.buffers[self.index].1 = 0;
250
251 // Invert the buffer.
252 self.index = if self.index == 0 { 1 } else { 0 };
253
254 // The length of unconsumed data needs to be updated into the reversed
255 // completion buffer.
256 self.buffers[self.index].1 = remaining;
257 }
258
259 // Unconsumed data exists and is copied to the free buffer.
260 #[allow(mutable_transmutes)]
261 if remaining > 0 {
262 unsafe { std::mem::transmute::<&[u8], &mut [u8]>(&self.buffers[self.index].0[..remaining]) }
263 .copy_from_slice(¤t_bytes[len..current_len]);
264 }
265
266 ¤t_bytes[..len]
267 }
268 }
269
270 /// tcp socket process thread.
271 ///
272 /// This function is used to handle all connections coming from the tcp
273 /// listener, and handle the receiving, sending and forwarding of messages.
274 pub struct Server;
275
276 impl ServerExt for Server {
277 async fn start<T>(
278 ServerStartOptions {
279 bind,
280 external,
281 service,
282 router,
283 statistics,
284 }: ServerStartOptions<T>,
285 ) -> Result<(), anyhow::Error>
286 where
287 T: Clone + Observer + 'static,
288 {
289 let listener = TcpListener::bind(bind).await?;
290 let local_addr = listener.local_addr()?;
291
292 tokio::spawn(async move {
293 // Accept all connections on the current listener, but exit the entire
294 // process when an error occurs.
295 while let Ok((socket, address)) = listener.accept().await {
296 let router = router.clone();
297 let reporter = statistics.get_reporter(Transport::TCP);
298 let mut receiver = router.get_receiver(address);
299 let mut operationer = service.get_operationer(address, external);
300
301 log::info!("tcp socket accept: addr={:?}, interface={:?}", address, local_addr,);
302
303 // Disable the Nagle algorithm.
304 // because to maintain real-time, any received data should be processed
305 // as soon as possible.
306 if let Err(e) = socket.set_nodelay(true) {
307 log::error!("tcp socket set nodelay failed!: addr={}, err={}", address, e);
308 }
309
310 let session_addr = SessionAddr {
311 interface: external,
312 address,
313 };
314
315 let (mut reader, writer) = socket.into_split();
316 let writer = Arc::new(Mutex::new(writer));
317
318 // Use a separate task to handle messages forwarded to this socket.
319 let writer_ = writer.clone();
320 let reporter_ = reporter.clone();
321 tokio::spawn(async move {
322 while let Some((bytes, method, _)) = receiver.recv().await {
323 let mut writer = writer_.lock().await;
324 if writer.write_all(bytes.as_slice()).await.is_err() {
325 break;
326 } else {
327 reporter_.send(&session_addr, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
328 }
329
330 // The channel data needs to be aligned in multiples of 4 in
331 // tcp. If the channel data is forwarded to tcp, the alignment
332 // bit needs to be filled, because if the channel data comes
333 // from udp, it is not guaranteed to be aligned and needs to be
334 // checked.
335 if method == ResponseMethod::ChannelData {
336 let pad = bytes.len() % 4;
337 if pad > 0 && writer.write_all(&ZERO_BYTES[..(4 - pad)]).await.is_err() {
338 break;
339 }
340 }
341 }
342 });
343
344 let sessions = service.get_sessions();
345 tokio::spawn(async move {
346 let mut buffer = ExchangeBuffer::default();
347
348 'a: while let Ok(size) = reader.read(&mut buffer).await {
349 // When the received message is 0, it means that the socket
350 // has been closed.
351 if size == 0 {
352 break;
353 } else {
354 reporter.send(&session_addr, &[Stats::ReceivedBytes(size)]);
355 buffer.advance(size);
356 }
357
358 // The minimum length of a stun message will not be less
359 // than 4.
360 if buffer.len() < 4 {
361 continue;
362 }
363
364 loop {
365 if buffer.len() <= 4 {
366 break;
367 }
368
369 // Try to get the message length, if the currently
370 // received data is less than the message length, jump
371 // out of the current loop and continue to receive more
372 // data.
373 let size = match Decoder::message_size(&buffer, true) {
374 Err(_) => break,
375 Ok(s) => {
376 // Limit the maximum length of messages to 2048, this is to prevent buffer
377 // overflow attacks.
378 if s > 2048 {
379 break 'a;
380 }
381
382 if s > buffer.len() {
383 break;
384 }
385
386 reporter.send(&session_addr, &[Stats::ReceivedPkts(1)]);
387
388 s
389 }
390 };
391
392 let chunk = buffer.split(size);
393 if let Ok(ret) = operationer.route(chunk, address) {
394 if let Some(res) = ret {
395 if let Some(ref inerface) = res.endpoint {
396 router.send(
397 inerface,
398 res.method,
399 res.relay.as_ref().unwrap_or(&address),
400 res.bytes,
401 );
402 } else {
403 if writer.lock().await.write_all(res.bytes).await.is_err() {
404 break 'a;
405 }
406
407 reporter.send(
408 &session_addr,
409 &[Stats::SendBytes(res.bytes.len()), Stats::SendPkts(1)],
410 );
411
412 if let ResponseMethod::Stun(method) = res.method {
413 if method.is_error() {
414 reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
415 }
416 }
417 }
418 }
419 } else {
420 break 'a;
421 }
422 }
423 }
424
425 // When the tcp connection is closed, the procedure to close the session is
426 // process directly once, avoiding the connection being disconnected
427 // directly without going through the closing
428 // process.
429 sessions.refresh(&session_addr, 0);
430
431 router.remove(&address);
432
433 log::info!("tcp socket disconnect: addr={:?}, interface={:?}", address, local_addr);
434 });
435 }
436
437 log::error!("tcp server close: interface={:?}", local_addr);
438 });
439
440 log::info!(
441 "turn server listening: bind={}, external={}, transport=TCP",
442 bind,
443 external,
444 );
445
446 Ok(())
447 }
448 }
449}
450
451/// start turn server.
452///
453/// create a specified number of threads,
454/// each thread processes udp data separately.
455pub async fn start<T>(config: &Config, statistics: &Statistics, service: &Service<T>) -> anyhow::Result<()>
456where
457 T: Clone + Observer + 'static,
458{
459 #[allow(unused)]
460 use crate::config::Transport;
461
462 let router = Router::default();
463 for Interface {
464 transport,
465 external,
466 bind,
467 } in config.turn.interfaces.iter().cloned()
468 {
469 #[allow(unused)]
470 let options = ServerStartOptions {
471 statistics: statistics.clone(),
472 service: service.clone(),
473 router: router.clone(),
474 external,
475 bind,
476 };
477
478 match transport {
479 #[cfg(feature = "udp")]
480 Transport::UDP => udp::Server::start(options).await?,
481 #[cfg(feature = "tcp")]
482 Transport::TCP => tcp::Server::start(options).await?,
483 #[allow(unreachable_patterns)]
484 _ => (),
485 };
486 }
487
488 Ok(())
489}