turn_server/
server.rs

1use crate::{
2    config::{Config, Interface},
3    router::Router,
4    statistics::Statistics,
5};
6
7use std::net::SocketAddr;
8
9use turn::{Observer, Service};
10
11#[allow(unused)]
12struct ServerStartOptions<T> {
13    bind: SocketAddr,
14    external: SocketAddr,
15    service: Service<T>,
16    router: Router,
17    statistics: Statistics,
18}
19
20#[allow(unused)]
21trait Server {
22    async fn start<T>(options: ServerStartOptions<T>) -> Result<(), anyhow::Error>
23    where
24        T: Clone + Observer + 'static;
25}
26
27#[cfg(feature = "udp")]
28mod udp {
29    use super::{Server as ServerExt, ServerStartOptions};
30    use crate::statistics::Stats;
31
32    use std::{io::ErrorKind::ConnectionReset, ops::Deref, sync::Arc};
33
34    use once_cell::sync::Lazy;
35    use stun::Transport;
36    use tokio::net::UdpSocket;
37    use turn::{Observer, ResponseMethod, SessionAddr};
38
39    static NUM_CPUS: Lazy<usize> = Lazy::new(|| num_cpus::get());
40
41    /// udp socket process thread.
42    ///
43    /// read the data packet from the UDP socket and hand
44    /// it to the proto for processing, and send the processed
45    /// data packet to the specified address.
46    pub struct Server;
47
48    impl ServerExt for Server {
49        async fn start<T>(
50            ServerStartOptions {
51                bind,
52                external,
53                service,
54                router,
55                statistics,
56            }: ServerStartOptions<T>,
57        ) -> Result<(), anyhow::Error>
58        where
59            T: Clone + Observer + 'static,
60        {
61            let socket = Arc::new(UdpSocket::bind(bind).await?);
62            let local_addr = socket.local_addr()?;
63
64            tokio::spawn(async move {
65                for _ in 0..*NUM_CPUS.deref() {
66                    let socket = socket.clone();
67                    let router = router.clone();
68                    let reporter = statistics.get_reporter(Transport::UDP);
69                    let mut operationer = service.get_operationer(external, external);
70
71                    let mut session_addr = SessionAddr {
72                        address: external,
73                        interface: external,
74                    };
75
76                    tokio::spawn(async move {
77                        let mut buf = vec![0u8; 2048];
78
79                        loop {
80                            // Note: An error will also be reported when the remote host is
81                            // shut down, which is not processed yet, but a
82                            // warning will be issued.
83                            let (size, addr) = match socket.recv_from(&mut buf).await {
84                                Err(e) if e.kind() != ConnectionReset => break,
85                                Ok(s) => s,
86                                _ => continue,
87                            };
88
89                            session_addr.address = addr;
90
91                            reporter.send(
92                                &session_addr,
93                                &[Stats::ReceivedBytes(size as u32), Stats::ReceivedPkts(1)],
94                            );
95
96                            // The stun message requires at least 4 bytes. (currently the
97                            // smallest stun message is channel data,
98                            // excluding content)
99                            if size >= 4 {
100                                if let Ok(Some(res)) = operationer.route(&buf[..size], addr).await {
101                                    let target = res.relay.as_ref().unwrap_or(&addr);
102                                    if let Some(ref endpoint) = res.endpoint {
103                                        router.send(endpoint, res.method, target, res.bytes);
104                                    } else {
105                                        if let Err(e) = socket.send_to(res.bytes, target).await {
106                                            if e.kind() != ConnectionReset {
107                                                break;
108                                            }
109                                        }
110
111                                        reporter.send(
112                                            &session_addr,
113                                            &[Stats::SendBytes(res.bytes.len() as u32), Stats::SendPkts(1)],
114                                        );
115
116                                        if let ResponseMethod::Stun(method) = res.method {
117                                            if method.is_error() {
118                                                reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
119                                            }
120                                        }
121                                    }
122                                }
123                            }
124                        }
125                    });
126                }
127
128                {
129                    let mut session_addr = SessionAddr {
130                        address: external,
131                        interface: external,
132                    };
133
134                    let reporter = statistics.get_reporter(Transport::UDP);
135                    let mut receiver = router.get_receiver(external);
136                    while let Some((bytes, _, addr)) = receiver.recv().await {
137                        session_addr.address = addr;
138
139                        if let Err(e) = socket.send_to(&bytes, addr).await {
140                            if e.kind() != ConnectionReset {
141                                break;
142                            }
143                        } else {
144                            reporter.send(
145                                &session_addr,
146                                &[Stats::SendBytes(bytes.len() as u32), Stats::SendPkts(1)],
147                            );
148                        }
149                    }
150
151                    router.remove(&external);
152                }
153
154                log::error!("udp server close: interface={:?}", local_addr);
155            });
156
157            log::info!(
158                "turn server listening: bind={}, external={}, transport=UDP",
159                bind,
160                external,
161            );
162
163            Ok(())
164        }
165    }
166}
167
168#[cfg(feature = "tcp")]
169mod tcp {
170    use super::{Server as ServerExt, ServerStartOptions};
171    use crate::statistics::Stats;
172
173    use std::{
174        ops::{Deref, DerefMut},
175        sync::Arc,
176    };
177
178    use stun::{Decoder, Transport};
179    use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::TcpListener, sync::Mutex};
180    use turn::{Observer, ResponseMethod, SessionAddr};
181
182    static ZERO_BYTES: [u8; 8] = [0u8; 8];
183
184    /// An emulated double buffer queue, this is used when reading data over
185    /// TCP.
186    ///
187    /// When reading data over TCP, you need to keep adding to the buffer until
188    /// you find the delimited position. But this double buffer queue solves
189    /// this problem well, in the queue, the separation is treated as the first
190    /// read operation and after the separation the buffer is reversed and
191    /// another free buffer is used for writing the data.
192    ///
193    /// If the current buffer in the separation after the existence of
194    /// unconsumed data, this time the unconsumed data will be copied to another
195    /// free buffer, and fill the length of the free buffer data, this time to
196    /// write data again when you can continue to fill to the end of the
197    /// unconsumed data.
198    ///
199    /// This queue only needs to copy the unconsumed data without duplicating
200    /// the memory allocation, which will reduce a lot of overhead.
201    struct ExchangeBuffer {
202        buffers: [(Vec<u8>, usize /* len */); 2],
203        index: usize,
204    }
205
206    impl Default for ExchangeBuffer {
207        #[rustfmt::skip]
208        fn default() -> Self {
209            Self {
210                index: 0,
211                buffers: [
212                    (vec![0u8; 2048], 0),
213                    (vec![0u8; 2048], 0),
214                ],
215            }
216        }
217    }
218
219    impl Deref for ExchangeBuffer {
220        type Target = [u8];
221
222        fn deref(&self) -> &Self::Target {
223            &self.buffers[self.index].0[..]
224        }
225    }
226
227    impl DerefMut for ExchangeBuffer {
228        // Writes need to take into account overwriting written data, so fetching the
229        // writable buffer starts with the internal cursor.
230        fn deref_mut(&mut self) -> &mut Self::Target {
231            let len = self.buffers[self.index].1;
232            &mut self.buffers[self.index].0[len..]
233        }
234    }
235
236    impl ExchangeBuffer {
237        fn len(&self) -> usize {
238            self.buffers[self.index].1
239        }
240
241        /// The buffer does not automatically advance the cursor as BytesMut
242        /// does, and you need to manually advance the length of the data
243        /// written.
244        fn advance(&mut self, len: usize) {
245            self.buffers[self.index].1 += len;
246        }
247
248        fn split(&mut self, len: usize) -> &[u8] {
249            let (ref current_bytes, current_len) = self.buffers[self.index];
250
251            // The length of the separation cannot be greater than the length of the data.
252            assert!(len <= current_len);
253
254            // Length of unconsumed data
255            let remaining = current_len - len;
256
257            {
258                // The current buffer is no longer in use, resetting the content length.
259                self.buffers[self.index].1 = 0;
260
261                // Invert the buffer.
262                self.index = if self.index == 0 { 1 } else { 0 };
263
264                // The length of unconsumed data needs to be updated into the reversed
265                // completion buffer.
266                self.buffers[self.index].1 = remaining;
267            }
268
269            // Unconsumed data exists and is copied to the free buffer.
270            #[allow(mutable_transmutes)]
271            if remaining > 0 {
272                unsafe { std::mem::transmute::<&[u8], &mut [u8]>(&self.buffers[self.index].0[..remaining]) }
273                    .copy_from_slice(&current_bytes[len..current_len]);
274            }
275
276            &current_bytes[..len]
277        }
278    }
279
280    /// tcp socket process thread.
281    ///
282    /// This function is used to handle all connections coming from the tcp
283    /// listener, and handle the receiving, sending and forwarding of messages.
284    pub struct Server;
285
286    impl ServerExt for Server {
287        async fn start<T>(
288            ServerStartOptions {
289                bind,
290                external,
291                service,
292                router,
293                statistics,
294            }: ServerStartOptions<T>,
295        ) -> Result<(), anyhow::Error>
296        where
297            T: Clone + Observer + 'static,
298        {
299            let listener = TcpListener::bind(bind).await?;
300            let local_addr = listener.local_addr()?;
301
302            tokio::spawn(async move {
303                // Accept all connections on the current listener, but exit the entire
304                // process when an error occurs.
305                while let Ok((socket, address)) = listener.accept().await {
306                    let router = router.clone();
307                    let reporter = statistics.get_reporter(Transport::TCP);
308                    let mut receiver = router.get_receiver(address);
309                    let mut operationer = service.get_operationer(address, external);
310
311                    log::info!("tcp socket accept: addr={:?}, interface={:?}", address, local_addr,);
312
313                    // Disable the Nagle algorithm.
314                    // because to maintain real-time, any received data should be processed
315                    // as soon as possible.
316                    if let Err(e) = socket.set_nodelay(true) {
317                        log::error!("tcp socket set nodelay failed!: addr={}, err={}", address, e);
318                    }
319
320                    let session_addr = SessionAddr {
321                        interface: external,
322                        address,
323                    };
324
325                    let (mut reader, writer) = socket.into_split();
326                    let writer = Arc::new(Mutex::new(writer));
327
328                    // Use a separate task to handle messages forwarded to this socket.
329                    let writer_ = writer.clone();
330                    let reporter_ = reporter.clone();
331                    tokio::spawn(async move {
332                        while let Some((bytes, method, _)) = receiver.recv().await {
333                            let mut writer = writer_.lock().await;
334                            if writer.write_all(bytes.as_slice()).await.is_err() {
335                                break;
336                            } else {
337                                reporter_.send(
338                                    &session_addr,
339                                    &[Stats::SendBytes(bytes.len() as u32), Stats::SendPkts(1)],
340                                );
341                            }
342
343                            // The channel data needs to be aligned in multiples of 4 in
344                            // tcp. If the channel data is forwarded to tcp, the alignment
345                            // bit needs to be filled, because if the channel data comes
346                            // from udp, it is not guaranteed to be aligned and needs to be
347                            // checked.
348                            if method == ResponseMethod::ChannelData {
349                                let pad = bytes.len() % 4;
350                                if pad > 0 && writer.write_all(&ZERO_BYTES[..(4 - pad)]).await.is_err() {
351                                    break;
352                                }
353                            }
354                        }
355                    });
356
357                    let sessions = service.get_sessions();
358                    tokio::spawn(async move {
359                        let mut buffer = ExchangeBuffer::default();
360
361                        'a: while let Ok(size) = reader.read(&mut buffer).await {
362                            // When the received message is 0, it means that the socket
363                            // has been closed.
364                            if size == 0 {
365                                break;
366                            } else {
367                                reporter.send(&session_addr, &[Stats::ReceivedBytes(size as u32)]);
368                                buffer.advance(size);
369                            }
370
371                            // The minimum length of a stun message will not be less
372                            // than 4.
373                            if buffer.len() < 4 {
374                                continue;
375                            }
376
377                            loop {
378                                if buffer.len() <= 4 {
379                                    break;
380                                }
381
382                                // Try to get the message length, if the currently
383                                // received data is less than the message length, jump
384                                // out of the current loop and continue to receive more
385                                // data.
386                                let size = match Decoder::message_size(&buffer, true) {
387                                    Err(_) => break,
388                                    Ok(s) => {
389                                        // Limit the maximum length of messages to 2048, this is to prevent buffer
390                                        // overflow attacks.
391                                        if s > 2048 {
392                                            break 'a;
393                                        }
394
395                                        if s > buffer.len() {
396                                            break;
397                                        }
398
399                                        reporter.send(&session_addr, &[Stats::ReceivedPkts(1)]);
400
401                                        s
402                                    }
403                                };
404
405                                let chunk = buffer.split(size);
406                                if let Ok(ret) = operationer.route(chunk, address).await {
407                                    if let Some(res) = ret {
408                                        if let Some(ref inerface) = res.endpoint {
409                                            router.send(
410                                                inerface,
411                                                res.method,
412                                                res.relay.as_ref().unwrap_or(&address),
413                                                res.bytes,
414                                            );
415                                        } else {
416                                            if writer.lock().await.write_all(res.bytes).await.is_err() {
417                                                break 'a;
418                                            }
419
420                                            reporter.send(
421                                                &session_addr,
422                                                &[Stats::SendBytes(res.bytes.len() as u32), Stats::SendPkts(1)],
423                                            );
424
425                                            if let ResponseMethod::Stun(method) = res.method {
426                                                if method.is_error() {
427                                                    reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
428                                                }
429                                            }
430                                        }
431                                    }
432                                } else {
433                                    break 'a;
434                                }
435                            }
436                        }
437
438                        // When the tcp connection is closed, the procedure to close the session is
439                        // process directly once, avoiding the connection being disconnected
440                        // directly without going through the closing
441                        // process.
442                        sessions.refresh(&session_addr, 0);
443
444                        router.remove(&address);
445
446                        log::info!("tcp socket disconnect: addr={:?}, interface={:?}", address, local_addr);
447                    });
448                }
449
450                log::error!("tcp server close: interface={:?}", local_addr);
451            });
452
453            log::info!(
454                "turn server listening: bind={}, external={}, transport=TCP",
455                bind,
456                external,
457            );
458
459            Ok(())
460        }
461    }
462}
463
464/// start turn server.
465///
466/// create a specified number of threads,
467/// each thread processes udp data separately.
468pub async fn start<T>(config: &Config, statistics: &Statistics, service: &Service<T>) -> anyhow::Result<()>
469where
470    T: Clone + Observer + 'static,
471{
472    #[allow(unused)]
473    use crate::config::Transport;
474
475    let router = Router::default();
476    for Interface {
477        transport,
478        external,
479        bind,
480    } in config.turn.interfaces.iter().cloned()
481    {
482        #[allow(unused)]
483        let options = ServerStartOptions {
484            statistics: statistics.clone(),
485            service: service.clone(),
486            router: router.clone(),
487            external,
488            bind,
489        };
490
491        match transport {
492            #[cfg(feature = "udp")]
493            Transport::UDP => udp::Server::start(options).await?,
494            #[cfg(feature = "tcp")]
495            Transport::TCP => tcp::Server::start(options).await?,
496            #[allow(unreachable_patterns)]
497            _ => (),
498        };
499    }
500
501    Ok(())
502}