Skip to main content

turn_server_rustls/
lib.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! # turn-server-dimpl
12//!
13//! A TURN server that can handle DTLS client connections using `dimpl`.
14//!
15//! `turn-server-dimpl` provides a sans-IO API for a TURN server communicating with many TURN clients.
16//!
17//! Relevant standards:
18//! - [RFC5766]: Traversal Using Relays around NAT (TURN).
19//! - [RFC6062]: Traversal Using Relays around NAT (TURN) Extensions for TCP Allocations
20//! - [RFC6156]: Traversal Using Relays around NAT (TURN) Extension for IPv6
21//! - [RFC8656]: Traversal Using Relays around NAT (TURN): Relay Extensions to Session
22//!   Traversal Utilities for NAT (STUN)
23//!
24//! [RFC5766]: https://datatracker.ietf.org/doc/html/rfc5766
25//! [RFC6062]: https://tools.ietf.org/html/rfc6062
26//! [RFC6156]: https://tools.ietf.org/html/rfc6156
27//! [RFC8656]: https://tools.ietf.org/html/rfc8656
28
29#![deny(missing_debug_implementations)]
30#![deny(missing_docs)]
31#![cfg_attr(docsrs, feature(doc_cfg))]
32#![deny(clippy::std_instead_of_core)]
33#![deny(clippy::std_instead_of_alloc)]
34#![no_std]
35
36extern crate alloc;
37
38#[cfg(any(feature = "std", test))]
39extern crate std;
40
41use alloc::string::String;
42use alloc::sync::Arc;
43use alloc::vec;
44use alloc::vec::Vec;
45use core::net::SocketAddr;
46use core::time::Duration;
47use std::io::{Read, Write};
48
49use turn_server_proto::types::prelude::DelayedTransmitBuild;
50use turn_server_proto::types::transmit::TransmitBuild;
51use turn_server_proto::types::AddressFamily;
52
53use turn_server_proto::api::Transmit;
54use turn_server_proto::server::TurnServer;
55use turn_server_proto::types::stun::TransportType;
56use turn_server_proto::types::Instant;
57
58pub use turn_server_proto as proto;
59pub use turn_server_proto::api;
60
61use turn_server_proto::api::{
62    DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
63};
64
65use tracing::{info, trace, warn};
66
67use rustls::{ServerConfig, ServerConnection};
68
69/// A TURN server that can handle TLS connections.
70#[derive(Debug)]
71pub struct RustlsTurnServer {
72    server: TurnServer,
73    config: Arc<ServerConfig>,
74    clients: Vec<Client>,
75}
76
77#[derive(Debug)]
78struct Client {
79    client_addr: SocketAddr,
80    tls: ServerConnection,
81    local_closed: bool,
82    peer_closed: bool,
83}
84
85impl RustlsTurnServer {
86    /// Construct a now Turn server that can handle TLS connections.
87    pub fn new(listen_addr: SocketAddr, realm: String, config: Arc<ServerConfig>) -> Self {
88        Self {
89            server: TurnServer::new(TransportType::Tcp, listen_addr, realm),
90            config,
91            clients: vec![],
92        }
93    }
94}
95
96impl TurnServerApi for RustlsTurnServer {
97    /// Add a user credentials that would be accepted by this [`TurnServer`].
98    fn add_user(&mut self, username: String, password: String) {
99        self.server.add_user(username, password)
100    }
101
102    /// The address that the [`TurnServer`] is listening on for incoming client connections.
103    fn listen_address(&self) -> SocketAddr {
104        self.server.listen_address()
105    }
106
107    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
108    /// will need to be acquired by a client.
109    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
110        self.server.set_nonce_expiry_duration(expiry_duration)
111    }
112
113    /// Provide received data to the [`TurnServer`].
114    ///
115    /// Any returned Transmit should be forwarded to the appropriate socket.
116    #[tracing::instrument(
117        name = "turn_server_rustls_recv",
118        skip(self, transmit, now),
119        fields(
120            from = ?transmit.from,
121            data_len = transmit.data.as_ref().len()
122        )
123    )]
124    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
125        &mut self,
126        transmit: Transmit<T>,
127        now: Instant,
128    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
129        let listen_address = self.listen_address();
130        if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
131            trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
132            // incoming client
133            let client = match self
134                .clients
135                .iter_mut()
136                .find(|client| client.client_addr == transmit.from)
137            {
138                Some(client) => client,
139                None => {
140                    if transmit.data.as_ref().is_empty() {
141                        return None;
142                    }
143                    let len = self.clients.len();
144                    self.clients.push(Client {
145                        client_addr: transmit.from,
146                        tls: ServerConnection::new(self.config.clone()).unwrap(),
147                        local_closed: false,
148                        peer_closed: false,
149                    });
150                    info!("new connection from {}", transmit.from);
151                    &mut self.clients[len]
152                }
153            };
154            let mut input = std::io::Cursor::new(transmit.data.as_ref());
155            let io_state = match client.tls.read_tls(&mut input) {
156                Ok(_written) => match client.tls.process_new_packets() {
157                    Ok(io_state) => io_state,
158                    Err(e) => {
159                        warn!("Error processing incoming TLS: {e:?}");
160                        return None;
161                    }
162                },
163                Err(e) => {
164                    warn!("Error receiving data: {e:?}");
165                    return None;
166                }
167            };
168            if io_state.peer_has_closed() {
169                client.peer_closed = true;
170                if !client.local_closed {
171                    client.tls.send_close_notify();
172                    client.local_closed = true;
173                    let mut out = vec![];
174                    client.tls.write_tls(&mut out).unwrap();
175                    let client_addr = client.client_addr;
176                    info!("client {client_addr} TLS closed");
177                    return Some(TransmitBuild::new(
178                        DelayedMessageOrChannelSend::Owned(out),
179                        TransportType::Tcp,
180                        listen_address,
181                        client_addr,
182                    ));
183                } else {
184                    return None;
185                }
186            }
187            if io_state.plaintext_bytes_to_read() == 0 {
188                return None;
189            }
190            let mut vec = vec![0; 2048];
191            let n = match client.tls.reader().read(&mut vec) {
192                Ok(n) => n,
193                Err(e) => {
194                    if e.kind() == std::io::ErrorKind::WouldBlock {
195                        return None;
196                    } else {
197                        warn!("TLS error: {e:?}");
198                        return None;
199                    }
200                }
201            };
202            trace!("io_state: {io_state:?}, n: {n}");
203            vec.resize(n, 0);
204            let transmit = self.server.recv(
205                Transmit::new(vec, transmit.transport, transmit.from, transmit.to),
206                now,
207            )?;
208            if transmit.transport == TransportType::Tcp
209                && transmit.from == listen_address
210                && transmit.to == client.client_addr
211            {
212                let plaintext = transmit.data.build();
213                client.tls.writer().write_all(&plaintext).unwrap();
214                let mut out = vec![];
215                client.tls.write_tls(&mut out).unwrap();
216                Some(TransmitBuild::new(
217                    DelayedMessageOrChannelSend::Owned(out),
218                    TransportType::Tcp,
219                    listen_address,
220                    client.client_addr,
221                ))
222            } else {
223                let transmit = transmit.build();
224                Some(TransmitBuild::new(
225                    DelayedMessageOrChannelSend::Owned(transmit.data),
226                    transmit.transport,
227                    transmit.from,
228                    transmit.to,
229                ))
230            }
231        } else if let Some(transmit) = self.server.recv(transmit, now) {
232            // incoming allocated address
233            if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
234                let Some(client) = self
235                    .clients
236                    .iter_mut()
237                    .find(|client| transmit.to == client.client_addr)
238                else {
239                    return Some(transmit);
240                };
241                let plaintext = transmit.data.build();
242                client.tls.writer().write_all(&plaintext).unwrap();
243                let mut out = vec![];
244                client.tls.write_tls(&mut out).unwrap();
245                Some(TransmitBuild::new(
246                    DelayedMessageOrChannelSend::Owned(out),
247                    TransportType::Tcp,
248                    listen_address,
249                    client.client_addr,
250                ))
251            } else {
252                Some(transmit)
253            }
254        } else {
255            None
256        }
257    }
258
259    fn recv_icmp<T: AsRef<[u8]>>(
260        &mut self,
261        family: AddressFamily,
262        bytes: T,
263        now: Instant,
264    ) -> Option<Transmit<Vec<u8>>> {
265        let transmit = self.server.recv_icmp(family, bytes, now)?;
266        // incoming allocated address
267        let listen_address = self.listen_address();
268        if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
269            let Some(client) = self
270                .clients
271                .iter_mut()
272                .find(|client| transmit.to == client.client_addr)
273            else {
274                return Some(transmit);
275            };
276            client.tls.writer().write_all(&transmit.data).unwrap();
277            let mut out = vec![];
278            client.tls.write_tls(&mut out).unwrap();
279            Some(Transmit::new(
280                out,
281                TransportType::Tcp,
282                listen_address,
283                client.client_addr,
284            ))
285        } else {
286            Some(transmit)
287        }
288    }
289
290    /// Poll the [`TurnServer`] in order to make further progress.
291    ///
292    /// The returned value indicates what the caller should do.
293    fn poll(&mut self, now: Instant) -> TurnServerPollRet {
294        let protocol_ret = self.server.poll(now);
295        let mut have_pending = false;
296        for (idx, client) in self.clients.iter_mut().enumerate() {
297            trace!("client: {client:?}");
298            let io_state = match client.tls.process_new_packets() {
299                Ok(io_state) => io_state,
300                Err(e) => {
301                    warn!("Error processing TLS: {e:?}");
302                    continue;
303                }
304            };
305            trace!("{io_state:?}");
306            if io_state.tls_bytes_to_write() > 0 {
307                have_pending = true;
308                continue;
309            } else if !client.peer_closed && io_state.peer_has_closed() {
310                client.peer_closed = true;
311                if !client.local_closed {
312                    client.tls.send_close_notify();
313                    client.local_closed = true;
314                    have_pending = true;
315                    continue;
316                }
317            }
318            if client.local_closed && client.peer_closed && !client.tls.wants_write() {
319                let client = self.clients.remove(idx);
320                return TurnServerPollRet::TcpClose {
321                    local_addr: self.server.listen_address(),
322                    remote_addr: client.client_addr,
323                };
324            }
325        }
326        if let TurnServerPollRet::TcpClose {
327            local_addr,
328            remote_addr,
329        } = protocol_ret
330        {
331            let Some(client) = self
332                .clients
333                .iter_mut()
334                .find(|client| client.client_addr == remote_addr)
335            else {
336                return TurnServerPollRet::TcpClose {
337                    local_addr,
338                    remote_addr,
339                };
340            };
341            client.tls.send_close_notify();
342            client.local_closed = true;
343            return TurnServerPollRet::WaitUntil(now);
344        }
345        if have_pending {
346            return TurnServerPollRet::WaitUntil(now);
347        }
348        protocol_ret
349    }
350
351    /// Poll for a new Transmit to send over a socket.
352    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
353        let listen_address = self.listen_address();
354
355        while let Some(transmit) = self.server.poll_transmit(now) {
356            if let Some(client) = self
357                .clients
358                .iter_mut()
359                .find(|client| transmit.to == client.client_addr)
360            {
361                if transmit.data.is_empty() {
362                    if !client.local_closed {
363                        warn!("client {} closed", client.client_addr);
364                        client.tls.send_close_notify();
365                        client.local_closed = true;
366                    }
367                } else {
368                    client.tls.writer().write_all(&transmit.data).unwrap();
369                }
370            } else {
371                warn!("return transmit: {transmit:?}");
372                return Some(transmit);
373            };
374        }
375
376        for client in self.clients.iter_mut() {
377            trace!("client: {client:?}");
378            let client_addr = client.client_addr;
379            if !client.tls.wants_write() {
380                continue;
381            }
382            let mut vec = vec![];
383            let n = match client.tls.write_tls(&mut vec) {
384                Ok(n) => n,
385                Err(e) => {
386                    warn!("error writing TLS: {e:?}");
387                    continue;
388                }
389            };
390            vec.resize(n, 0);
391            warn!("return transmit: {vec:x?}");
392            return Some(Transmit::new(
393                vec,
394                TransportType::Tcp,
395                listen_address,
396                client_addr,
397            ));
398        }
399        None
400    }
401
402    /// Notify the [`TurnServer`] that a UDP socket has been allocated (or an error) in response to
403    /// [TurnServerPollRet::AllocateSocket].
404    fn allocated_socket(
405        &mut self,
406        transport: TransportType,
407        local_addr: SocketAddr,
408        remote_addr: SocketAddr,
409        allocation_transport: TransportType,
410        family: AddressFamily,
411        socket_addr: Result<SocketAddr, SocketAllocateError>,
412        now: Instant,
413    ) {
414        self.server.allocated_socket(
415            transport,
416            local_addr,
417            remote_addr,
418            allocation_transport,
419            family,
420            socket_addr,
421            now,
422        )
423    }
424
425    fn tcp_connected(
426        &mut self,
427        relayed_addr: SocketAddr,
428        peer_addr: SocketAddr,
429        listen_addr: SocketAddr,
430        client_addr: SocketAddr,
431        socket_addr: Result<SocketAddr, crate::api::TcpConnectError>,
432        now: Instant,
433    ) {
434        self.server.tcp_connected(
435            relayed_addr,
436            peer_addr,
437            listen_addr,
438            client_addr,
439            socket_addr,
440            now,
441        )
442    }
443}