1use alloc::collections::BTreeMap;
16use alloc::vec::Vec;
17use core::net::{IpAddr, SocketAddr};
18use core::ops::Range;
19use turn_types::stun::message::Message;
20
21use stun_proto::agent::{StunAgent, Transmit};
22use stun_proto::types::data::Data;
23use stun_proto::types::TransportType;
24use stun_proto::Instant;
25
26use turn_types::channel::ChannelData;
27use turn_types::tcp::{IncomingTcp, StoredTcp, TurnTcpBuffer};
28
29use tracing::{trace, warn};
30
31use crate::api::{
32 DataRangeOrOwned, DelayedMessageOrChannelSend, Socket5Tuple, TcpAllocateError, TcpConnectError,
33 TransmitBuild, TurnClientApi, TurnConfig, TurnPeerData,
34};
35use crate::protocol::{TurnClientProtocol, TurnProtocolChannelRecv, TurnProtocolRecv};
36
37pub use crate::api::{
38 BindChannelError, CreatePermissionError, DeleteError, SendError, TurnEvent, TurnPollRet,
39 TurnRecvRet,
40};
41
42#[derive(Debug)]
44pub struct TurnClientTcp {
45 protocol: TurnClientProtocol,
46 incoming_tcp_buffers: BTreeMap<(SocketAddr, SocketAddr), TcpBuffer>,
47}
48
49#[derive(Debug)]
50enum TcpBuffer {
51 Control(TurnTcpBuffer),
53 WaitingForConnectionBindResponse(TurnTcpBuffer),
54 PendingData(Vec<u8>, SocketAddr),
56 Passthrough(SocketAddr),
58}
59
60impl TurnClientTcp {
61 #[tracing::instrument(
86 name = "turn_client_tcp_allocate"
87 skip(config),
88 fields(
89 allocation_transport = %config.allocation_transport(),
90 )
91 )]
92 pub fn allocate(local_addr: SocketAddr, remote_addr: SocketAddr, config: TurnConfig) -> Self {
93 let stun_agent = StunAgent::builder(TransportType::Tcp, local_addr)
94 .remote_addr(remote_addr)
95 .build();
96
97 Self {
98 protocol: TurnClientProtocol::new(stun_agent, config),
99 incoming_tcp_buffers: BTreeMap::from([(
100 (local_addr, remote_addr),
101 TcpBuffer::Control(TurnTcpBuffer::new()),
102 )]),
103 }
104 }
105}
106
107impl TurnClientApi for TurnClientTcp {
108 fn transport(&self) -> TransportType {
109 self.protocol.transport()
110 }
111
112 fn local_addr(&self) -> SocketAddr {
113 self.protocol.local_addr()
114 }
115
116 fn remote_addr(&self) -> SocketAddr {
117 self.protocol.remote_addr()
118 }
119
120 fn poll(&mut self, now: Instant) -> TurnPollRet {
121 self.protocol.poll(now)
122 }
123
124 fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
125 self.protocol.relayed_addresses()
126 }
127
128 fn permissions(
129 &self,
130 transport: TransportType,
131 relayed: SocketAddr,
132 ) -> impl Iterator<Item = IpAddr> + '_ {
133 self.protocol.permissions(transport, relayed)
134 }
135
136 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
137 self.protocol.poll_transmit(now)
138 }
139
140 fn poll_event(&mut self) -> Option<TurnEvent> {
141 self.protocol.poll_event()
142 }
143
144 fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
145 self.protocol.delete(now)
146 }
147
148 fn create_permission(
149 &mut self,
150 transport: TransportType,
151 peer_addr: IpAddr,
152 now: Instant,
153 ) -> Result<(), CreatePermissionError> {
154 self.protocol.create_permission(transport, peer_addr, now)
155 }
156
157 fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
158 self.protocol.have_permission(transport, to)
159 }
160
161 fn bind_channel(
162 &mut self,
163 transport: TransportType,
164 peer_addr: SocketAddr,
165 now: Instant,
166 ) -> Result<(), BindChannelError> {
167 self.protocol.bind_channel(transport, peer_addr, now)
168 }
169
170 fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
171 self.protocol.tcp_connect(peer_addr, now)
172 }
173
174 fn allocated_tcp_socket(
175 &mut self,
176 id: u32,
177 five_tuple: Socket5Tuple,
178 peer_addr: SocketAddr,
179 local_addr: Option<SocketAddr>,
180 now: Instant,
181 ) -> Result<(), TcpAllocateError> {
182 self.protocol
183 .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
184 if let Some(local_addr) = local_addr {
185 self.incoming_tcp_buffers.insert(
186 (local_addr, self.remote_addr()),
187 TcpBuffer::WaitingForConnectionBindResponse(TurnTcpBuffer::new()),
188 );
189 }
190 Ok(())
191 }
192
193 fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
194 self.protocol.tcp_closed(local_addr, remote_addr, now);
195 }
196
197 fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
198 &mut self,
199 transport: TransportType,
200 to: SocketAddr,
201 data: T,
202 now: Instant,
203 ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
204 self.protocol.send_to(transport, to, data, now).map(Some)
205 }
206
207 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
208 &mut self,
209 transmit: Transmit<T>,
210 now: Instant,
211 ) -> TurnRecvRet<T> {
212 if self.transport() != transmit.transport || transmit.from != self.remote_addr() {
214 trace!(
215 "received data not directed at us ({:?}) but for {:?}!",
216 self.local_addr(),
217 transmit.to
218 );
219 return TurnRecvRet::Ignored(transmit);
220 }
221
222 let Some(tcp_buffer) = self
223 .incoming_tcp_buffers
224 .get_mut(&(transmit.to, transmit.from))
225 else {
226 return TurnRecvRet::Ignored(transmit);
227 };
228
229 if transmit.data.as_ref().is_empty() {
230 self.protocol.tcp_closed(transmit.to, transmit.from, now);
231 self.incoming_tcp_buffers
232 .remove(&(transmit.to, transmit.from));
233 return TurnRecvRet::Handled;
234 }
235
236 let tcp_buffer = match tcp_buffer {
237 TcpBuffer::WaitingForConnectionBindResponse(buffer) => {
238 match buffer.incoming_tcp(transmit) {
239 None => return TurnRecvRet::Handled,
240 Some(
242 IncomingTcp::CompleteChannel(transmit, _)
243 | IncomingTcp::StoredChannel(_, transmit),
244 ) => {
245 return TurnRecvRet::Ignored(transmit);
246 }
247 Some(IncomingTcp::CompleteMessage(transmit, msg_range)) => {
248 let Ok(msg) = Message::from_bytes(
249 &transmit.data.as_ref()[msg_range.start..msg_range.end],
250 ) else {
251 return TurnRecvRet::Handled;
253 };
254 let msg_transmit =
255 Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
256 if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
257 self.protocol.handle_message(msg_transmit, now)
258 {
259 let data_len = transmit.data.as_ref().len();
260 if msg_range.end < data_len {
261 trace!(
262 "Have {} bytes after success ConnectionBind from peer",
263 data_len - msg_range.end
264 );
265 *tcp_buffer = TcpBuffer::PendingData(
266 transmit.data.as_ref()[msg_range.end..].to_vec(),
267 peer_addr,
268 );
269 } else {
270 *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
271 }
272 return TurnRecvRet::Handled;
273 } else {
274 return TurnRecvRet::Handled;
276 }
277 }
278 Some(IncomingTcp::StoredMessage(msg_data, transmit)) => {
279 let Ok(msg) = Message::from_bytes(&msg_data) else {
280 return TurnRecvRet::Handled;
281 };
282 let msg_transmit =
283 Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
284 if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
285 self.protocol.handle_message(msg_transmit, now)
286 {
287 if buffer.is_empty() {
288 *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
289 } else {
290 let mut new_buffer = TurnTcpBuffer::new();
291 core::mem::swap(buffer, &mut new_buffer);
292 let data = new_buffer.into_inner();
293 *tcp_buffer = TcpBuffer::PendingData(data, peer_addr);
294 }
295 }
296 return TurnRecvRet::Handled;
297 }
298 }
299 }
300 TcpBuffer::PendingData(data, peer) => {
301 let mut replace = Vec::default();
302 core::mem::swap(&mut replace, data);
303 replace.extend_from_slice(transmit.data.as_ref());
304 let ret = TurnRecvRet::PeerData(TurnPeerData {
305 data: DataRangeOrOwned::Owned(replace),
306 transport: transmit.transport,
307 peer: *peer,
308 });
309 *tcp_buffer = TcpBuffer::Passthrough(*peer);
310 return ret;
311 }
312 TcpBuffer::Passthrough(peer) => {
313 return TurnRecvRet::PeerData(TurnPeerData {
314 data: DataRangeOrOwned::Range {
315 range: 0..transmit.data.as_ref().len(),
316 data: transmit.data,
317 },
318 transport: transmit.transport,
319 peer: *peer,
320 });
321 }
322 TcpBuffer::Control(tcp_buffer) => tcp_buffer,
323 };
324
325 let ret = match tcp_buffer.incoming_tcp(transmit) {
326 None => TurnRecvRet::Handled,
327 Some(IncomingTcp::CompleteMessage(transmit, msg_range)) => {
328 let Ok(msg) =
329 Message::from_bytes(&transmit.data.as_ref()[msg_range.start..msg_range.end])
330 else {
331 return TurnRecvRet::Handled;
332 };
333 let msg_transmit =
334 Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
335 TurnRecvRet::from_protocol_recv_subrange(
336 self.protocol.handle_message(msg_transmit, now),
337 transmit,
338 msg_range.start,
339 )
340 }
341 Some(IncomingTcp::CompleteChannel(transmit, range)) => {
342 let channel =
343 ChannelData::parse(&transmit.data.as_ref()[range.start..range.end]).unwrap();
344 match self.protocol.handle_channel(channel, now) {
345 TurnProtocolChannelRecv::Ignored => TurnRecvRet::Ignored(transmit),
347 TurnProtocolChannelRecv::PeerData {
348 range,
349 transport,
350 peer,
351 } => TurnRecvRet::PeerData(TurnPeerData {
352 data: DataRangeOrOwned::Range {
353 data: transmit.data,
354 range,
355 },
356 transport,
357 peer,
358 }),
359 }
360 }
361 Some(IncomingTcp::StoredMessage(msg_data, transmit)) => {
362 let Ok(msg) = Message::from_bytes(&msg_data) else {
363 return TurnRecvRet::Handled;
364 };
365 let msg_transmit =
366 Transmit::new(msg, transmit.transport, transmit.from, transmit.to);
367 TurnRecvRet::from_protocol_recv_stored(
368 self.protocol.handle_message(msg_transmit, now),
369 transmit,
370 msg_data,
371 )
372 }
373 Some(IncomingTcp::StoredChannel(data, transmit)) => {
374 let channel = ChannelData::parse(&data).unwrap();
375 match self.protocol.handle_channel(channel, now) {
376 TurnProtocolChannelRecv::Ignored => TurnRecvRet::Ignored(transmit),
378 TurnProtocolChannelRecv::PeerData {
379 range,
380 transport,
381 peer,
382 } => TurnRecvRet::PeerData(TurnPeerData {
383 data: DataRangeOrOwned::Owned(ensure_data_owned(data, range)),
384 transport,
385 peer,
386 }),
387 }
388 }
389 };
390
391 if matches!(ret, TurnRecvRet::Handled | TurnRecvRet::Ignored(_)) {
392 if let Some(TurnPeerData {
393 data,
394 transport,
395 peer,
396 }) = self.poll_recv(now)
397 {
398 return TurnRecvRet::PeerData(TurnPeerData {
399 data: data.into_owned(),
400 transport,
401 peer,
402 });
403 }
404 }
405 ret
406 }
407
408 fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
409 for ((local_addr, remote_addr), tcp_buffer) in self.incoming_tcp_buffers.iter_mut() {
410 match tcp_buffer {
411 TcpBuffer::Passthrough(_) => continue,
412 TcpBuffer::PendingData(data, peer) => {
413 let mut replace = Vec::default();
414 core::mem::swap(&mut replace, data);
415 let ret = Some(TurnPeerData {
416 data: DataRangeOrOwned::Owned(replace),
417 transport: TransportType::Tcp,
418 peer: *peer,
419 });
420 *tcp_buffer = TcpBuffer::Passthrough(*peer);
421 return ret;
422 }
423 TcpBuffer::WaitingForConnectionBindResponse(buffer) => {
424 if let Some(recv) = buffer.poll_recv() {
425 match recv {
426 StoredTcp::Channel(_) => continue,
428 StoredTcp::Message(msg_data) => {
429 let Ok(msg) = Message::from_bytes(&msg_data) else {
430 continue;
431 };
432 if let TurnProtocolRecv::TcpConnectionBound { peer_addr } =
433 self.protocol.handle_message(
434 Transmit::new(
435 msg,
436 TransportType::Tcp,
437 *remote_addr,
438 *local_addr,
439 ),
440 now,
441 )
442 {
443 if buffer.is_empty() {
444 *tcp_buffer = TcpBuffer::Passthrough(peer_addr);
445 } else {
446 let mut new_buffer = TurnTcpBuffer::new();
447 core::mem::swap(buffer, &mut new_buffer);
448 let data = new_buffer.into_inner();
449 *tcp_buffer = TcpBuffer::PendingData(data, peer_addr);
450 }
451 }
452 }
453 }
454 }
455 }
456 TcpBuffer::Control(buffer) => {
457 while let Some(recv) = buffer.poll_recv() {
458 match recv {
459 StoredTcp::Message(msg_data) => {
460 let Ok(msg) = Message::from_bytes(&msg_data) else {
461 continue;
462 };
463 let msg_transmit = Transmit::new(
464 msg,
465 TransportType::Tcp,
466 *remote_addr,
467 *local_addr,
468 );
469 if let TurnProtocolRecv::PeerData {
470 range,
471 transport,
472 peer,
473 } = self.protocol.handle_message(msg_transmit, now)
474 {
475 return Some(TurnPeerData {
476 data: DataRangeOrOwned::Range {
477 data: msg_data,
478 range,
479 },
480 transport,
481 peer,
482 });
483 }
484 }
485 StoredTcp::Channel(data) => {
486 let Ok(channel) = ChannelData::parse(&data) else {
487 continue;
488 };
489 if let TurnProtocolChannelRecv::PeerData {
490 range,
491 transport,
492 peer,
493 } = self.protocol.handle_channel(channel, now)
494 {
495 return Some(TurnPeerData {
496 data: DataRangeOrOwned::Range { data, range },
497 transport,
498 peer,
499 });
500 }
501 }
502 }
503 }
504 }
505 }
506 }
507 None
508 }
509
510 fn protocol_error(&mut self) {
511 self.protocol.protocol_error()
512 }
513}
514
515pub(crate) fn ensure_data_owned(data: Vec<u8>, range: Range<usize>) -> Vec<u8> {
516 if range.start == 0 && range.end == data.len() {
517 data
518 } else {
519 data[range.start..range.end].to_vec()
520 }
521}