Skip to main content

turn_client_openssl/
lib.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! #turn-client-openssl
12//!
13//! TLS TURN client using OpenSSL.
14//!
15//! An implementation of a TURN client suitable for TLS over TCP connections and DTLS over UDP
16//! connections.
17
18#![deny(missing_debug_implementations)]
19#![deny(missing_docs)]
20#![cfg_attr(docsrs, feature(doc_cfg))]
21#![deny(clippy::std_instead_of_core)]
22#![deny(clippy::std_instead_of_alloc)]
23#![no_std]
24
25extern crate alloc;
26
27pub use openssl;
28
29#[cfg(any(feature = "std", test))]
30extern crate std;
31
32pub use turn_client_proto::api;
33
34use std::io::{Read, Write};
35
36use alloc::collections::VecDeque;
37use alloc::vec;
38use alloc::vec::Vec;
39
40use core::net::{IpAddr, SocketAddr};
41use core::time::Duration;
42
43use turn_client_proto::types::Instant;
44use turn_client_proto::types::TransportType;
45
46use tracing::{info, trace, warn};
47
48use turn_client_proto::api::*;
49use turn_client_proto::tcp::TurnClientTcp;
50use turn_client_proto::udp::TurnClientUdp;
51
52use openssl::ssl::{
53    HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslContext,
54    SslStream,
55};
56
57turn_client_proto::impl_client!(TcpOrUdp, (Udp, TurnClientUdp), (Tcp, TurnClientTcp));
58
59/// A TURN client that communicates over TLS.
60#[derive(Debug)]
61pub struct TurnClientOpensslTls {
62    protocol: TcpOrUdp,
63    ssl_context: SslContext,
64    sockets: Vec<Socket>,
65}
66
67#[derive(Debug)]
68struct Socket {
69    local_addr: SocketAddr,
70    remote_addr: SocketAddr,
71    handshake: HandshakeState,
72    pending_write: VecDeque<Data<'static>>,
73    shutdown: ShutdownState,
74}
75
76#[derive(Debug)]
77enum HandshakeState {
78    Init(Ssl, OsslBio),
79    Handshaking(MidHandshakeSslStream<OsslBio>),
80    Done(SslStream<OsslBio>),
81    Nothing,
82}
83
84impl HandshakeState {
85    fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
86        if let Self::Done(s) = self {
87            return Ok(s);
88        }
89        let taken = core::mem::replace(self, Self::Nothing);
90
91        let ret = match taken {
92            Self::Init(ssl, bio) => ssl.connect(bio),
93            Self::Handshaking(mid) => mid.handshake(),
94            Self::Done(_) | Self::Nothing => unreachable!(),
95        };
96
97        match ret {
98            Ok(s) => {
99                info!(
100                    "SSL handshake completed with version {} cipher: {:?}",
101                    s.ssl().version_str(),
102                    s.ssl().current_cipher()
103                );
104                *self = Self::Done(s);
105                Ok(self.complete()?)
106            }
107            Err(HandshakeError::WouldBlock(mid)) => {
108                *self = Self::Handshaking(mid);
109                Err(std::io::Error::new(
110                    std::io::ErrorKind::WouldBlock,
111                    "Would Block",
112                ))
113            }
114            Err(HandshakeError::SetupFailure(e)) => {
115                warn!("Error during ssl setup: {e}");
116                Err(std::io::Error::new(
117                    std::io::ErrorKind::ConnectionRefused,
118                    e,
119                ))
120            }
121            Err(HandshakeError::Failure(mid)) => {
122                warn!("Failure during ssl setup: {}", mid.error());
123                *self = Self::Handshaking(mid);
124                Err(std::io::Error::new(
125                    std::io::ErrorKind::ConnectionRefused,
126                    "Failure to setup SSL parameters",
127                ))
128            }
129        }
130    }
131    fn inner_mut(&mut self) -> &mut OsslBio {
132        match self {
133            Self::Init(_ssl, stream) => stream,
134            Self::Handshaking(mid) => mid.get_mut(),
135            Self::Done(stream) => stream.get_mut(),
136            Self::Nothing => unreachable!(),
137        }
138    }
139}
140
141#[derive(Debug, Default)]
142struct OsslBio {
143    incoming: Vec<u8>,
144    outgoing: VecDeque<Vec<u8>>,
145}
146
147impl OsslBio {
148    fn push_incoming(&mut self, buf: &[u8]) {
149        self.incoming.extend_from_slice(buf)
150    }
151
152    fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
153        self.outgoing.pop_front()
154    }
155}
156
157impl std::io::Write for OsslBio {
158    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
159        self.outgoing.push_back(buf.to_vec());
160        Ok(buf.len())
161    }
162
163    fn flush(&mut self) -> std::io::Result<()> {
164        Ok(())
165    }
166}
167
168impl std::io::Read for OsslBio {
169    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
170        let len = self.incoming.len();
171        let max = buf.len().min(len);
172
173        if len == 0 {
174            return Err(std::io::Error::new(
175                std::io::ErrorKind::WouldBlock,
176                "Would Block",
177            ));
178        }
179
180        buf[..max].copy_from_slice(&self.incoming[..max]);
181        if max == len {
182            self.incoming.truncate(0);
183        } else {
184            self.incoming.drain(..max);
185        }
186
187        Ok(max)
188    }
189}
190
191impl TurnClientOpensslTls {
192    /// Allocate an address on a TURN server to relay data to and from peers.
193    pub fn allocate(
194        transport: TransportType,
195        local_addr: SocketAddr,
196        remote_addr: SocketAddr,
197        config: TurnConfig,
198        ssl_context: SslContext,
199    ) -> Self {
200        let ssl = Ssl::new(&ssl_context).expect("Cannot create ssl structure");
201
202        Self {
203            protocol: match transport {
204                TransportType::Udp => {
205                    if config.allocation_transport() != TransportType::Udp {
206                        panic!("Cannot create a TCP allocation with a UDP connection to the TURN server")
207                    }
208                    TcpOrUdp::Udp(TurnClientUdp::allocate(local_addr, remote_addr, config))
209                }
210                TransportType::Tcp => {
211                    TcpOrUdp::Tcp(TurnClientTcp::allocate(local_addr, remote_addr, config))
212                }
213            },
214            ssl_context,
215            sockets: vec![Socket {
216                local_addr,
217                remote_addr,
218                handshake: HandshakeState::Init(ssl, OsslBio::default()),
219                pending_write: VecDeque::default(),
220                shutdown: ShutdownState::empty(),
221            }],
222        }
223    }
224
225    fn empty_transmit_queue(&mut self, now: Instant) {
226        while let Some(transmit) = self.protocol.poll_transmit(now) {
227            let Some(socket) = self.sockets.iter_mut().find(|socket| {
228                socket.local_addr == transmit.from && socket.remote_addr == transmit.to
229            }) else {
230                warn!(
231                    "no socket for transmit from {} to {}",
232                    transmit.from, transmit.to
233                );
234                continue;
235            };
236            match socket.handshake.complete() {
237                Ok(stream) => {
238                    for data in socket.pending_write.drain(..) {
239                        warn!("write early data, {} bytes", data.len());
240                        stream.write_all(&data).unwrap()
241                    }
242                    stream.write_all(&transmit.data).unwrap()
243                }
244                Err(e) => {
245                    if e.kind() == std::io::ErrorKind::WouldBlock {
246                        warn!("early data ({} bytes), storing", transmit.data.len());
247                        socket.pending_write.push_back(transmit.data);
248                    } else {
249                        warn!("Failure to send data: {e:?}");
250                        continue;
251                    }
252                }
253            }
254        }
255    }
256}
257
258impl TurnClientApi for TurnClientOpensslTls {
259    fn transport(&self) -> TransportType {
260        self.protocol.transport()
261    }
262
263    fn local_addr(&self) -> SocketAddr {
264        self.protocol.local_addr()
265    }
266
267    fn remote_addr(&self) -> SocketAddr {
268        self.protocol.remote_addr()
269    }
270
271    fn poll(&mut self, now: Instant) -> TurnPollRet {
272        let mut is_handshaking = false;
273        let mut have_outgoing = false;
274        for (idx, socket) in self.sockets.iter_mut().enumerate() {
275            let stream = match socket.handshake.complete() {
276                Ok(stream) => stream,
277                Err(e) => {
278                    if e.kind() == std::io::ErrorKind::WouldBlock {
279                        is_handshaking = true;
280                        continue;
281                    } else {
282                        warn!("Openssl produced error: {e:?}");
283                        return TurnPollRet::Closed;
284                    }
285                }
286            };
287            socket.shutdown = stream.get_shutdown();
288            if !socket.handshake.inner_mut().outgoing.is_empty() {
289                have_outgoing = true;
290                continue;
291            }
292            if socket
293                .shutdown
294                .contains(ShutdownState::SENT | ShutdownState::RECEIVED)
295            {
296                let socket = self.sockets.swap_remove(idx);
297                if self.transport() == TransportType::Tcp {
298                    return TurnPollRet::TcpClose {
299                        local_addr: socket.local_addr,
300                        remote_addr: socket.remote_addr,
301                    };
302                } else {
303                    have_outgoing = true;
304                    break;
305                }
306            }
307        }
308        if have_outgoing {
309            return TurnPollRet::WaitUntil(now);
310        }
311        if is_handshaking {
312            // FIXME: try to determine a more appropriate timeout for an in progress handshake.
313            return TurnPollRet::WaitUntil(now + Duration::from_millis(200));
314        }
315        let protocol_ret = self.protocol.poll(now);
316        if let TurnPollRet::TcpClose {
317            local_addr,
318            remote_addr,
319        } = protocol_ret
320        {
321            if let Some((idx, socket)) =
322                self.sockets.iter_mut().enumerate().find(|(_idx, socket)| {
323                    socket.local_addr == local_addr && socket.remote_addr == remote_addr
324                })
325            {
326                if let Ok(stream) = socket.handshake.complete() {
327                    let _ = stream.shutdown();
328                    socket.shutdown = stream.get_shutdown();
329                } else {
330                    self.sockets.swap_remove(idx);
331                }
332                return TurnPollRet::WaitUntil(now);
333            }
334        }
335        protocol_ret
336    }
337
338    fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
339        self.protocol.relayed_addresses()
340    }
341
342    fn permissions(
343        &self,
344        transport: TransportType,
345        relayed: SocketAddr,
346    ) -> impl Iterator<Item = IpAddr> + '_ {
347        self.protocol.permissions(transport, relayed)
348    }
349
350    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
351        let client_transport = self.transport();
352        for socket in self.sockets.iter_mut() {
353            if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
354                return Some(Transmit::new(
355                    outgoing.into_boxed_slice().into(),
356                    client_transport,
357                    socket.local_addr,
358                    socket.remote_addr,
359                ));
360            }
361
362            let stream = match socket.handshake.complete() {
363                Ok(stream) => stream,
364                Err(e) => {
365                    warn!("handshake error: {e:?}");
366                    if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
367                        return Some(Transmit::new(
368                            outgoing.into_boxed_slice().into(),
369                            client_transport,
370                            socket.local_addr,
371                            socket.remote_addr,
372                        ));
373                    } else {
374                        return None;
375                    }
376                }
377            };
378            for data in socket.pending_write.drain(..) {
379                warn!("write early data, {} bytes", data.len());
380                stream.write_all(&data).unwrap()
381            }
382        }
383        self.empty_transmit_queue(now);
384        for socket in self.sockets.iter_mut() {
385            if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
386                return Some(Transmit::new(
387                    outgoing.into_boxed_slice().into(),
388                    client_transport,
389                    socket.local_addr,
390                    socket.remote_addr,
391                ));
392            }
393        }
394        None
395    }
396
397    fn poll_event(&mut self) -> Option<TurnEvent> {
398        self.protocol.poll_event()
399    }
400
401    fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
402        self.protocol.delete(now)?;
403        self.empty_transmit_queue(now);
404        Ok(())
405    }
406
407    fn create_permission(
408        &mut self,
409        transport: TransportType,
410        peer_addr: IpAddr,
411        now: Instant,
412    ) -> Result<(), CreatePermissionError> {
413        self.protocol.create_permission(transport, peer_addr, now)?;
414        self.empty_transmit_queue(now);
415        Ok(())
416    }
417
418    fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
419        self.protocol.have_permission(transport, to)
420    }
421
422    fn bind_channel(
423        &mut self,
424        transport: TransportType,
425        peer_addr: SocketAddr,
426        now: Instant,
427    ) -> Result<(), BindChannelError> {
428        self.protocol.bind_channel(transport, peer_addr, now)?;
429        self.empty_transmit_queue(now);
430        Ok(())
431    }
432
433    fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
434        self.protocol.tcp_connect(peer_addr, now)?;
435
436        self.empty_transmit_queue(now);
437
438        Ok(())
439    }
440
441    fn allocated_tcp_socket(
442        &mut self,
443        id: u32,
444        five_tuple: Socket5Tuple,
445        peer_addr: SocketAddr,
446        local_addr: Option<SocketAddr>,
447        now: Instant,
448    ) -> Result<(), TcpAllocateError> {
449        self.protocol
450            .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
451
452        if let Some(local_addr) = local_addr {
453            self.sockets.push(Socket {
454                local_addr,
455                remote_addr: self.remote_addr(),
456                handshake: HandshakeState::Init(
457                    Ssl::new(&self.ssl_context).expect("Failed to create SSL"),
458                    OsslBio::default(),
459                ),
460                pending_write: VecDeque::default(),
461                shutdown: ShutdownState::empty(),
462            });
463        }
464
465        self.empty_transmit_queue(now);
466
467        Ok(())
468    }
469
470    fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
471        let Some(socket) = self
472            .sockets
473            .iter_mut()
474            .find(|socket| socket.local_addr == local_addr && socket.remote_addr == remote_addr)
475        else {
476            warn!(
477                "Unknown socket local:{}, remote:{}",
478                local_addr, remote_addr
479            );
480            return;
481        };
482        self.protocol.tcp_closed(local_addr, remote_addr, now);
483        if let Ok(stream) = socket.handshake.complete() {
484            socket.shutdown |= match stream.shutdown() {
485                Ok(ShutdownResult::Sent) => ShutdownState::SENT,
486                Ok(ShutdownResult::Received) => ShutdownState::RECEIVED,
487                Err(e) => {
488                    warn!("Failed to close TLS connection: {e:?}");
489                    return;
490                }
491            }
492        }
493    }
494
495    fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
496        &mut self,
497        transport: TransportType,
498        to: SocketAddr,
499        data: T,
500        now: Instant,
501    ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
502        let client_transport = self.transport();
503        if let Some(transmit) = self.protocol.send_to(transport, to, data, now)? {
504            let Some(socket) = self.sockets.iter_mut().find(|socket| {
505                socket.local_addr == transmit.from
506                    && socket.remote_addr == transmit.to
507                    && !socket.shutdown.contains(ShutdownState::SENT)
508            }) else {
509                warn!(
510                    "no socket for transmit from {} to {}",
511                    transmit.from, transmit.to
512                );
513                return Err(SendError::NoTcpSocket);
514            };
515            let stream = socket.handshake.complete().expect("No TLS connection yet");
516            let transmit = transmit.build();
517            for data in socket.pending_write.drain(..) {
518                stream.write_all(&data).unwrap()
519            }
520            if let Err(e) = stream.write_all(&transmit.data) {
521                self.protocol.protocol_error();
522                warn!("Error when writing plaintext: {e:?}");
523                return Err(SendError::NoAllocation);
524            }
525
526            if let Some(outgoing) = stream.get_mut().pop_outgoing() {
527                return Ok(Some(TransmitBuild::new(
528                    DelayedMessageOrChannelSend::OwnedData(outgoing),
529                    client_transport,
530                    socket.local_addr,
531                    socket.remote_addr,
532                )));
533            }
534        }
535
536        Ok(None)
537    }
538
539    #[tracing::instrument(
540        name = "turn_openssl_recv",
541        skip(self, transmit, now),
542        fields(
543            transport = %transmit.transport,
544            from = ?transmit.from,
545            data_len = transmit.data.as_ref().len()
546        )
547    )]
548    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
549        &mut self,
550        transmit: Transmit<T>,
551        now: Instant,
552    ) -> TurnRecvRet<T> {
553        /* is this data for our client? */
554        if self.transport() != transmit.transport {
555            return TurnRecvRet::Ignored(transmit);
556        }
557        let Some(socket) = self
558            .sockets
559            .iter_mut()
560            .find(|socket| socket.local_addr == transmit.to && socket.remote_addr == transmit.from)
561        else {
562            trace!(
563                "received data not directed at us ({} {:?}) but for {} {:?}!",
564                self.transport(),
565                self.local_addr(),
566                transmit.transport,
567                transmit.to,
568            );
569            return TurnRecvRet::Ignored(transmit);
570        };
571
572        socket
573            .handshake
574            .inner_mut()
575            .push_incoming(transmit.data.as_ref());
576
577        let stream = match socket.handshake.complete() {
578            Ok(stream) => stream,
579            Err(e) => {
580                if e.kind() == std::io::ErrorKind::WouldBlock {
581                    return TurnRecvRet::Handled;
582                }
583                return TurnRecvRet::Ignored(transmit);
584            }
585        };
586
587        let mut out = vec![0; 2048];
588        let len = match stream.read(&mut out) {
589            Ok(len) => len,
590            Err(e) => {
591                if e.kind() != std::io::ErrorKind::WouldBlock {
592                    self.protocol.protocol_error();
593                    tracing::warn!("Error: {e}");
594                }
595                return TurnRecvRet::Ignored(transmit);
596            }
597        };
598        out.resize(len, 0);
599
600        let transmit = Transmit::new(out, transmit.transport, transmit.from, transmit.to);
601
602        match self.protocol.recv(transmit, now) {
603            TurnRecvRet::Ignored(_) => unreachable!(),
604            TurnRecvRet::PeerData(peer_data) => TurnRecvRet::PeerData(peer_data.into_owned()),
605            TurnRecvRet::Handled => TurnRecvRet::Handled,
606            TurnRecvRet::PeerIcmp {
607                transport,
608                peer,
609                icmp_type,
610                icmp_code,
611                icmp_data,
612            } => TurnRecvRet::PeerIcmp {
613                transport,
614                peer,
615                icmp_type,
616                icmp_code,
617                icmp_data,
618            },
619        }
620    }
621
622    fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
623        self.protocol.poll_recv(now)
624    }
625
626    fn protocol_error(&mut self) {
627        self.protocol.protocol_error()
628    }
629}