Skip to main content

turn_server_openssl/
lib.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! # turn-server-dimpl
12//!
13//! A TURN server that can handle TLS client connections.
14//!
15//! Relevant standards:
16//! - [RFC5766]: Traversal Using Relays around NAT (TURN).
17//! - [RFC6062]: Traversal Using Relays around NAT (TURN) Extensions for TCP Allocations
18//! - [RFC6156]: Traversal Using Relays around NAT (TURN) Extension for IPv6
19//! - [RFC8656]: Traversal Using Relays around NAT (TURN): Relay Extensions to Session
20//!   Traversal Utilities for NAT (STUN)
21//!
22//! [RFC5766]: https://datatracker.ietf.org/doc/html/rfc5766
23//! [RFC6062]: https://tools.ietf.org/html/rfc6062
24//! [RFC6156]: https://tools.ietf.org/html/rfc6156
25//! [RFC8656]: https://tools.ietf.org/html/rfc8656
26
27#![deny(missing_debug_implementations)]
28#![deny(missing_docs)]
29#![cfg_attr(docsrs, feature(doc_cfg))]
30#![deny(clippy::std_instead_of_core)]
31#![deny(clippy::std_instead_of_alloc)]
32#![no_std]
33
34extern crate alloc;
35
36#[cfg(any(feature = "std", test))]
37extern crate std;
38
39use alloc::collections::VecDeque;
40use alloc::string::String;
41use alloc::vec;
42use alloc::vec::Vec;
43use core::net::SocketAddr;
44use core::time::Duration;
45
46use std::io::{Read, Write};
47
48use turn_server_proto::types::prelude::DelayedTransmitBuild;
49use turn_server_proto::types::transmit::TransmitBuild;
50use turn_server_proto::types::AddressFamily;
51
52use turn_server_proto::api::Transmit;
53use turn_server_proto::types::stun::TransportType;
54use turn_server_proto::types::Instant;
55
56pub use turn_server_proto as proto;
57pub use turn_server_proto::api;
58
59use turn_server_proto::api::{
60    DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
61};
62use turn_server_proto::server::TurnServer;
63
64use tracing::{info, trace, warn};
65
66use openssl::ssl::{
67    HandshakeError, MidHandshakeSslStream, ShutdownState, Ssl, SslContext, SslStream,
68};
69
70/// A TURN server that can handle TLS connections.
71#[derive(Debug)]
72pub struct OpensslTurnServer {
73    server: TurnServer,
74    ssl_context: SslContext,
75    clients: Vec<Client>,
76}
77
78#[derive(Debug)]
79struct Client {
80    transport: TransportType,
81    client_addr: SocketAddr,
82    tls: HandshakeState,
83    shutdown: ShutdownState,
84}
85
86#[derive(Debug)]
87enum HandshakeState {
88    Init(Ssl, OsslBio),
89    Handshaking(MidHandshakeSslStream<OsslBio>),
90    Done(SslStream<OsslBio>),
91    Nothing,
92}
93
94impl HandshakeState {
95    fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
96        if let Self::Done(s) = self {
97            return Ok(s);
98        }
99        let taken = core::mem::replace(self, Self::Nothing);
100
101        let ret = match taken {
102            Self::Init(ssl, bio) => ssl.accept(bio),
103            Self::Handshaking(mid) => mid.handshake(),
104            Self::Done(_) | Self::Nothing => unreachable!(),
105        };
106
107        match ret {
108            Ok(s) => {
109                info!(
110                    "SSL handshake completed with version {} cipher: {:?}",
111                    s.ssl().version_str(),
112                    s.ssl().current_cipher()
113                );
114                *self = Self::Done(s);
115                Ok(self.complete()?)
116            }
117            Err(HandshakeError::WouldBlock(mid)) => {
118                *self = Self::Handshaking(mid);
119                Err(std::io::Error::new(
120                    std::io::ErrorKind::WouldBlock,
121                    "Would Block",
122                ))
123            }
124            Err(HandshakeError::SetupFailure(e)) => {
125                warn!("Error during ssl setup: {e}");
126                Err(std::io::Error::new(
127                    std::io::ErrorKind::ConnectionRefused,
128                    e,
129                ))
130            }
131            Err(HandshakeError::Failure(mid)) => {
132                warn!("Failure during ssl setup: {}", mid.error());
133                *self = Self::Handshaking(mid);
134                Err(std::io::Error::new(
135                    std::io::ErrorKind::WouldBlock,
136                    "Would Block",
137                ))
138            }
139        }
140    }
141    fn inner_mut(&mut self) -> &mut OsslBio {
142        match self {
143            Self::Init(_ssl, stream) => stream,
144            Self::Handshaking(mid) => mid.get_mut(),
145            Self::Done(stream) => stream.get_mut(),
146            Self::Nothing => unreachable!(),
147        }
148    }
149}
150
151#[derive(Debug, Default)]
152struct OsslBio {
153    incoming: Vec<u8>,
154    outgoing: VecDeque<Vec<u8>>,
155}
156
157impl OsslBio {
158    fn push_incoming(&mut self, buf: &[u8]) {
159        self.incoming.extend_from_slice(buf)
160    }
161
162    fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
163        self.outgoing.pop_front()
164    }
165}
166
167impl std::io::Write for OsslBio {
168    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
169        self.outgoing.push_back(buf.to_vec());
170        Ok(buf.len())
171    }
172
173    fn flush(&mut self) -> std::io::Result<()> {
174        Ok(())
175    }
176}
177
178impl std::io::Read for OsslBio {
179    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
180        let len = self.incoming.len();
181        let max = buf.len().min(len);
182
183        if len == 0 {
184            return Err(std::io::Error::new(
185                std::io::ErrorKind::WouldBlock,
186                "Would Block",
187            ));
188        }
189
190        buf[..max].copy_from_slice(&self.incoming[..max]);
191        if max == len {
192            self.incoming.truncate(0);
193        } else {
194            self.incoming.drain(..max);
195        }
196
197        Ok(max)
198    }
199}
200
201impl OpensslTurnServer {
202    /// Construct a now Turn server that can handle TLS connections.
203    pub fn new(
204        transport: TransportType,
205        listen_addr: SocketAddr,
206        realm: String,
207        ssl_context: SslContext,
208    ) -> Self {
209        Self {
210            server: TurnServer::new(transport, listen_addr, realm),
211            ssl_context,
212            clients: vec![],
213        }
214    }
215}
216
217impl TurnServerApi for OpensslTurnServer {
218    /// Add a user credentials that would be accepted by this [`TurnServer`].
219    fn add_user(&mut self, username: String, password: String) {
220        self.server.add_user(username, password)
221    }
222
223    /// The address that the [`TurnServer`] is listening on for incoming client connections.
224    fn listen_address(&self) -> SocketAddr {
225        self.server.listen_address()
226    }
227
228    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
229    /// will need to be acquired by a client.
230    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
231        self.server.set_nonce_expiry_duration(expiry_duration)
232    }
233
234    /// Provide received data to the [`TurnServer`].
235    ///
236    /// Any returned Transmit should be forwarded to the appropriate socket.
237    #[tracing::instrument(
238        name = "turn_server_openssl_recv",
239        skip(self, transmit, now),
240        fields(
241            from = ?transmit.from,
242            data_len = transmit.data.as_ref().len()
243        )
244    )]
245    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
246        &mut self,
247        transmit: Transmit<T>,
248        now: Instant,
249    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
250        let listen_address = self.listen_address();
251        if transmit.to == listen_address {
252            trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
253            // incoming client
254            let client = match self
255                .clients
256                .iter_mut()
257                .find(|client| client.client_addr == transmit.from)
258            {
259                Some(client) => client,
260                None => {
261                    let len = self.clients.len();
262                    let ssl = Ssl::new(&self.ssl_context).expect("Cannot create ssl structure");
263                    self.clients.push(Client {
264                        transport: transmit.transport,
265                        client_addr: transmit.from,
266                        tls: HandshakeState::Init(ssl, OsslBio::default()),
267                        shutdown: ShutdownState::empty(),
268                    });
269                    info!(
270                        "new connection from {} {}",
271                        transmit.transport, transmit.from
272                    );
273                    &mut self.clients[len]
274                }
275            };
276            client.tls.inner_mut().push_incoming(transmit.data.as_ref());
277            let stream = match client.tls.complete() {
278                Ok(s) => s,
279                Err(e) => {
280                    if e.kind() != std::io::ErrorKind::WouldBlock {
281                        warn!("error accepting TLS: {e}");
282                    }
283                    return None;
284                }
285            };
286
287            let mut plaintext = vec![0; 2048];
288            let len = match stream.read(&mut plaintext) {
289                Ok(len) => len,
290                Err(e) => {
291                    if e.kind() != std::io::ErrorKind::WouldBlock {
292                        warn!("Error: {e}");
293                    }
294                    return None;
295                }
296            };
297            warn!("received: {len} plaintext bytes");
298            if len == 0 {
299                let pre_shutdown = stream.get_shutdown();
300                let _ = stream.shutdown();
301                client.shutdown = stream.get_shutdown();
302                if !pre_shutdown.contains(ShutdownState::SENT) {
303                    return stream.get_mut().pop_outgoing().map(|data| {
304                        TransmitBuild::new(
305                            DelayedMessageOrChannelSend::Owned(data),
306                            transmit.transport,
307                            listen_address,
308                            client.client_addr,
309                        )
310                    });
311                } else {
312                    return None;
313                }
314            }
315            plaintext.resize(len, 0);
316
317            let transmit = self.server.recv(
318                Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
319                now,
320            )?;
321
322            if transmit.from == listen_address && transmit.to == client.client_addr {
323                let plaintext = transmit.data.build();
324                stream.write_all(&plaintext).unwrap();
325                stream.get_mut().pop_outgoing().map(|data| {
326                    TransmitBuild::new(
327                        DelayedMessageOrChannelSend::Owned(data),
328                        transmit.transport,
329                        listen_address,
330                        client.client_addr,
331                    )
332                })
333            } else {
334                let transmit = transmit.build();
335                Some(TransmitBuild::new(
336                    DelayedMessageOrChannelSend::Owned(transmit.data),
337                    transmit.transport,
338                    transmit.from,
339                    transmit.to,
340                ))
341            }
342        } else if let Some(transmit) = self.server.recv(transmit, now) {
343            // incoming allocated address
344            if transmit.from == listen_address {
345                let Some(client) = self
346                    .clients
347                    .iter_mut()
348                    .find(|client| transmit.to == client.client_addr)
349                else {
350                    return Some(transmit);
351                };
352
353                let plaintext = transmit.data.build();
354                let stream = match client.tls.complete() {
355                    Ok(s) => s,
356                    Err(e) => {
357                        if e.kind() != std::io::ErrorKind::WouldBlock {
358                            warn!("error accepting TLS: {e}");
359                        }
360                        return None;
361                    }
362                };
363                stream.write_all(&plaintext).unwrap();
364                stream.get_mut().pop_outgoing().map(|data| {
365                    TransmitBuild::new(
366                        DelayedMessageOrChannelSend::Owned(data),
367                        transmit.transport,
368                        listen_address,
369                        client.client_addr,
370                    )
371                })
372            } else {
373                Some(transmit)
374            }
375        } else {
376            None
377        }
378    }
379
380    fn recv_icmp<T: AsRef<[u8]>>(
381        &mut self,
382        family: AddressFamily,
383        bytes: T,
384        now: Instant,
385    ) -> Option<Transmit<Vec<u8>>> {
386        let transmit = self.server.recv_icmp(family, bytes, now)?;
387        // incoming allocated address
388        let listen_address = self.listen_address();
389        if transmit.from == listen_address {
390            let Some(client) = self
391                .clients
392                .iter_mut()
393                .find(|client| transmit.to == client.client_addr)
394            else {
395                return Some(transmit);
396            };
397            let stream = match client.tls.complete() {
398                Ok(s) => s,
399                Err(e) => {
400                    if e.kind() != std::io::ErrorKind::WouldBlock {
401                        warn!("error accepting TLS: {e}");
402                    }
403                    return None;
404                }
405            };
406            stream.write_all(&transmit.data).unwrap();
407            stream.get_mut().pop_outgoing().map(|data| {
408                Transmit::new(data, transmit.transport, listen_address, client.client_addr)
409            })
410        } else {
411            Some(transmit)
412        }
413    }
414
415    /// Poll the [`TurnServer`] in order to make further progress.
416    ///
417    /// The returned value indicates what the caller should do.
418    fn poll(&mut self, now: Instant) -> TurnServerPollRet {
419        let listen_address = self.listen_address();
420        let protocol_ret = self.server.poll(now);
421        let mut have_pending = false;
422        for (idx, client) in self.clients.iter_mut().enumerate() {
423            let stream = match client.tls.complete() {
424                Ok(s) => s,
425                Err(_) => continue,
426            };
427            client.shutdown = stream.get_shutdown();
428            if !stream.get_mut().outgoing.is_empty() {
429                have_pending = true;
430                continue;
431            }
432            if client
433                .shutdown
434                .contains(ShutdownState::SENT | ShutdownState::RECEIVED)
435            {
436                let client = self.clients.swap_remove(idx);
437                return TurnServerPollRet::TcpClose {
438                    local_addr: listen_address,
439                    remote_addr: client.client_addr,
440                };
441            }
442        }
443        if have_pending {
444            return TurnServerPollRet::WaitUntil(now);
445        }
446        if let TurnServerPollRet::TcpClose {
447            local_addr: _,
448            remote_addr,
449        } = protocol_ret
450        {
451            let Some(client) = self
452                .clients
453                .iter_mut()
454                .find(|client| client.client_addr == remote_addr)
455            else {
456                return protocol_ret;
457            };
458            if let Ok(stream) = client.tls.complete() {
459                if let Err(e) = stream.shutdown() {
460                    warn!("Failed to shutdown ssl connection to {remote_addr}: {e:?}");
461                }
462                client.shutdown = stream.get_shutdown();
463            }
464            return TurnServerPollRet::WaitUntil(now);
465        }
466        protocol_ret
467    }
468
469    /// Poll for a new Transmit to send over a socket.
470    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
471        let listen_address = self.listen_address();
472
473        for client in self.clients.iter_mut() {
474            if let Some(data) = client.tls.inner_mut().pop_outgoing() {
475                return Some(Transmit::new(
476                    data,
477                    client.transport,
478                    listen_address,
479                    client.client_addr,
480                ));
481            }
482        }
483
484        while let Some(transmit) = self.server.poll_transmit(now) {
485            let Some(client) = self
486                .clients
487                .iter_mut()
488                .find(|client| transmit.to == client.client_addr)
489            else {
490                warn!("return transmit: {transmit:?}");
491                return Some(transmit);
492            };
493            let stream = match client.tls.complete() {
494                Ok(s) => s,
495                // FIXME: how to deal with early data
496                Err(e) => {
497                    warn!("early data -> ignored: {e:?}");
498                    continue;
499                }
500            };
501            stream.write_all(&transmit.data).unwrap();
502
503            if let Some(data) = client.tls.inner_mut().pop_outgoing() {
504                return Some(Transmit::new(
505                    data,
506                    client.transport,
507                    listen_address,
508                    client.client_addr,
509                ));
510            }
511        }
512        None
513    }
514
515    /// Notify the [`TurnServer`] that a socket has been allocated (or an error) in response to
516    /// [TurnServerPollRet::AllocateSocket].
517    fn allocated_socket(
518        &mut self,
519        transport: TransportType,
520        local_addr: SocketAddr,
521        remote_addr: SocketAddr,
522        allocation_transport: TransportType,
523        family: AddressFamily,
524        socket_addr: Result<SocketAddr, SocketAllocateError>,
525        now: Instant,
526    ) {
527        self.server.allocated_socket(
528            transport,
529            local_addr,
530            remote_addr,
531            allocation_transport,
532            family,
533            socket_addr,
534            now,
535        )
536    }
537
538    fn tcp_connected(
539        &mut self,
540        relayed_addr: SocketAddr,
541        peer_addr: SocketAddr,
542        listen_addr: SocketAddr,
543        client_addr: SocketAddr,
544        socket_addr: Result<SocketAddr, crate::api::TcpConnectError>,
545        now: Instant,
546    ) {
547        self.server.tcp_connected(
548            relayed_addr,
549            peer_addr,
550            listen_addr,
551            client_addr,
552            socket_addr,
553            now,
554        )
555    }
556}