Skip to main content

rift_nat/
turn.rs

1//! TURN (Traversal Using Relays around NAT) client implementation.
2//!
3//! This module provides a minimal TURN allocator for UDP relays. It supports:
4//! - Allocation requests with long-term credentials
5//! - Permission and channel binding management
6//! - Send/receive data indications
7//!
8//! The implementation favors simplicity and correctness over completeness.
9
10use std::collections::{HashMap, HashSet};
11use std::net::{IpAddr, Ipv4Addr, SocketAddr};
12use std::sync::Arc;
13use std::time::Duration;
14
15use tokio::net::UdpSocket;
16use tokio::sync::Mutex;
17use tokio::time::timeout;
18
19use rand::RngCore;
20use hmac::{Hmac, Mac};
21use sha1::Sha1;
22use crc32fast::Hasher as Crc32;
23
24use crate::{StunError};
25
26type HmacSha1 = Hmac<Sha1>;
27
28const STUN_MAGIC_COOKIE: u32 = 0x2112A442;
29const STUN_HEADER_LEN: usize = 20;
30
31const TURN_ALLOCATE_REQUEST: u16 = 0x0003;
32const TURN_ALLOCATE_RESPONSE: u16 = 0x0103;
33const TURN_ALLOCATE_ERROR: u16 = 0x0113;
34const TURN_CREATE_PERMISSION_REQUEST: u16 = 0x0008;
35const TURN_CREATE_PERMISSION_RESPONSE: u16 = 0x0108;
36const TURN_CHANNEL_BIND_REQUEST: u16 = 0x0009;
37const TURN_CHANNEL_BIND_RESPONSE: u16 = 0x0109;
38const TURN_SEND_INDICATION: u16 = 0x0016;
39const TURN_DATA_INDICATION: u16 = 0x0017;
40
41const ATTR_USERNAME: u16 = 0x0006;
42const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008;
43#[allow(dead_code)]
44const ATTR_ERROR_CODE: u16 = 0x0009;
45const ATTR_REALM: u16 = 0x0014;
46const ATTR_NONCE: u16 = 0x0015;
47const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
48const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019;
49const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012;
50const ATTR_DATA: u16 = 0x0013;
51const ATTR_CHANNEL_NUMBER: u16 = 0x000C;
52#[allow(dead_code)]
53const ATTR_LIFETIME: u16 = 0x000D;
54const ATTR_FINGERPRINT: u16 = 0x8028;
55
56const TURN_UDP_TRANSPORT: u8 = 17;
57const TURN_DEFAULT_PORT: u16 = 3478;
58
59#[derive(Debug, Clone)]
60pub struct TurnServerConfig {
61    /// TURN server socket address.
62    pub addr: SocketAddr,
63    /// Optional long-term auth username.
64    pub username: Option<String>,
65    /// Optional long-term auth credential (password).
66    pub credential: Option<String>,
67}
68
69#[derive(Debug)]
70pub struct TurnRelay {
71    /// UDP socket used for TURN control/data.
72    socket: Arc<UdpSocket>,
73    /// TURN server address.
74    server: SocketAddr,
75    /// Allocated relay address provided by the TURN server.
76    relay_addr: SocketAddr,
77    /// TURN realm provided by server during auth challenge.
78    realm: Option<String>,
79    /// TURN nonce provided by server during auth challenge.
80    nonce: Option<String>,
81    /// Username for long-term auth.
82    username: Option<String>,
83    /// Credential for long-term auth.
84    credential: Option<String>,
85    /// Cached channel bindings (peer addr -> channel number).
86    channels: Mutex<HashMap<SocketAddr, u16>>,
87    /// Cached permissions for peer addresses.
88    permissions: Mutex<HashSet<SocketAddr>>,
89}
90
91#[derive(Debug, thiserror::Error)]
92pub enum TurnError {
93    /// No TURN servers configured.
94    #[error("no turn servers configured")]
95    NoServers,
96    /// TURN credentials were required but missing.
97    #[error("turn server missing credentials")]
98    MissingCredentials,
99    /// Allocation attempt failed after retries.
100    #[error("turn allocation failed")]
101    AllocationFailed,
102    /// Response from TURN server was invalid.
103    #[error("turn response invalid")]
104    InvalidResponse,
105    /// Authentication failed.
106    #[error("turn auth failed")]
107    AuthFailed,
108    /// Underlying socket I/O error.
109    #[error("io error: {0}")]
110    Io(#[from] std::io::Error),
111    /// STUN parsing errors.
112    #[error("stun error: {0}")]
113    Stun(#[from] StunError),
114}
115
116#[derive(Debug, Clone)]
117pub struct TurnCandidate {
118    /// Relay address on the TURN server.
119    pub relay_addr: SocketAddr,
120    /// TURN server address.
121    pub server: SocketAddr,
122    /// Shared relay handle.
123    pub relay: Arc<TurnRelay>,
124}
125
126/// Periodically send empty datagrams to keep the TURN allocation alive.
127pub fn spawn_turn_keepalive(relay: Arc<TurnRelay>, interval_ms: u64) -> tokio::task::JoinHandle<()> {
128    tokio::spawn(async move {
129        let mut tick = tokio::time::interval(Duration::from_millis(interval_ms.max(1000)));
130        loop {
131            tick.tick().await;
132            let _ = relay.send_to(relay.relay_addr(), b"").await;
133        }
134    })
135}
136
137/// Parse a TURN server URI into a config struct.
138pub fn parse_turn_server(uri: &str) -> Result<TurnServerConfig, TurnError> {
139    let trimmed = uri.trim();
140    let trimmed = trimmed.strip_prefix("turn:").unwrap_or(trimmed);
141    let (host_port, query) = match trimmed.split_once('?') {
142        Some((base, q)) => (base, Some(q)),
143        None => (trimmed, None),
144    };
145    let (host, port) = match host_port.rsplit_once(':') {
146        Some((h, p)) => (h, p.parse::<u16>().unwrap_or(TURN_DEFAULT_PORT)),
147        None => (host_port, TURN_DEFAULT_PORT),
148    };
149    let addr = format!("{}:{}", host, port)
150        .parse::<SocketAddr>()
151        .map_err(|_| TurnError::InvalidResponse)?;
152    let mut username = None;
153    let mut credential = None;
154    if let Some(query) = query {
155        for pair in query.split('&') {
156            if pair.is_empty() {
157                continue;
158            }
159            if let Some((k, v)) = pair.split_once('=') {
160                if k == "username" {
161                    username = Some(v.to_string());
162                } else if k == "credential" || k == "password" {
163                    credential = Some(v.to_string());
164                }
165            }
166        }
167    }
168    Ok(TurnServerConfig {
169        addr,
170        username,
171        credential,
172    })
173}
174
175pub async fn allocate_turn_relay(
176    server: TurnServerConfig,
177    timeout_ms: u64,
178) -> Result<TurnCandidate, TurnError> {
179    // Allocate a TURN relay:
180    // 1) Send allocate request
181    // 2) Handle auth challenge (nonce/realm)
182    // 3) Extract relayed address on success
183    let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
184    let socket = Arc::new(socket);
185
186    let mut nonce = None;
187    let mut realm = None;
188    let mut relay_addr = None;
189
190    for attempt in 0..=1 {
191        let tx_id = random_tx_id();
192        let mut msg = build_allocate_request(&tx_id, server.username.as_deref(), &nonce, &realm)?;
193        if let (Some(username), Some(password), Some(realm), Some(_nonce)) = (
194            server.username.as_deref(),
195            server.credential.as_deref(),
196            realm.as_deref(),
197            nonce.as_deref(),
198        ) {
199            add_message_integrity(&mut msg, username, realm, password);
200            add_fingerprint(&mut msg);
201        }
202
203        socket.send_to(&msg, server.addr).await?;
204
205        let mut buf = [0u8; 1500];
206        let (len, _) = timeout(Duration::from_millis(timeout_ms), socket.recv_from(&mut buf))
207            .await
208            .map_err(|_| TurnError::AllocationFailed)??;
209        let response = parse_turn_response(&buf[..len], &tx_id)?;
210
211        match response.kind {
212            TurnResponseKind::Success { relayed } => {
213                relay_addr = Some(relayed);
214                break;
215            }
216            TurnResponseKind::AuthChallenge { nonce: new_nonce, realm: new_realm } => {
217                if attempt == 0 {
218                    nonce = Some(new_nonce);
219                    realm = Some(new_realm);
220                    continue;
221                }
222                return Err(TurnError::AuthFailed);
223            }
224            TurnResponseKind::Error => return Err(TurnError::AllocationFailed),
225        }
226    }
227
228    let relay_addr = relay_addr.ok_or(TurnError::AllocationFailed)?;
229    let relay = Arc::new(TurnRelay {
230        socket: socket.clone(),
231        server: server.addr,
232        relay_addr,
233        realm,
234        nonce,
235        username: server.username.clone(),
236        credential: server.credential.clone(),
237        channels: Mutex::new(HashMap::new()),
238        permissions: Mutex::new(HashSet::new()),
239    });
240
241    Ok(TurnCandidate {
242        relay_addr,
243        server: server.addr,
244        relay,
245    })
246}
247
248impl TurnRelay {
249    /// Return the allocated relay address.
250    pub fn relay_addr(&self) -> SocketAddr {
251        self.relay_addr
252    }
253
254    /// Send a payload to a peer via TURN (channel data preferred, send indication fallback).
255    pub async fn send_to(&self, peer: SocketAddr, data: &[u8]) -> Result<(), TurnError> {
256        self.ensure_permission(peer).await?;
257        let channel = self.ensure_channel(peer).await.ok();
258        if let Some(channel) = channel {
259            let mut buf = Vec::with_capacity(4 + data.len());
260            buf.extend_from_slice(&channel.to_be_bytes());
261            buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
262            buf.extend_from_slice(data);
263            self.socket.send_to(&buf, self.server).await?;
264            return Ok(());
265        }
266
267        let tx_id = random_tx_id();
268        let mut msg = build_send_indication(&tx_id, peer, data);
269        if self.should_auth() {
270            self.add_auth(&mut msg)?;
271        }
272        self.socket.send_to(&msg, self.server).await?;
273        Ok(())
274    }
275
276    /// Receive the next payload from the TURN relay.
277    pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), TurnError> {
278        loop {
279            let (len, _addr) = self.socket.recv_from(buf).await?;
280            if len >= 4 && is_channel_data(buf) {
281                let channel = u16::from_be_bytes([buf[0], buf[1]]);
282                let data_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
283                if len < 4 + data_len {
284                    continue;
285                }
286                let peer = {
287                    let channels = self.channels.lock().await;
288                    channels.iter().find(|(_, c)| **c == channel).map(|(peer, _)| *peer)
289                };
290                if let Some(peer) = peer {
291                    buf.copy_within(4..4 + data_len, 0);
292                    return Ok((data_len, peer));
293                }
294            }
295
296            if let Ok((peer, payload)) = parse_data_indication(&buf[..len]) {
297                let payload_len = payload.len();
298                if payload_len > buf.len() {
299                    continue;
300                }
301                let mut tmp = vec![0u8; payload_len];
302                tmp.copy_from_slice(payload);
303                buf[..payload_len].copy_from_slice(&tmp);
304                return Ok((payload_len, peer));
305            }
306        }
307    }
308
309    async fn ensure_permission(&self, peer: SocketAddr) -> Result<(), TurnError> {
310        // TURN requires explicit permission before relaying to a peer.
311        {
312            let perms = self.permissions.lock().await;
313            if perms.contains(&peer) {
314                return Ok(());
315            }
316        }
317        let tx_id = random_tx_id();
318        let mut msg = build_create_permission(&tx_id, peer);
319        if self.should_auth() {
320            self.add_auth(&mut msg)?;
321        }
322        self.socket.send_to(&msg, self.server).await?;
323        let mut perms = self.permissions.lock().await;
324        perms.insert(peer);
325        Ok(())
326    }
327
328    async fn ensure_channel(&self, peer: SocketAddr) -> Result<u16, TurnError> {
329        // Channel bindings provide efficient data framing for frequent peers.
330        if let Some(channel) = { self.channels.lock().await.get(&peer).copied() } {
331            return Ok(channel);
332        }
333        let channel = allocate_channel_number(&self.channels).await;
334        let tx_id = random_tx_id();
335        let mut msg = build_channel_bind(&tx_id, peer, channel);
336        if self.should_auth() {
337            self.add_auth(&mut msg)?;
338        }
339        self.socket.send_to(&msg, self.server).await?;
340        let mut channels = self.channels.lock().await;
341        channels.insert(peer, channel);
342        Ok(channel)
343    }
344
345    fn should_auth(&self) -> bool {
346        // Auth is required once we have all parameters for long-term credentials.
347        self.username.is_some() && self.credential.is_some() && self.realm.is_some() && self.nonce.is_some()
348    }
349
350    fn add_auth(&self, msg: &mut Vec<u8>) -> Result<(), TurnError> {
351        // Add MESSAGE-INTEGRITY and FINGERPRINT to a TURN message.
352        let username = self.username.as_ref().ok_or(TurnError::MissingCredentials)?;
353        let credential = self.credential.as_ref().ok_or(TurnError::MissingCredentials)?;
354        let realm = self.realm.as_ref().ok_or(TurnError::MissingCredentials)?;
355        let _nonce = self.nonce.as_ref().ok_or(TurnError::MissingCredentials)?;
356        add_message_integrity(msg, username, realm, credential);
357        add_fingerprint(msg);
358        Ok(())
359    }
360}
361
362#[derive(Debug)]
363struct TurnResponse {
364    kind: TurnResponseKind,
365}
366
367#[derive(Debug)]
368enum TurnResponseKind {
369    Success { relayed: SocketAddr },
370    AuthChallenge { nonce: String, realm: String },
371    Error,
372}
373
374fn build_allocate_request(
375    tx_id: &[u8; 12],
376    username: Option<&str>,
377    nonce: &Option<String>,
378    realm: &Option<String>,
379) -> Result<Vec<u8>, TurnError> {
380    // Build TURN allocate request with optional auth parameters.
381    let mut msg = build_stun_header(TURN_ALLOCATE_REQUEST, tx_id);
382    add_attr_u32(&mut msg, ATTR_REQUESTED_TRANSPORT, (TURN_UDP_TRANSPORT as u32) << 24);
383    if let Some(username) = username {
384        add_attr_bytes(&mut msg, ATTR_USERNAME, username.as_bytes());
385    }
386    if let Some(realm) = realm.as_ref() {
387        add_attr_bytes(&mut msg, ATTR_REALM, realm.as_bytes());
388    }
389    if let Some(nonce) = nonce.as_ref() {
390        add_attr_bytes(&mut msg, ATTR_NONCE, nonce.as_bytes());
391    }
392    finalize_length(&mut msg);
393    Ok(msg)
394}
395
396fn build_create_permission(tx_id: &[u8; 12], peer: SocketAddr) -> Vec<u8> {
397    // Build TURN create-permission request for a peer.
398    let mut msg = build_stun_header(TURN_CREATE_PERMISSION_REQUEST, tx_id);
399    add_attr_bytes(&mut msg, ATTR_XOR_PEER_ADDRESS, &encode_xor_addr(peer, tx_id));
400    finalize_length(&mut msg);
401    msg
402}
403
404fn build_channel_bind(tx_id: &[u8; 12], peer: SocketAddr, channel: u16) -> Vec<u8> {
405    // Build TURN channel bind request for a peer + channel.
406    let mut msg = build_stun_header(TURN_CHANNEL_BIND_REQUEST, tx_id);
407    add_attr_u32(&mut msg, ATTR_CHANNEL_NUMBER, (channel as u32) << 16);
408    add_attr_bytes(&mut msg, ATTR_XOR_PEER_ADDRESS, &encode_xor_addr(peer, tx_id));
409    finalize_length(&mut msg);
410    msg
411}
412
413fn build_send_indication(tx_id: &[u8; 12], peer: SocketAddr, data: &[u8]) -> Vec<u8> {
414    // Build TURN send indication (no channel binding required).
415    let mut msg = build_stun_header(TURN_SEND_INDICATION, tx_id);
416    add_attr_bytes(&mut msg, ATTR_XOR_PEER_ADDRESS, &encode_xor_addr(peer, tx_id));
417    add_attr_bytes(&mut msg, ATTR_DATA, data);
418    finalize_length(&mut msg);
419    msg
420}
421
422fn parse_turn_response(buf: &[u8], tx_id: &[u8; 12]) -> Result<TurnResponse, TurnError> {
423    // Parse STUN/TURN response and extract success or auth challenge.
424    if buf.len() < STUN_HEADER_LEN {
425        return Err(TurnError::InvalidResponse);
426    }
427    let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
428    let msg_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
429    let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
430    if cookie != STUN_MAGIC_COOKIE || &buf[8..20] != tx_id {
431        return Err(TurnError::InvalidResponse);
432    }
433    let end = STUN_HEADER_LEN + msg_len.min(buf.len().saturating_sub(STUN_HEADER_LEN));
434    let mut offset = STUN_HEADER_LEN;
435    let mut nonce = None;
436    let mut realm = None;
437    let mut relayed = None;
438    while offset + 4 <= end {
439        let attr_type = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
440        let attr_len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
441        let value_start = offset + 4;
442        let value_end = value_start + attr_len;
443        if value_end > buf.len() {
444            break;
445        }
446        match attr_type {
447            ATTR_NONCE => nonce = Some(String::from_utf8_lossy(&buf[value_start..value_end]).to_string()),
448            ATTR_REALM => realm = Some(String::from_utf8_lossy(&buf[value_start..value_end]).to_string()),
449            ATTR_XOR_RELAYED_ADDRESS => {
450                if let Ok(addr) = decode_xor_addr(&buf[value_start..value_end], tx_id) {
451                    relayed = Some(addr);
452                }
453            }
454            _ => {}
455        }
456        offset = value_start + ((attr_len + 3) & !3);
457    }
458
459    match msg_type {
460        TURN_ALLOCATE_RESPONSE | TURN_CREATE_PERMISSION_RESPONSE | TURN_CHANNEL_BIND_RESPONSE => {
461            let relayed = relayed.unwrap_or_else(|| SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0));
462            Ok(TurnResponse { kind: TurnResponseKind::Success { relayed } })
463        }
464        TURN_ALLOCATE_ERROR => {
465            if let (Some(nonce), Some(realm)) = (nonce, realm) {
466                Ok(TurnResponse { kind: TurnResponseKind::AuthChallenge { nonce, realm } })
467            } else {
468                Ok(TurnResponse { kind: TurnResponseKind::Error })
469            }
470        }
471        _ => Ok(TurnResponse { kind: TurnResponseKind::Error }),
472    }
473}
474
475fn parse_data_indication(buf: &[u8]) -> Result<(SocketAddr, &[u8]), TurnError> {
476    if buf.len() < STUN_HEADER_LEN {
477        return Err(TurnError::InvalidResponse);
478    }
479    let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
480    if msg_type != TURN_DATA_INDICATION {
481        return Err(TurnError::InvalidResponse);
482    }
483    let tx_id: [u8; 12] = buf[8..20].try_into().map_err(|_| TurnError::InvalidResponse)?;
484    let msg_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
485    let end = STUN_HEADER_LEN + msg_len.min(buf.len().saturating_sub(STUN_HEADER_LEN));
486    let mut offset = STUN_HEADER_LEN;
487    let mut peer = None;
488    let mut data = None;
489    while offset + 4 <= end {
490        let attr_type = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
491        let attr_len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
492        let value_start = offset + 4;
493        let value_end = value_start + attr_len;
494        if value_end > buf.len() {
495            break;
496        }
497        match attr_type {
498            ATTR_XOR_PEER_ADDRESS => {
499                peer = decode_xor_addr(&buf[value_start..value_end], &tx_id).ok();
500            }
501            ATTR_DATA => {
502                data = Some(&buf[value_start..value_end]);
503            }
504            _ => {}
505        }
506        offset = value_start + ((attr_len + 3) & !3);
507    }
508    let peer = peer.ok_or(TurnError::InvalidResponse)?;
509    let data = data.ok_or(TurnError::InvalidResponse)?;
510    Ok((peer, data))
511}
512
513fn build_stun_header(msg_type: u16, tx_id: &[u8; 12]) -> Vec<u8> {
514    let mut out = Vec::with_capacity(128);
515    out.extend_from_slice(&msg_type.to_be_bytes());
516    out.extend_from_slice(&0u16.to_be_bytes());
517    out.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
518    out.extend_from_slice(tx_id);
519    out
520}
521
522fn add_attr_u32(buf: &mut Vec<u8>, attr: u16, value: u32) {
523    buf.extend_from_slice(&attr.to_be_bytes());
524    buf.extend_from_slice(&4u16.to_be_bytes());
525    buf.extend_from_slice(&value.to_be_bytes());
526}
527
528fn add_attr_bytes(buf: &mut Vec<u8>, attr: u16, value: &[u8]) {
529    buf.extend_from_slice(&attr.to_be_bytes());
530    buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
531    buf.extend_from_slice(value);
532    let pad = (4 - (value.len() % 4)) % 4;
533    for _ in 0..pad {
534        buf.push(0);
535    }
536}
537
538fn finalize_length(buf: &mut Vec<u8>) {
539    let len = buf.len().saturating_sub(STUN_HEADER_LEN) as u16;
540    buf[2..4].copy_from_slice(&len.to_be_bytes());
541}
542
543fn add_message_integrity(buf: &mut Vec<u8>, username: &str, realm: &str, password: &str) {
544    finalize_length(buf);
545    let key = format!("{}:{}:{}", username, realm, password);
546    let mut mac = HmacSha1::new_from_slice(key.as_bytes()).expect("hmac key");
547    mac.update(buf);
548    let result = mac.finalize().into_bytes();
549    add_attr_bytes(buf, ATTR_MESSAGE_INTEGRITY, &result);
550    finalize_length(buf);
551}
552
553fn add_fingerprint(buf: &mut Vec<u8>) {
554    finalize_length(buf);
555    let mut hasher = Crc32::new();
556    hasher.update(buf);
557    let crc = hasher.finalize() ^ 0x5354_554e;
558    add_attr_u32(buf, ATTR_FINGERPRINT, crc);
559    finalize_length(buf);
560}
561
562fn encode_xor_addr(addr: SocketAddr, tx_id: &[u8; 12]) -> Vec<u8> {
563    match addr {
564        SocketAddr::V4(addr) => {
565            let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
566            let ip = u32::from(*addr.ip()) ^ STUN_MAGIC_COOKIE;
567            let mut out = Vec::with_capacity(8);
568            out.push(0);
569            out.push(0x01);
570            out.extend_from_slice(&port.to_be_bytes());
571            out.extend_from_slice(&ip.to_be_bytes());
572            out
573        }
574        SocketAddr::V6(addr) => {
575            let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
576            let mut ip = addr.ip().octets();
577            let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
578            for i in 0..4 {
579                ip[i] ^= cookie[i];
580            }
581            for i in 0..12 {
582                ip[4 + i] ^= tx_id[i];
583            }
584            let mut out = Vec::with_capacity(20);
585            out.push(0);
586            out.push(0x02);
587            out.extend_from_slice(&port.to_be_bytes());
588            out.extend_from_slice(&ip);
589            out
590        }
591    }
592}
593
594fn decode_xor_addr(buf: &[u8], tx_id: &[u8; 12]) -> Result<SocketAddr, TurnError> {
595    if buf.len() < 4 {
596        return Err(TurnError::InvalidResponse);
597    }
598    let family = buf[1];
599    let port = u16::from_be_bytes([buf[2], buf[3]]) ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
600    match family {
601        0x01 => {
602            if buf.len() < 8 {
603                return Err(TurnError::InvalidResponse);
604            }
605            let mut ip = [0u8; 4];
606            ip.copy_from_slice(&buf[4..8]);
607            let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
608            for i in 0..4 {
609                ip[i] ^= cookie[i];
610            }
611            Ok(SocketAddr::new(IpAddr::V4(ip.into()), port))
612        }
613        0x02 => {
614            if buf.len() < 20 {
615                return Err(TurnError::InvalidResponse);
616            }
617            let mut ip = [0u8; 16];
618            ip.copy_from_slice(&buf[4..20]);
619            let mut xor = [0u8; 16];
620            xor[..4].copy_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
621            xor[4..].copy_from_slice(tx_id);
622            for i in 0..16 {
623                ip[i] ^= xor[i];
624            }
625            Ok(SocketAddr::new(IpAddr::V6(ip.into()), port))
626        }
627        _ => Err(TurnError::InvalidResponse),
628    }
629}
630
631fn random_tx_id() -> [u8; 12] {
632    let mut tx_id = [0u8; 12];
633    rand::rngs::OsRng.fill_bytes(&mut tx_id);
634    tx_id
635}
636
637fn is_channel_data(buf: &[u8]) -> bool {
638    if buf.len() < 4 {
639        return false;
640    }
641    let channel = u16::from_be_bytes([buf[0], buf[1]]);
642    (0x4000..=0x7FFF).contains(&channel)
643}
644
645async fn allocate_channel_number(channels: &Mutex<HashMap<SocketAddr, u16>>) -> u16 {
646    let mut num = 0x4000u16;
647    let existing = channels.lock().await.values().copied().collect::<HashSet<_>>();
648    while existing.contains(&num) {
649        num = num.wrapping_add(1);
650        if num < 0x4000 {
651            num = 0x4000;
652        }
653    }
654    num
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[tokio::test]
662    async fn parse_turn_uri_defaults() {
663        let cfg = parse_turn_server("turn:127.0.0.1").unwrap();
664        assert_eq!(cfg.addr.port(), TURN_DEFAULT_PORT);
665    }
666
667    #[tokio::test]
668    async fn allocate_turn_no_auth() {
669        let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 34790);
670        let relay_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 50000);
671
672        let _server = tokio::spawn(async move {
673            let socket = UdpSocket::bind(server_addr).await.unwrap();
674            let mut buf = [0u8; 1500];
675            let (_len, peer) = socket.recv_from(&mut buf).await.unwrap();
676            let tx_id: [u8; 12] = buf[8..20].try_into().unwrap();
677            let mut response = build_stun_header(TURN_ALLOCATE_RESPONSE, &tx_id);
678            add_attr_bytes(&mut response, ATTR_XOR_RELAYED_ADDRESS, &encode_xor_addr(relay_addr, &tx_id));
679            finalize_length(&mut response);
680            let _ = socket.send_to(&response, peer).await;
681        });
682
683        let cfg = TurnServerConfig {
684            addr: server_addr,
685            username: None,
686            credential: None,
687        };
688        let cand = allocate_turn_relay(cfg, 1000).await.unwrap();
689        assert_eq!(cand.relay_addr, relay_addr);
690    }
691}