turn_server/server.rs
1use crate::{
2 config::{Config, Interface},
3 router::Router,
4 statistics::Statistics,
5};
6
7use std::net::SocketAddr;
8
9use turn::{Observer, Service};
10
11#[allow(unused)]
12struct ServerStartOptions<T> {
13 bind: SocketAddr,
14 external: SocketAddr,
15 service: Service<T>,
16 router: Router,
17 statistics: Statistics,
18}
19
20#[allow(unused)]
21trait Server {
22 async fn start<T>(options: ServerStartOptions<T>) -> Result<(), anyhow::Error>
23 where
24 T: Clone + Observer + 'static;
25}
26
27#[cfg(feature = "udp")]
28mod udp {
29 use super::{Server as ServerExt, ServerStartOptions};
30 use crate::statistics::Stats;
31
32 use std::{io::ErrorKind::ConnectionReset, ops::Deref, sync::Arc};
33
34 use once_cell::sync::Lazy;
35 use stun::Transport;
36 use tokio::net::UdpSocket;
37 use turn::{Observer, ResponseMethod, SessionAddr};
38
39 static NUM_CPUS: Lazy<usize> = Lazy::new(|| num_cpus::get());
40
41 /// udp socket process thread.
42 ///
43 /// read the data packet from the UDP socket and hand
44 /// it to the proto for processing, and send the processed
45 /// data packet to the specified address.
46 pub struct Server;
47
48 impl ServerExt for Server {
49 async fn start<T>(
50 ServerStartOptions {
51 bind,
52 external,
53 service,
54 router,
55 statistics,
56 }: ServerStartOptions<T>,
57 ) -> Result<(), anyhow::Error>
58 where
59 T: Clone + Observer + 'static,
60 {
61 let socket = Arc::new(UdpSocket::bind(bind).await?);
62 let local_addr = socket.local_addr()?;
63
64 tokio::spawn(async move {
65 for _ in 0..*NUM_CPUS.deref() {
66 let socket = socket.clone();
67 let router = router.clone();
68 let reporter = statistics.get_reporter(Transport::UDP);
69 let mut operationer = service.get_operationer(external, external);
70
71 let mut session_addr = SessionAddr {
72 address: external,
73 interface: external,
74 };
75
76 tokio::spawn(async move {
77 let mut buf = vec![0u8; 2048];
78
79 loop {
80 // Note: An error will also be reported when the remote host is
81 // shut down, which is not processed yet, but a
82 // warning will be issued.
83 let (size, addr) = match socket.recv_from(&mut buf).await {
84 Err(e) if e.kind() != ConnectionReset => break,
85 Ok(s) => s,
86 _ => continue,
87 };
88
89 session_addr.address = addr;
90
91 reporter.send(
92 &session_addr,
93 &[Stats::ReceivedBytes(size as u32), Stats::ReceivedPkts(1)],
94 );
95
96 // The stun message requires at least 4 bytes. (currently the
97 // smallest stun message is channel data,
98 // excluding content)
99 if size >= 4 {
100 if let Ok(Some(res)) = operationer.route(&buf[..size], addr).await {
101 let target = res.relay.as_ref().unwrap_or(&addr);
102 if let Some(ref endpoint) = res.endpoint {
103 router.send(endpoint, res.method, target, res.bytes);
104 } else {
105 if let Err(e) = socket.send_to(res.bytes, target).await {
106 if e.kind() != ConnectionReset {
107 break;
108 }
109 }
110
111 reporter.send(
112 &session_addr,
113 &[Stats::SendBytes(res.bytes.len() as u32), Stats::SendPkts(1)],
114 );
115
116 if let ResponseMethod::Stun(method) = res.method {
117 if method.is_error() {
118 reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
119 }
120 }
121 }
122 }
123 }
124 }
125 });
126 }
127
128 {
129 let mut session_addr = SessionAddr {
130 address: external,
131 interface: external,
132 };
133
134 let reporter = statistics.get_reporter(Transport::UDP);
135 let mut receiver = router.get_receiver(external);
136 while let Some((bytes, _, addr)) = receiver.recv().await {
137 session_addr.address = addr;
138
139 if let Err(e) = socket.send_to(&bytes, addr).await {
140 if e.kind() != ConnectionReset {
141 break;
142 }
143 } else {
144 reporter.send(
145 &session_addr,
146 &[Stats::SendBytes(bytes.len() as u32), Stats::SendPkts(1)],
147 );
148 }
149 }
150
151 router.remove(&external);
152 }
153
154 log::error!("udp server close: interface={:?}", local_addr);
155 });
156
157 log::info!(
158 "turn server listening: bind={}, external={}, transport=UDP",
159 bind,
160 external,
161 );
162
163 Ok(())
164 }
165 }
166}
167
168#[cfg(feature = "tcp")]
169mod tcp {
170 use super::{Server as ServerExt, ServerStartOptions};
171 use crate::statistics::Stats;
172
173 use std::{
174 ops::{Deref, DerefMut},
175 sync::Arc,
176 };
177
178 use stun::{Decoder, Transport};
179 use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::TcpListener, sync::Mutex};
180 use turn::{Observer, ResponseMethod, SessionAddr};
181
182 static ZERO_BYTES: [u8; 8] = [0u8; 8];
183
184 /// An emulated double buffer queue, this is used when reading data over
185 /// TCP.
186 ///
187 /// When reading data over TCP, you need to keep adding to the buffer until
188 /// you find the delimited position. But this double buffer queue solves
189 /// this problem well, in the queue, the separation is treated as the first
190 /// read operation and after the separation the buffer is reversed and
191 /// another free buffer is used for writing the data.
192 ///
193 /// If the current buffer in the separation after the existence of
194 /// unconsumed data, this time the unconsumed data will be copied to another
195 /// free buffer, and fill the length of the free buffer data, this time to
196 /// write data again when you can continue to fill to the end of the
197 /// unconsumed data.
198 ///
199 /// This queue only needs to copy the unconsumed data without duplicating
200 /// the memory allocation, which will reduce a lot of overhead.
201 struct ExchangeBuffer {
202 buffers: [(Vec<u8>, usize /* len */); 2],
203 index: usize,
204 }
205
206 impl Default for ExchangeBuffer {
207 #[rustfmt::skip]
208 fn default() -> Self {
209 Self {
210 index: 0,
211 buffers: [
212 (vec![0u8; 2048], 0),
213 (vec![0u8; 2048], 0),
214 ],
215 }
216 }
217 }
218
219 impl Deref for ExchangeBuffer {
220 type Target = [u8];
221
222 fn deref(&self) -> &Self::Target {
223 &self.buffers[self.index].0[..]
224 }
225 }
226
227 impl DerefMut for ExchangeBuffer {
228 // Writes need to take into account overwriting written data, so fetching the
229 // writable buffer starts with the internal cursor.
230 fn deref_mut(&mut self) -> &mut Self::Target {
231 let len = self.buffers[self.index].1;
232 &mut self.buffers[self.index].0[len..]
233 }
234 }
235
236 impl ExchangeBuffer {
237 fn len(&self) -> usize {
238 self.buffers[self.index].1
239 }
240
241 /// The buffer does not automatically advance the cursor as BytesMut
242 /// does, and you need to manually advance the length of the data
243 /// written.
244 fn advance(&mut self, len: usize) {
245 self.buffers[self.index].1 += len;
246 }
247
248 fn split(&mut self, len: usize) -> &[u8] {
249 let (ref current_bytes, current_len) = self.buffers[self.index];
250
251 // The length of the separation cannot be greater than the length of the data.
252 assert!(len <= current_len);
253
254 // Length of unconsumed data
255 let remaining = current_len - len;
256
257 {
258 // The current buffer is no longer in use, resetting the content length.
259 self.buffers[self.index].1 = 0;
260
261 // Invert the buffer.
262 self.index = if self.index == 0 { 1 } else { 0 };
263
264 // The length of unconsumed data needs to be updated into the reversed
265 // completion buffer.
266 self.buffers[self.index].1 = remaining;
267 }
268
269 // Unconsumed data exists and is copied to the free buffer.
270 #[allow(mutable_transmutes)]
271 if remaining > 0 {
272 unsafe { std::mem::transmute::<&[u8], &mut [u8]>(&self.buffers[self.index].0[..remaining]) }
273 .copy_from_slice(¤t_bytes[len..current_len]);
274 }
275
276 ¤t_bytes[..len]
277 }
278 }
279
280 /// tcp socket process thread.
281 ///
282 /// This function is used to handle all connections coming from the tcp
283 /// listener, and handle the receiving, sending and forwarding of messages.
284 pub struct Server;
285
286 impl ServerExt for Server {
287 async fn start<T>(
288 ServerStartOptions {
289 bind,
290 external,
291 service,
292 router,
293 statistics,
294 }: ServerStartOptions<T>,
295 ) -> Result<(), anyhow::Error>
296 where
297 T: Clone + Observer + 'static,
298 {
299 let listener = TcpListener::bind(bind).await?;
300 let local_addr = listener.local_addr()?;
301
302 tokio::spawn(async move {
303 // Accept all connections on the current listener, but exit the entire
304 // process when an error occurs.
305 while let Ok((socket, address)) = listener.accept().await {
306 let router = router.clone();
307 let reporter = statistics.get_reporter(Transport::TCP);
308 let mut receiver = router.get_receiver(address);
309 let mut operationer = service.get_operationer(address, external);
310
311 log::info!("tcp socket accept: addr={:?}, interface={:?}", address, local_addr,);
312
313 // Disable the Nagle algorithm.
314 // because to maintain real-time, any received data should be processed
315 // as soon as possible.
316 if let Err(e) = socket.set_nodelay(true) {
317 log::error!("tcp socket set nodelay failed!: addr={}, err={}", address, e);
318 }
319
320 let session_addr = SessionAddr {
321 interface: external,
322 address,
323 };
324
325 let (mut reader, writer) = socket.into_split();
326 let writer = Arc::new(Mutex::new(writer));
327
328 // Use a separate task to handle messages forwarded to this socket.
329 let writer_ = writer.clone();
330 let reporter_ = reporter.clone();
331 tokio::spawn(async move {
332 while let Some((bytes, method, _)) = receiver.recv().await {
333 let mut writer = writer_.lock().await;
334 if writer.write_all(bytes.as_slice()).await.is_err() {
335 break;
336 } else {
337 reporter_.send(
338 &session_addr,
339 &[Stats::SendBytes(bytes.len() as u32), Stats::SendPkts(1)],
340 );
341 }
342
343 // The channel data needs to be aligned in multiples of 4 in
344 // tcp. If the channel data is forwarded to tcp, the alignment
345 // bit needs to be filled, because if the channel data comes
346 // from udp, it is not guaranteed to be aligned and needs to be
347 // checked.
348 if method == ResponseMethod::ChannelData {
349 let pad = bytes.len() % 4;
350 if pad > 0 && writer.write_all(&ZERO_BYTES[..(4 - pad)]).await.is_err() {
351 break;
352 }
353 }
354 }
355 });
356
357 let sessions = service.get_sessions();
358 tokio::spawn(async move {
359 let mut buffer = ExchangeBuffer::default();
360
361 'a: while let Ok(size) = reader.read(&mut buffer).await {
362 // When the received message is 0, it means that the socket
363 // has been closed.
364 if size == 0 {
365 break;
366 } else {
367 reporter.send(&session_addr, &[Stats::ReceivedBytes(size as u32)]);
368 buffer.advance(size);
369 }
370
371 // The minimum length of a stun message will not be less
372 // than 4.
373 if buffer.len() < 4 {
374 continue;
375 }
376
377 loop {
378 if buffer.len() <= 4 {
379 break;
380 }
381
382 // Try to get the message length, if the currently
383 // received data is less than the message length, jump
384 // out of the current loop and continue to receive more
385 // data.
386 let size = match Decoder::message_size(&buffer, true) {
387 Err(_) => break,
388 Ok(s) => {
389 // Limit the maximum length of messages to 2048, this is to prevent buffer
390 // overflow attacks.
391 if s > 2048 {
392 break 'a;
393 }
394
395 if s > buffer.len() {
396 break;
397 }
398
399 reporter.send(&session_addr, &[Stats::ReceivedPkts(1)]);
400
401 s
402 }
403 };
404
405 let chunk = buffer.split(size);
406 if let Ok(ret) = operationer.route(chunk, address).await {
407 if let Some(res) = ret {
408 if let Some(ref inerface) = res.endpoint {
409 router.send(
410 inerface,
411 res.method,
412 res.relay.as_ref().unwrap_or(&address),
413 res.bytes,
414 );
415 } else {
416 if writer.lock().await.write_all(res.bytes).await.is_err() {
417 break 'a;
418 }
419
420 reporter.send(
421 &session_addr,
422 &[Stats::SendBytes(res.bytes.len() as u32), Stats::SendPkts(1)],
423 );
424
425 if let ResponseMethod::Stun(method) = res.method {
426 if method.is_error() {
427 reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
428 }
429 }
430 }
431 }
432 } else {
433 break 'a;
434 }
435 }
436 }
437
438 // When the tcp connection is closed, the procedure to close the session is
439 // process directly once, avoiding the connection being disconnected
440 // directly without going through the closing
441 // process.
442 sessions.refresh(&session_addr, 0);
443
444 router.remove(&address);
445
446 log::info!("tcp socket disconnect: addr={:?}, interface={:?}", address, local_addr);
447 });
448 }
449
450 log::error!("tcp server close: interface={:?}", local_addr);
451 });
452
453 log::info!(
454 "turn server listening: bind={}, external={}, transport=TCP",
455 bind,
456 external,
457 );
458
459 Ok(())
460 }
461 }
462}
463
464/// start turn server.
465///
466/// create a specified number of threads,
467/// each thread processes udp data separately.
468pub async fn start<T>(config: &Config, statistics: &Statistics, service: &Service<T>) -> anyhow::Result<()>
469where
470 T: Clone + Observer + 'static,
471{
472 #[allow(unused)]
473 use crate::config::Transport;
474
475 let router = Router::default();
476 for Interface {
477 transport,
478 external,
479 bind,
480 } in config.turn.interfaces.iter().cloned()
481 {
482 #[allow(unused)]
483 let options = ServerStartOptions {
484 statistics: statistics.clone(),
485 service: service.clone(),
486 router: router.clone(),
487 external,
488 bind,
489 };
490
491 match transport {
492 #[cfg(feature = "udp")]
493 Transport::UDP => udp::Server::start(options).await?,
494 #[cfg(feature = "tcp")]
495 Transport::TCP => tcp::Server::start(options).await?,
496 #[allow(unreachable_patterns)]
497 _ => (),
498 };
499 }
500
501 Ok(())
502}