Skip to main content

turn_client_proto/
tcp.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! TCP TURN client.
12//!
13//! An implementation of a TURN client suitable for TCP connections.
14
15use alloc::collections::BTreeMap;
16use alloc::vec::Vec;
17use core::net::{IpAddr, SocketAddr};
18use core::ops::Range;
19use turn_types::stun::message::Message;
20
21use stun_proto::agent::{StunAgent, Transmit};
22use stun_proto::types::data::Data;
23use stun_proto::types::TransportType;
24use stun_proto::Instant;
25
26use turn_types::channel::ChannelData;
27use turn_types::tcp::{IncomingTcp, StoredTcp, TurnTcpBuffer};
28
29use tracing::{trace, warn};
30
31use crate::api::{
32    DataRangeOrOwned, DelayedMessageOrChannelSend, Socket5Tuple, TcpAllocateError, TcpConnectError,
33    TransmitBuild, TurnClientApi, TurnConfig, TurnPeerData,
34};
35use crate::protocol::{TurnClientProtocol, TurnProtocolChannelRecv, TurnProtocolRecv};
36
37pub use crate::api::{
38    BindChannelError, CreatePermissionError, DeleteError, SendError, TurnEvent, TurnPollRet,
39    TurnRecvRet,
40};
41
42/// A TURN client.
43#[derive(Debug)]
44pub struct TurnClientTcp {
45    protocol: TurnClientProtocol,
46    incoming_tcp_buffers: BTreeMap<(SocketAddr, SocketAddr), TcpBuffer>,
47}
48
49#[derive(Debug)]
50enum TcpBuffer {
51    // The control TURN connection. Always buffered.
52    Control(TurnTcpBuffer),
53    WaitingForConnectionBindResponse(TurnTcpBuffer),
54    // reached if after ConnectionBind there is more data.
55    PendingData(Vec<u8>, SocketAddr),
56    // peer address
57    Passthrough(SocketAddr),
58}
59
60impl TurnClientTcp {
61    /// Allocate an address on a TURN server to relay data to and from peers.
62    ///
63    /// # Examples
64    /// ```
65    /// # use turn_types::TurnCredentials;
66    /// # use turn_client_proto::prelude::*;
67    /// # use turn_client_proto::tcp::TurnClientTcp;
68    /// # use turn_client_proto::api::TurnConfig;
69    /// # use stun_proto::types::TransportType;
70    /// let credentials = TurnCredentials::new("tuser", "tpass");
71    /// let mut config = TurnConfig::new(credentials);
72    /// // The transport protocol of the allocation on the TURN server.
73    /// config.set_allocation_transport(TransportType::Udp);
74    /// let local_addr = "192.168.0.1:4000".parse().unwrap();
75    /// let remote_addr = "10.0.0.1:3478".parse().unwrap();
76    /// let client = TurnClientTcp::allocate(
77    ///     local_addr,
78    ///     remote_addr,
79    ///     config,
80    /// );
81    /// assert_eq!(client.transport(), TransportType::Tcp);
82    /// assert_eq!(client.local_addr(), local_addr);
83    /// assert_eq!(client.remote_addr(), remote_addr);
84    /// ```
85    #[tracing::instrument(
86        name = "turn_client_tcp_allocate"
87        skip(config),
88        fields(
89            allocation_transport = %config.allocation_transport(),
90        )
91    )]
92    pub fn allocate(local_addr: SocketAddr, remote_addr: SocketAddr, config: TurnConfig) -> Self {
93        let stun_agent = StunAgent::builder(TransportType::Tcp, local_addr)
94            .remote_addr(remote_addr)
95            .build();
96
97        Self {
98            protocol: TurnClientProtocol::new(stun_agent, config),
99            incoming_tcp_buffers: BTreeMap::from([(
100                (local_addr, remote_addr),
101                TcpBuffer::Control(TurnTcpBuffer::new()),
102            )]),
103        }
104    }
105}
106
107impl TurnClientApi for TurnClientTcp {
108    fn transport(&self) -> TransportType {
109        self.protocol.transport()
110    }
111
112    fn local_addr(&self) -> SocketAddr {
113        self.protocol.local_addr()
114    }
115
116    fn remote_addr(&self) -> SocketAddr {
117        self.protocol.remote_addr()
118    }
119
120    fn poll(&mut self, now: Instant) -> TurnPollRet {
121        self.protocol.poll(now)
122    }
123
124    fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
125        self.protocol.relayed_addresses()
126    }
127
128    fn permissions(
129        &self,
130        transport: TransportType,
131        relayed: SocketAddr,
132    ) -> impl Iterator<Item = IpAddr> + '_ {
133        self.protocol.permissions(transport, relayed)
134    }
135
136    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
137        self.protocol.poll_transmit(now)
138    }
139
140    fn poll_event(&mut self) -> Option<TurnEvent> {
141        self.protocol.poll_event()
142    }
143
144    fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
145        self.protocol.delete(now)
146    }
147
148    fn create_permission(
149        &mut self,
150        transport: TransportType,
151        peer_addr: IpAddr,
152        now: Instant,
153    ) -> Result<(), CreatePermissionError> {
154        self.protocol.create_permission(transport, peer_addr, now)
155    }
156
157    fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
158        self.protocol.have_permission(transport, to)
159    }
160
161    fn bind_channel(
162        &mut self,
163        transport: TransportType,
164        peer_addr: SocketAddr,
165        now: Instant,
166    ) -> Result<(), BindChannelError> {
167        self.protocol.bind_channel(transport, peer_addr, now)
168    }
169
170    fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
171        self.protocol.tcp_connect(peer_addr, now)
172    }
173
174    fn allocated_tcp_socket(
175        &mut self,
176        id: u32,
177        five_tuple: Socket5Tuple,
178        peer_addr: SocketAddr,
179        local_addr: Option<SocketAddr>,
180        now: Instant,
181    ) -> Result<(), TcpAllocateError> {
182        self.protocol
183            .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
184        if let Some(local_addr) = local_addr {
185            self.incoming_tcp_buffers.insert(
186                (local_addr, self.remote_addr()),
187                TcpBuffer::WaitingForConnectionBindResponse(TurnTcpBuffer::new()),
188            );
189        }
190        Ok(())
191    }
192
193    fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
194        self.protocol.tcp_closed(local_addr, remote_addr, now);
195    }
196
197    fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
198        &mut self,
199        transport: TransportType,
200        to: SocketAddr,
201        data: T,
202        now: Instant,
203    ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
204        self.protocol.send_to(transport, to, data, now).map(Some)
205    }
206
207    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
208        &mut self,
209        transmit: Transmit<T>,
210        now: Instant,
211    ) -> TurnRecvRet<T> {
212        /* is this data for our client? */
213        if self.transport() != transmit.transport || transmit.from != self.remote_addr() {
214            trace!(
215                "received data not directed at us ({:?}) but for {:?}!",
216                self.local_addr(),
217                transmit.to
218            );
219            return TurnRecvRet::Ignored(transmit);
220        }
221
222        let Some(tcp_buffer) = self
223            .incoming_tcp_buffers
224            .get_mut(&(transmit.to, transmit.from))
225        else {
226            return TurnRecvRet::Ignored(transmit);
227        };
228
229        if transmit.data.as_ref().is_empty() {
230            self.protocol.tcp_closed(transmit.to, transmit.from, now);
231            self.incoming_tcp_buffers
232                .remove(&(transmit.to, transmit.from));
233            return TurnRecvRet::Handled;
234        }
235
236        let tcp_buffer = match tcp_buffer {
237            TcpBuffer::WaitingForConnectionBindResponse(buffer) => {
238                match buffer.incoming_tcp(transmit) {
239                    None => return TurnRecvRet::Handled,
240                    // protocol violation
241                    Some(
242                        IncomingTcp::CompleteChannel(transmit, _)
243                        | IncomingTcp::StoredChannel(_, transmit),
244                    ) => {
245                        return TurnRecvRet::Ignored(transmit);
246                    }
247                    Some(IncomingTcp::CompleteMessage(transmit, msg_range)) => {
248                        let Ok(msg) = Message::from_bytes(
249                            &transmit.data.as_ref()[msg_range.start..msg_range.end],
250                        ) else {
251                            // protocol violation
252                            return TurnRecvRet::Handled;
253                        };
254                        let msg_transmit =
255                            Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
256                        if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
257                            self.protocol.handle_message(msg_transmit, now)
258                        {
259                            let data_len = transmit.data.as_ref().len();
260                            if msg_range.end < data_len {
261                                trace!(
262                                    "Have {} bytes after success ConnectionBind from peer",
263                                    data_len - msg_range.end
264                                );
265                                *tcp_buffer = TcpBuffer::PendingData(
266                                    transmit.data.as_ref()[msg_range.end..].to_vec(),
267                                    peer_addr,
268                                );
269                            } else {
270                                *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
271                            }
272                            return TurnRecvRet::Handled;
273                        } else {
274                            // possible protocol violation
275                            return TurnRecvRet::Handled;
276                        }
277                    }
278                    Some(IncomingTcp::StoredMessage(msg_data, transmit)) => {
279                        let Ok(msg) = Message::from_bytes(&msg_data) else {
280                            return TurnRecvRet::Handled;
281                        };
282                        let msg_transmit =
283                            Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
284                        if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
285                            self.protocol.handle_message(msg_transmit, now)
286                        {
287                            if buffer.is_empty() {
288                                *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
289                            } else {
290                                let mut new_buffer = TurnTcpBuffer::new();
291                                core::mem::swap(buffer, &mut new_buffer);
292                                let data = new_buffer.into_inner();
293                                *tcp_buffer = TcpBuffer::PendingData(data, peer_addr);
294                            }
295                        }
296                        return TurnRecvRet::Handled;
297                    }
298                }
299            }
300            TcpBuffer::PendingData(data, peer) => {
301                let mut replace = Vec::default();
302                core::mem::swap(&mut replace, data);
303                replace.extend_from_slice(transmit.data.as_ref());
304                let ret = TurnRecvRet::PeerData(TurnPeerData {
305                    data: DataRangeOrOwned::Owned(replace),
306                    transport: transmit.transport,
307                    peer: *peer,
308                });
309                *tcp_buffer = TcpBuffer::Passthrough(*peer);
310                return ret;
311            }
312            TcpBuffer::Passthrough(peer) => {
313                return TurnRecvRet::PeerData(TurnPeerData {
314                    data: DataRangeOrOwned::Range {
315                        range: 0..transmit.data.as_ref().len(),
316                        data: transmit.data,
317                    },
318                    transport: transmit.transport,
319                    peer: *peer,
320                });
321            }
322            TcpBuffer::Control(tcp_buffer) => tcp_buffer,
323        };
324
325        let ret = match tcp_buffer.incoming_tcp(transmit) {
326            None => TurnRecvRet::Handled,
327            Some(IncomingTcp::CompleteMessage(transmit, msg_range)) => {
328                let Ok(msg) =
329                    Message::from_bytes(&transmit.data.as_ref()[msg_range.start..msg_range.end])
330                else {
331                    return TurnRecvRet::Handled;
332                };
333                let msg_transmit =
334                    Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
335                TurnRecvRet::from_protocol_recv_subrange(
336                    self.protocol.handle_message(msg_transmit, now),
337                    transmit,
338                    msg_range.start,
339                )
340            }
341            Some(IncomingTcp::CompleteChannel(transmit, range)) => {
342                let channel =
343                    ChannelData::parse(&transmit.data.as_ref()[range.start..range.end]).unwrap();
344                match self.protocol.handle_channel(channel, now) {
345                    // XXX: Ignored should probably produce an error for TCP
346                    TurnProtocolChannelRecv::Ignored => TurnRecvRet::Ignored(transmit),
347                    TurnProtocolChannelRecv::PeerData {
348                        range,
349                        transport,
350                        peer,
351                    } => TurnRecvRet::PeerData(TurnPeerData {
352                        data: DataRangeOrOwned::Range {
353                            data: transmit.data,
354                            range,
355                        },
356                        transport,
357                        peer,
358                    }),
359                }
360            }
361            Some(IncomingTcp::StoredMessage(msg_data, transmit)) => {
362                let Ok(msg) = Message::from_bytes(&msg_data) else {
363                    return TurnRecvRet::Handled;
364                };
365                let msg_transmit =
366                    Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
367                TurnRecvRet::from_protocol_recv_stored(
368                    self.protocol.handle_message(msg_transmit, now),
369                    transmit,
370                    msg_data,
371                )
372            }
373            Some(IncomingTcp::StoredChannel(data, transmit)) => {
374                let channel = ChannelData::parse(&data).unwrap();
375                match self.protocol.handle_channel(channel, now) {
376                    // XXX: Ignored should probably produce an error for TCP
377                    TurnProtocolChannelRecv::Ignored => TurnRecvRet::Ignored(transmit),
378                    TurnProtocolChannelRecv::PeerData {
379                        range,
380                        transport,
381                        peer,
382                    } => TurnRecvRet::PeerData(TurnPeerData {
383                        data: DataRangeOrOwned::Owned(ensure_data_owned(data, range)),
384                        transport,
385                        peer,
386                    }),
387                }
388            }
389        };
390
391        if matches!(ret, TurnRecvRet::Handled | TurnRecvRet::Ignored(_)) {
392            if let Some(TurnPeerData {
393                data,
394                transport,
395                peer,
396            }) = self.poll_recv(now)
397            {
398                return TurnRecvRet::PeerData(TurnPeerData {
399                    data: data.into_owned(),
400                    transport,
401                    peer,
402                });
403            }
404        }
405        ret
406    }
407
408    fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
409        for ((local_addr, remote_addr), tcp_buffer) in self.incoming_tcp_buffers.iter_mut() {
410            match tcp_buffer {
411                TcpBuffer::Passthrough(_) => continue,
412                TcpBuffer::PendingData(data, peer) => {
413                    let mut replace = Vec::default();
414                    core::mem::swap(&mut replace, data);
415                    let ret = Some(TurnPeerData {
416                        data: DataRangeOrOwned::Owned(replace),
417                        transport: TransportType::Tcp,
418                        peer: *peer,
419                    });
420                    *tcp_buffer = TcpBuffer::Passthrough(*peer);
421                    return ret;
422                }
423                TcpBuffer::WaitingForConnectionBindResponse(buffer) => {
424                    if let Some(recv) = buffer.poll_recv() {
425                        match recv {
426                            // protocol violation
427                            StoredTcp::Channel(_) => continue,
428                            StoredTcp::Message(msg_data) => {
429                                let Ok(msg) = Message::from_bytes(&msg_data) else {
430                                    continue;
431                                };
432                                if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
433                                    self.protocol.handle_message(
434                                        Transmit::new(
435                                            msg,
436                                            TransportType::Tcp,
437                                            *remote_addr,
438                                            *local_addr,
439                                        ),
440                                        now,
441                                    )
442                                {
443                                    if buffer.is_empty() {
444                                        *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
445                                    } else {
446                                        let mut new_buffer = TurnTcpBuffer::new();
447                                        core::mem::swap(buffer, &mut new_buffer);
448                                        let data = new_buffer.into_inner();
449                                        *tcp_buffer = TcpBuffer::PendingData(data, peer_addr);
450                                    }
451                                }
452                            }
453                        }
454                    }
455                }
456                TcpBuffer::Control(buffer) => {
457                    while let Some(recv) = buffer.poll_recv() {
458                        match recv {
459                            StoredTcp::Message(msg_data) => {
460                                let Ok(msg) = Message::from_bytes(&msg_data) else {
461                                    continue;
462                                };
463                                let msg_transmit = Transmit::new(
464                                    msg,
465                                    TransportType::Tcp,
466                                    *remote_addr,
467                                    *local_addr,
468                                );
469                                if let TurnProtocolRecv::PeerData {
470                                    range,
471                                    transport,
472                                    peer,
473                                } = self.protocol.handle_message(msg_transmit, now)
474                                {
475                                    return Some(TurnPeerData {
476                                        data: DataRangeOrOwned::Range {
477                                            data: msg_data,
478                                            range,
479                                        },
480                                        transport,
481                                        peer,
482                                    });
483                                }
484                            }
485                            StoredTcp::Channel(data) => {
486                                let Ok(channel) = ChannelData::parse(&data) else {
487                                    continue;
488                                };
489                                if let TurnProtocolChannelRecv::PeerData {
490                                    range,
491                                    transport,
492                                    peer,
493                                } = self.protocol.handle_channel(channel, now)
494                                {
495                                    return Some(TurnPeerData {
496                                        data: DataRangeOrOwned::Range { data, range },
497                                        transport,
498                                        peer,
499                                    });
500                                }
501                            }
502                        }
503                    }
504                }
505            }
506        }
507        None
508    }
509
510    fn protocol_error(&mut self) {
511        self.protocol.protocol_error()
512    }
513}
514
515pub(crate) fn ensure_data_owned(data: Vec<u8>, range: Range<usize>) -> Vec<u8> {
516    if range.start == 0 && range.end == data.len() {
517        data
518    } else {
519        data[range.start..range.end].to_vec()
520    }
521}