1#![deny(missing_debug_implementations)]
30#![deny(missing_docs)]
31#![cfg_attr(docsrs, feature(doc_cfg))]
32#![deny(clippy::std_instead_of_core)]
33#![deny(clippy::std_instead_of_alloc)]
34#![no_std]
35
36extern crate alloc;
37
38#[cfg(any(feature = "std", test))]
39extern crate std;
40
41use alloc::collections::VecDeque;
42use alloc::string::String;
43use alloc::sync::Arc;
44use alloc::vec;
45use alloc::vec::Vec;
46use core::net::SocketAddr;
47use core::time::Duration;
48use turn_server_proto::types::prelude::DelayedTransmitBuild;
49use turn_server_proto::types::transmit::TransmitBuild;
50use turn_server_proto::types::AddressFamily;
51
52use turn_server_proto::api::Transmit;
53use turn_server_proto::types::Instant;
54use turn_server_proto::types::stun::TransportType;
55
56pub use turn_server_proto as proto;
57pub use turn_server_proto::api as api;
58
59use turn_server_proto::api::{
60 DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
61};
62use turn_server_proto::server::TurnServer;
63
64use tracing::{info, trace, warn};
65
66#[derive(Debug)]
68pub struct DimplTurnServer {
69 server: TurnServer,
70 config: Arc<dimpl::Config>,
71 certificate: dimpl::DtlsCertificate,
72 clients: Vec<Client>,
73}
74
75#[derive(Debug)]
76struct Client {
77 client_addr: SocketAddr,
78 dtls: dimpl::Dtls,
79 base_instant: std::time::Instant,
80 base_now: Instant,
81 connected: bool,
82 pending_encrypted: VecDeque<Vec<u8>>,
83 pending_incoming_plaintext: VecDeque<Vec<u8>>,
84}
85
86impl Client {
87 fn poll(&mut self, now: Instant) -> Option<Instant> {
88 let _ = self.dtls.handle_timeout(
89 Instant::from_nanos((now - self.base_now).as_nanos() as i64).to_std(self.base_instant),
90 );
91 let mut out = [0; 2048];
92 let mut earliest_wait = None;
93 loop {
94 match self.dtls.poll_output(&mut out) {
95 dimpl::Output::Packet(p) => {
96 self.pending_encrypted.push_back(p.to_vec());
97 earliest_wait = Some(now);
98 }
99 dimpl::Output::Timeout(time) => {
100 let wait = Instant::from_nanos((time - self.base_instant).as_nanos() as i64);
101 if wait == now {
102 let _ = self.dtls.handle_timeout(time);
103 continue;
104 }
105 if earliest_wait.is_none_or(|earliest| earliest > wait) {
106 earliest_wait = Some(wait);
107 }
108 break;
109 }
110 dimpl::Output::Connected => self.connected = true,
111 dimpl::Output::PeerCert(_peer_cert) => (),
113 dimpl::Output::KeyingMaterial(_key, _srtp_profile) => (),
114 dimpl::Output::ApplicationData(app_data) => {
115 self.pending_incoming_plaintext.push_back(app_data.to_vec());
116 }
117 _ => (),
118 }
119 }
120 earliest_wait
121 }
122
123 fn poll_plaintext(&mut self) -> Option<Vec<u8>> {
124 self.pending_incoming_plaintext.pop_front()
125 }
126
127 fn poll_encrypted(&mut self) -> Option<Vec<u8>> {
128 self.pending_encrypted.pop_front()
129 }
130}
131
132impl DimplTurnServer {
133 pub fn new(
135 transport: TransportType,
136 listen_addr: SocketAddr,
137 realm: String,
138 config: Arc<dimpl::Config>,
139 certificate: dimpl::DtlsCertificate,
140 ) -> Self {
141 Self {
142 server: TurnServer::new(transport, listen_addr, realm),
143 config,
144 certificate,
145 clients: vec![],
146 }
147 }
148}
149
150impl TurnServerApi for DimplTurnServer {
151 fn add_user(&mut self, username: String, password: String) {
153 self.server.add_user(username, password)
154 }
155
156 fn listen_address(&self) -> SocketAddr {
158 self.server.listen_address()
159 }
160
161 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
164 self.server.set_nonce_expiry_duration(expiry_duration)
165 }
166
167 #[tracing::instrument(
171 name = "turn_server_dimpl_recv",
172 skip(self, transmit, now),
173 fields(
174 from = ?transmit.from,
175 data_len = transmit.data.as_ref().len()
176 )
177 )]
178 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
179 &mut self,
180 transmit: Transmit<T>,
181 now: Instant,
182 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
183 let listen_address = self.listen_address();
184 if transmit.to == listen_address {
185 trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
186 let client = match self
188 .clients
189 .iter_mut()
190 .find(|client| client.client_addr == transmit.from)
191 {
192 Some(client) => client,
193 None => {
194 let len = self.clients.len();
195 let base_instant = std::time::Instant::now();
196 let mut dtls = dimpl::Dtls::new_auto(
197 self.config.clone(),
198 self.certificate.clone(),
199 base_instant,
200 );
201 dtls.set_active(false);
202 let mut client = Client {
203 client_addr: transmit.from,
204 dtls,
205 base_instant,
206 base_now: now,
207 connected: false,
208 pending_encrypted: VecDeque::default(),
209 pending_incoming_plaintext: VecDeque::default(),
210 };
211 client.poll(now);
213 self.clients.push(client);
214 info!(
215 "new connection from {} {}",
216 transmit.transport, transmit.from
217 );
218 &mut self.clients[len]
219 }
220 };
221 match client.dtls.handle_packet(transmit.data.as_ref()) {
222 Ok(_) => (),
223 Err(e) => {
224 warn!("error accepting TLS: {e}");
225 return None;
226 }
227 };
228
229 client.poll(now);
230 while let Some(plaintext) = client.poll_plaintext() {
231 let Some(transmit) = self.server.recv(
232 Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
233 now,
234 ) else {
235 continue;
236 };
237
238 if transmit.from == listen_address && transmit.to == client.client_addr {
239 client
240 .dtls
241 .send_application_data(&transmit.data.build())
242 .unwrap();
243 client.poll(now);
244 let Some(data) = client.poll_encrypted() else {
245 continue;
246 };
247 return Some(TransmitBuild::new(
248 DelayedMessageOrChannelSend::Owned(data),
249 transmit.transport,
250 listen_address,
251 client.client_addr,
252 ));
253 } else {
254 let transmit = transmit.build();
255 return Some(TransmitBuild::new(
256 DelayedMessageOrChannelSend::Owned(transmit.data),
257 transmit.transport,
258 transmit.from,
259 transmit.to,
260 ));
261 }
262 }
263 None
264 } else if let Some(transmit) = self.server.recv(transmit, now) {
265 if transmit.from == listen_address {
267 let Some(client) = self
268 .clients
269 .iter_mut()
270 .find(|client| transmit.to == client.client_addr)
271 else {
272 return Some(transmit);
273 };
274
275 let _ = client.dtls.send_application_data(&transmit.data.build());
276 client.poll(now);
277 client.poll_encrypted().map(|encrypted| {
278 TransmitBuild::new(
279 DelayedMessageOrChannelSend::Owned(encrypted),
280 transmit.transport,
281 listen_address,
282 client.client_addr,
283 )
284 })
285 } else {
286 Some(transmit)
287 }
288 } else {
289 None
290 }
291 }
292
293 fn recv_icmp<T: AsRef<[u8]>>(
294 &mut self,
295 family: AddressFamily,
296 bytes: T,
297 now: Instant,
298 ) -> Option<Transmit<Vec<u8>>> {
299 let transmit = self.server.recv_icmp(family, bytes, now)?;
300 let listen_address = self.listen_address();
302 if transmit.from == listen_address {
303 let Some(client) = self
304 .clients
305 .iter_mut()
306 .find(|client| transmit.to == client.client_addr)
307 else {
308 return Some(transmit);
309 };
310
311 client.dtls.send_application_data(&transmit.data).unwrap();
312 client.poll(now);
313 client.poll_encrypted().map(|encrypted| {
314 Transmit::new(
315 encrypted,
316 transmit.transport,
317 listen_address,
318 client.client_addr,
319 )
320 })
321 } else {
322 Some(transmit)
323 }
324 }
325
326 fn poll(&mut self, now: Instant) -> TurnServerPollRet {
330 let protocol_ret = self.server.poll(now);
331 let mut have_pending = false;
332 for client in self.clients.iter_mut() {
333 client.poll(now);
334 if !client.pending_encrypted.is_empty() {
335 have_pending = true;
336 continue;
337 }
338 }
339 if have_pending {
340 return TurnServerPollRet::WaitUntil(now);
341 }
342 protocol_ret
343 }
344
345 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
347 let listen_address = self.listen_address();
348
349 for client in self.clients.iter_mut() {
350 if let Some(data) = client.poll_encrypted() {
351 return Some(Transmit::new(
352 data,
353 TransportType::Udp,
354 listen_address,
355 client.client_addr,
356 ));
357 }
358 }
359
360 while let Some(transmit) = self.server.poll_transmit(now) {
361 let Some(client) = self
362 .clients
363 .iter_mut()
364 .find(|client| transmit.to == client.client_addr)
365 else {
366 warn!("return transmit: {transmit:?}");
367 return Some(transmit);
368 };
369 client.dtls.send_application_data(&transmit.data).unwrap();
370 client.poll(now);
371
372 if let Some(data) = client.poll_encrypted() {
373 return Some(Transmit::new(
374 data,
375 TransportType::Udp,
376 listen_address,
377 client.client_addr,
378 ));
379 }
380 }
381 None
382 }
383
384 fn allocated_socket(
387 &mut self,
388 transport: TransportType,
389 local_addr: SocketAddr,
390 remote_addr: SocketAddr,
391 allocation_transport: TransportType,
392 family: AddressFamily,
393 socket_addr: Result<SocketAddr, SocketAllocateError>,
394 now: Instant,
395 ) {
396 self.server.allocated_socket(
397 transport,
398 local_addr,
399 remote_addr,
400 allocation_transport,
401 family,
402 socket_addr,
403 now,
404 )
405 }
406
407 fn tcp_connected(
408 &mut self,
409 relayed_addr: SocketAddr,
410 peer_addr: SocketAddr,
411 listen_addr: SocketAddr,
412 client_addr: SocketAddr,
413 socket_addr: Result<SocketAddr, api::TcpConnectError>,
414 now: Instant,
415 ) {
416 self.server.tcp_connected(
417 relayed_addr,
418 peer_addr,
419 listen_addr,
420 client_addr,
421 socket_addr,
422 now,
423 )
424 }
425}
426
427
428#[cfg(test)]
429mod tests {
430 use tracing::subscriber::DefaultGuard;
431 use tracing_subscriber::layer::SubscriberExt;
432 use tracing_subscriber::Layer;
433
434 use super::*;
435
436 fn test_init_log() -> DefaultGuard {
437 crate::proto::types::debug_init();
438 let level_filter = std::env::var("TURN_LOG")
439 .or(std::env::var("RUST_LOG"))
440 .ok()
441 .and_then(|var| var.parse::<tracing_subscriber::filter::Targets>().ok())
442 .unwrap_or(
443 tracing_subscriber::filter::Targets::new().with_default(tracing::Level::TRACE),
444 );
445 let registry = tracing_subscriber::registry().with(
446 tracing_subscriber::fmt::layer()
447 .with_file(true)
448 .with_line_number(true)
449 .with_level(true)
450 .with_target(false)
451 .with_test_writer()
452 .with_filter(level_filter),
453 );
454 tracing::subscriber::set_default(registry)
455 }
456
457 fn generate_cert() -> dimpl::DtlsCertificate {
458 dimpl::certificate::generate_self_signed_certificate().unwrap()
459 }
460
461 #[test]
462 fn constructor() {
463 let _log = test_init_log();
464 let config = Arc::new(dimpl::Config::builder().build().unwrap());
465 let listen_addr = "127.0.0.1:3478".parse().unwrap();
466 let realm = String::from("realm");
467 let cert = generate_cert();
468 let server = DimplTurnServer::new(TransportType::Udp, listen_addr, realm, config, cert);
469 assert_eq!(server.listen_address(), listen_addr);
470 }
471}