1#![deny(missing_debug_implementations)]
28#![deny(missing_docs)]
29#![cfg_attr(docsrs, feature(doc_cfg))]
30#![deny(clippy::std_instead_of_core)]
31#![deny(clippy::std_instead_of_alloc)]
32#![no_std]
33
34extern crate alloc;
35
36#[cfg(any(feature = "std", test))]
37extern crate std;
38
39use alloc::collections::VecDeque;
40use alloc::string::String;
41use alloc::vec;
42use alloc::vec::Vec;
43use core::net::SocketAddr;
44use core::time::Duration;
45
46use std::io::{Read, Write};
47
48use turn_server_proto::types::prelude::DelayedTransmitBuild;
49use turn_server_proto::types::transmit::TransmitBuild;
50use turn_server_proto::types::AddressFamily;
51
52use turn_server_proto::api::Transmit;
53use turn_server_proto::types::stun::TransportType;
54use turn_server_proto::types::Instant;
55
56pub use turn_server_proto as proto;
57pub use turn_server_proto::api;
58
59use turn_server_proto::api::{
60 DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
61};
62use turn_server_proto::server::TurnServer;
63
64use tracing::{info, trace, warn};
65
66use openssl::ssl::{
67 HandshakeError, MidHandshakeSslStream, ShutdownState, Ssl, SslContext, SslStream,
68};
69
70#[derive(Debug)]
72pub struct OpensslTurnServer {
73 server: TurnServer,
74 ssl_context: SslContext,
75 clients: Vec<Client>,
76}
77
78#[derive(Debug)]
79struct Client {
80 transport: TransportType,
81 client_addr: SocketAddr,
82 tls: HandshakeState,
83 shutdown: ShutdownState,
84}
85
86#[derive(Debug)]
87enum HandshakeState {
88 Init(Ssl, OsslBio),
89 Handshaking(MidHandshakeSslStream<OsslBio>),
90 Done(SslStream<OsslBio>),
91 Nothing,
92}
93
94impl HandshakeState {
95 fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
96 if let Self::Done(s) = self {
97 return Ok(s);
98 }
99 let taken = core::mem::replace(self, Self::Nothing);
100
101 let ret = match taken {
102 Self::Init(ssl, bio) => ssl.accept(bio),
103 Self::Handshaking(mid) => mid.handshake(),
104 Self::Done(_) | Self::Nothing => unreachable!(),
105 };
106
107 match ret {
108 Ok(s) => {
109 info!(
110 "SSL handshake completed with version {} cipher: {:?}",
111 s.ssl().version_str(),
112 s.ssl().current_cipher()
113 );
114 *self = Self::Done(s);
115 Ok(self.complete()?)
116 }
117 Err(HandshakeError::WouldBlock(mid)) => {
118 *self = Self::Handshaking(mid);
119 Err(std::io::Error::new(
120 std::io::ErrorKind::WouldBlock,
121 "Would Block",
122 ))
123 }
124 Err(HandshakeError::SetupFailure(e)) => {
125 warn!("Error during ssl setup: {e}");
126 Err(std::io::Error::new(
127 std::io::ErrorKind::ConnectionRefused,
128 e,
129 ))
130 }
131 Err(HandshakeError::Failure(mid)) => {
132 warn!("Failure during ssl setup: {}", mid.error());
133 *self = Self::Handshaking(mid);
134 Err(std::io::Error::new(
135 std::io::ErrorKind::WouldBlock,
136 "Would Block",
137 ))
138 }
139 }
140 }
141 fn inner_mut(&mut self) -> &mut OsslBio {
142 match self {
143 Self::Init(_ssl, stream) => stream,
144 Self::Handshaking(mid) => mid.get_mut(),
145 Self::Done(stream) => stream.get_mut(),
146 Self::Nothing => unreachable!(),
147 }
148 }
149}
150
151#[derive(Debug, Default)]
152struct OsslBio {
153 incoming: Vec<u8>,
154 outgoing: VecDeque<Vec<u8>>,
155}
156
157impl OsslBio {
158 fn push_incoming(&mut self, buf: &[u8]) {
159 self.incoming.extend_from_slice(buf)
160 }
161
162 fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
163 self.outgoing.pop_front()
164 }
165}
166
167impl std::io::Write for OsslBio {
168 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
169 self.outgoing.push_back(buf.to_vec());
170 Ok(buf.len())
171 }
172
173 fn flush(&mut self) -> std::io::Result<()> {
174 Ok(())
175 }
176}
177
178impl std::io::Read for OsslBio {
179 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
180 let len = self.incoming.len();
181 let max = buf.len().min(len);
182
183 if len == 0 {
184 return Err(std::io::Error::new(
185 std::io::ErrorKind::WouldBlock,
186 "Would Block",
187 ));
188 }
189
190 buf[..max].copy_from_slice(&self.incoming[..max]);
191 if max == len {
192 self.incoming.truncate(0);
193 } else {
194 self.incoming.drain(..max);
195 }
196
197 Ok(max)
198 }
199}
200
201impl OpensslTurnServer {
202 pub fn new(
204 transport: TransportType,
205 listen_addr: SocketAddr,
206 realm: String,
207 ssl_context: SslContext,
208 ) -> Self {
209 Self {
210 server: TurnServer::new(transport, listen_addr, realm),
211 ssl_context,
212 clients: vec![],
213 }
214 }
215}
216
217impl TurnServerApi for OpensslTurnServer {
218 fn add_user(&mut self, username: String, password: String) {
220 self.server.add_user(username, password)
221 }
222
223 fn listen_address(&self) -> SocketAddr {
225 self.server.listen_address()
226 }
227
228 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
231 self.server.set_nonce_expiry_duration(expiry_duration)
232 }
233
234 #[tracing::instrument(
238 name = "turn_server_openssl_recv",
239 skip(self, transmit, now),
240 fields(
241 from = ?transmit.from,
242 data_len = transmit.data.as_ref().len()
243 )
244 )]
245 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
246 &mut self,
247 transmit: Transmit<T>,
248 now: Instant,
249 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
250 let listen_address = self.listen_address();
251 if transmit.to == listen_address {
252 trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
253 let client = match self
255 .clients
256 .iter_mut()
257 .find(|client| client.client_addr == transmit.from)
258 {
259 Some(client) => client,
260 None => {
261 let len = self.clients.len();
262 let ssl = Ssl::new(&self.ssl_context).expect("Cannot create ssl structure");
263 self.clients.push(Client {
264 transport: transmit.transport,
265 client_addr: transmit.from,
266 tls: HandshakeState::Init(ssl, OsslBio::default()),
267 shutdown: ShutdownState::empty(),
268 });
269 info!(
270 "new connection from {} {}",
271 transmit.transport, transmit.from
272 );
273 &mut self.clients[len]
274 }
275 };
276 client.tls.inner_mut().push_incoming(transmit.data.as_ref());
277 let stream = match client.tls.complete() {
278 Ok(s) => s,
279 Err(e) => {
280 if e.kind() != std::io::ErrorKind::WouldBlock {
281 warn!("error accepting TLS: {e}");
282 }
283 return None;
284 }
285 };
286
287 let mut plaintext = vec![0; 2048];
288 let len = match stream.read(&mut plaintext) {
289 Ok(len) => len,
290 Err(e) => {
291 if e.kind() != std::io::ErrorKind::WouldBlock {
292 warn!("Error: {e}");
293 }
294 return None;
295 }
296 };
297 warn!("received: {len} plaintext bytes");
298 if len == 0 {
299 let pre_shutdown = stream.get_shutdown();
300 let _ = stream.shutdown();
301 client.shutdown = stream.get_shutdown();
302 if !pre_shutdown.contains(ShutdownState::SENT) {
303 return stream.get_mut().pop_outgoing().map(|data| {
304 TransmitBuild::new(
305 DelayedMessageOrChannelSend::Owned(data),
306 transmit.transport,
307 listen_address,
308 client.client_addr,
309 )
310 });
311 } else {
312 return None;
313 }
314 }
315 plaintext.resize(len, 0);
316
317 let transmit = self.server.recv(
318 Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
319 now,
320 )?;
321
322 if transmit.from == listen_address && transmit.to == client.client_addr {
323 let plaintext = transmit.data.build();
324 stream.write_all(&plaintext).unwrap();
325 stream.get_mut().pop_outgoing().map(|data| {
326 TransmitBuild::new(
327 DelayedMessageOrChannelSend::Owned(data),
328 transmit.transport,
329 listen_address,
330 client.client_addr,
331 )
332 })
333 } else {
334 let transmit = transmit.build();
335 Some(TransmitBuild::new(
336 DelayedMessageOrChannelSend::Owned(transmit.data),
337 transmit.transport,
338 transmit.from,
339 transmit.to,
340 ))
341 }
342 } else if let Some(transmit) = self.server.recv(transmit, now) {
343 if transmit.from == listen_address {
345 let Some(client) = self
346 .clients
347 .iter_mut()
348 .find(|client| transmit.to == client.client_addr)
349 else {
350 return Some(transmit);
351 };
352
353 let plaintext = transmit.data.build();
354 let stream = match client.tls.complete() {
355 Ok(s) => s,
356 Err(e) => {
357 if e.kind() != std::io::ErrorKind::WouldBlock {
358 warn!("error accepting TLS: {e}");
359 }
360 return None;
361 }
362 };
363 stream.write_all(&plaintext).unwrap();
364 stream.get_mut().pop_outgoing().map(|data| {
365 TransmitBuild::new(
366 DelayedMessageOrChannelSend::Owned(data),
367 transmit.transport,
368 listen_address,
369 client.client_addr,
370 )
371 })
372 } else {
373 Some(transmit)
374 }
375 } else {
376 None
377 }
378 }
379
380 fn recv_icmp<T: AsRef<[u8]>>(
381 &mut self,
382 family: AddressFamily,
383 bytes: T,
384 now: Instant,
385 ) -> Option<Transmit<Vec<u8>>> {
386 let transmit = self.server.recv_icmp(family, bytes, now)?;
387 let listen_address = self.listen_address();
389 if transmit.from == listen_address {
390 let Some(client) = self
391 .clients
392 .iter_mut()
393 .find(|client| transmit.to == client.client_addr)
394 else {
395 return Some(transmit);
396 };
397 let stream = match client.tls.complete() {
398 Ok(s) => s,
399 Err(e) => {
400 if e.kind() != std::io::ErrorKind::WouldBlock {
401 warn!("error accepting TLS: {e}");
402 }
403 return None;
404 }
405 };
406 stream.write_all(&transmit.data).unwrap();
407 stream.get_mut().pop_outgoing().map(|data| {
408 Transmit::new(data, transmit.transport, listen_address, client.client_addr)
409 })
410 } else {
411 Some(transmit)
412 }
413 }
414
415 fn poll(&mut self, now: Instant) -> TurnServerPollRet {
419 let listen_address = self.listen_address();
420 let protocol_ret = self.server.poll(now);
421 let mut have_pending = false;
422 for (idx, client) in self.clients.iter_mut().enumerate() {
423 let stream = match client.tls.complete() {
424 Ok(s) => s,
425 Err(_) => continue,
426 };
427 client.shutdown = stream.get_shutdown();
428 if !stream.get_mut().outgoing.is_empty() {
429 have_pending = true;
430 continue;
431 }
432 if client
433 .shutdown
434 .contains(ShutdownState::SENT | ShutdownState::RECEIVED)
435 {
436 let client = self.clients.swap_remove(idx);
437 return TurnServerPollRet::TcpClose {
438 local_addr: listen_address,
439 remote_addr: client.client_addr,
440 };
441 }
442 }
443 if have_pending {
444 return TurnServerPollRet::WaitUntil(now);
445 }
446 if let TurnServerPollRet::TcpClose {
447 local_addr: _,
448 remote_addr,
449 } = protocol_ret
450 {
451 let Some(client) = self
452 .clients
453 .iter_mut()
454 .find(|client| client.client_addr == remote_addr)
455 else {
456 return protocol_ret;
457 };
458 if let Ok(stream) = client.tls.complete() {
459 if let Err(e) = stream.shutdown() {
460 warn!("Failed to shutdown ssl connection to {remote_addr}: {e:?}");
461 }
462 client.shutdown = stream.get_shutdown();
463 }
464 return TurnServerPollRet::WaitUntil(now);
465 }
466 protocol_ret
467 }
468
469 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
471 let listen_address = self.listen_address();
472
473 for client in self.clients.iter_mut() {
474 if let Some(data) = client.tls.inner_mut().pop_outgoing() {
475 return Some(Transmit::new(
476 data,
477 client.transport,
478 listen_address,
479 client.client_addr,
480 ));
481 }
482 }
483
484 while let Some(transmit) = self.server.poll_transmit(now) {
485 let Some(client) = self
486 .clients
487 .iter_mut()
488 .find(|client| transmit.to == client.client_addr)
489 else {
490 warn!("return transmit: {transmit:?}");
491 return Some(transmit);
492 };
493 let stream = match client.tls.complete() {
494 Ok(s) => s,
495 Err(e) => {
497 warn!("early data -> ignored: {e:?}");
498 continue;
499 }
500 };
501 stream.write_all(&transmit.data).unwrap();
502
503 if let Some(data) = client.tls.inner_mut().pop_outgoing() {
504 return Some(Transmit::new(
505 data,
506 client.transport,
507 listen_address,
508 client.client_addr,
509 ));
510 }
511 }
512 None
513 }
514
515 fn allocated_socket(
518 &mut self,
519 transport: TransportType,
520 local_addr: SocketAddr,
521 remote_addr: SocketAddr,
522 allocation_transport: TransportType,
523 family: AddressFamily,
524 socket_addr: Result<SocketAddr, SocketAllocateError>,
525 now: Instant,
526 ) {
527 self.server.allocated_socket(
528 transport,
529 local_addr,
530 remote_addr,
531 allocation_transport,
532 family,
533 socket_addr,
534 now,
535 )
536 }
537
538 fn tcp_connected(
539 &mut self,
540 relayed_addr: SocketAddr,
541 peer_addr: SocketAddr,
542 listen_addr: SocketAddr,
543 client_addr: SocketAddr,
544 socket_addr: Result<SocketAddr, crate::api::TcpConnectError>,
545 now: Instant,
546 ) {
547 self.server.tcp_connected(
548 relayed_addr,
549 peer_addr,
550 listen_addr,
551 client_addr,
552 socket_addr,
553 now,
554 )
555 }
556}