1use std::net::SocketAddr;
4
5use stun_rs::{
6 attributes::stun::{Fingerprint, XorMappedAddress},
7 DecoderContextBuilder, MessageDecoderBuilder, MessageEncoderBuilder, StunMessageBuilder,
8};
9pub use stun_rs::{
10 attributes::StunAttribute, error::StunDecodeError, methods, MessageClass, MessageDecoder,
11 TransactionId,
12};
13
14#[derive(Debug, thiserror::Error)]
16pub enum Error {
17 #[error("invalid message")]
19 InvalidMessage,
20 #[error("not binding")]
22 NotBinding,
23 #[error("not success response")]
25 NotSuccessResponse,
26 #[error("malformed attributes")]
28 MalformedAttrs,
29 #[error("no fingerprint")]
31 NoFingerprint,
32 #[error("invalid fingerprint")]
34 InvalidFingerprint,
35}
36
37pub fn request(tx: TransactionId) -> Vec<u8> {
39 let fp = Fingerprint::default();
40 let msg = StunMessageBuilder::new(methods::BINDING, MessageClass::Request)
41 .with_transaction_id(tx)
42 .with_attribute(fp)
43 .build();
44
45 let encoder = MessageEncoderBuilder::default().build();
46 let mut buffer = vec![0u8; 150];
47 let size = encoder.encode(&mut buffer, &msg).expect("invalid encoding");
48 buffer.truncate(size);
49 buffer
50}
51
52pub fn response(tx: TransactionId, addr: SocketAddr) -> Vec<u8> {
54 let msg = StunMessageBuilder::new(methods::BINDING, MessageClass::SuccessResponse)
55 .with_transaction_id(tx)
56 .with_attribute(XorMappedAddress::from(addr))
57 .build();
58
59 let encoder = MessageEncoderBuilder::default().build();
60 let mut buffer = vec![0u8; 150];
61 let size = encoder.encode(&mut buffer, &msg).expect("invalid encoding");
62 buffer.truncate(size);
63 buffer
64}
65
66const COOKIE: [u8; 4] = 0x2112_A442u32.to_be_bytes();
69
70pub fn is(b: &[u8]) -> bool {
72 b.len() >= stun_rs::MESSAGE_HEADER_SIZE &&
73 b[0]&0b11000000 == 0 && b[4..8] == COOKIE
75}
76
77pub fn parse_binding_request(b: &[u8]) -> Result<TransactionId, Error> {
79 let ctx = DecoderContextBuilder::default()
80 .with_validation() .build();
82 let decoder = MessageDecoderBuilder::default().with_context(ctx).build();
83 let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?;
84
85 let tx = *msg.transaction_id();
86 if msg.method() != methods::BINDING {
87 return Err(Error::NotBinding);
88 }
89
90 if msg
93 .attributes()
94 .last()
95 .map(|attr| !attr.is_fingerprint())
96 .unwrap_or_default()
97 {
98 return Err(Error::NoFingerprint);
99 }
100
101 Ok(tx)
102}
103
104pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), Error> {
107 let decoder = MessageDecoder::default();
108 let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?;
109
110 let tx = *msg.transaction_id();
111 if msg.class() != MessageClass::SuccessResponse {
112 return Err(Error::NotSuccessResponse);
113 }
114
115 let mut addr = None;
122 let mut fallback_addr = None;
123 for attr in msg.attributes() {
124 match attr {
125 StunAttribute::XorMappedAddress(a) => {
126 let mut a = *a.socket_address();
127 a.set_ip(a.ip().to_canonical());
128 addr = Some(a);
129 }
130 StunAttribute::MappedAddress(a) => {
131 let mut a = *a.socket_address();
132 a.set_ip(a.ip().to_canonical());
133 fallback_addr = Some(a);
134 }
135 _ => {}
136 }
137 }
138
139 if let Some(addr) = addr {
140 return Ok((tx, addr));
141 }
142
143 if let Some(addr) = fallback_addr {
144 return Ok((tx, addr));
145 }
146
147 Err(Error::MalformedAttrs)
148}
149
150#[cfg(test)]
151pub(crate) mod tests {
152 use std::{
153 net::{IpAddr, Ipv4Addr},
154 sync::Arc,
155 };
156
157 use anyhow::Result;
158 use tokio::{
159 net,
160 sync::{oneshot, Mutex},
161 };
162 use tracing::{debug, trace};
163
164 use super::*;
165 use crate::{
166 relay::{RelayMap, RelayNode, RelayUrl},
167 test_utils::CleanupDropGuard,
168 };
169
170 #[derive(Debug, Default, Clone)]
174 pub struct StunStats(Arc<Mutex<(usize, usize)>>);
175
176 impl StunStats {
177 pub async fn total(&self) -> usize {
178 let s = self.0.lock().await;
179 s.0 + s.1
180 }
181 }
182
183 pub fn relay_map_of(stun: impl Iterator<Item = SocketAddr>) -> RelayMap {
184 relay_map_of_opts(stun.map(|addr| (addr, true)))
185 }
186
187 pub fn relay_map_of_opts(stun: impl Iterator<Item = (SocketAddr, bool)>) -> RelayMap {
188 let nodes = stun.map(|(addr, stun_only)| {
189 let host = addr.ip();
190 let port = addr.port();
191
192 let url: RelayUrl = format!("http://{host}:{port}").parse().unwrap();
193 RelayNode {
194 url,
195 stun_port: port,
196 stun_only,
197 }
198 });
199 RelayMap::from_nodes(nodes).expect("generated invalid nodes")
200 }
201
202 pub(crate) async fn serve_v4() -> Result<(SocketAddr, StunStats, CleanupDropGuard)> {
206 serve(std::net::Ipv4Addr::UNSPECIFIED.into()).await
207 }
208
209 pub(crate) async fn serve(ip: IpAddr) -> Result<(SocketAddr, StunStats, CleanupDropGuard)> {
211 let stats = StunStats::default();
212
213 let pc = net::UdpSocket::bind((ip, 0)).await?;
214 let mut addr = pc.local_addr()?;
215 match addr.ip() {
216 IpAddr::V4(ip) => {
217 if ip.octets() == [0, 0, 0, 0] {
218 addr.set_ip("127.0.0.1".parse().unwrap());
219 }
220 }
221 _ => unreachable!("using ipv4"),
222 }
223
224 println!("STUN listening on {}", addr);
225 let (s, r) = oneshot::channel();
226 let stats_c = stats.clone();
227 tokio::task::spawn(async move {
228 run_stun(pc, stats_c, r).await;
229 });
230
231 Ok((addr, stats, CleanupDropGuard(s)))
232 }
233
234 async fn run_stun(pc: net::UdpSocket, stats: StunStats, mut done: oneshot::Receiver<()>) {
235 let mut buf = vec![0u8; 64 << 10];
236 loop {
237 trace!("read loop");
238 tokio::select! {
239 _ = &mut done => {
240 debug!("shutting down");
241 break;
242 }
243 res = pc.recv_from(&mut buf) => match res {
244 Ok((n, addr)) => {
245 trace!("read packet {}bytes from {}", n, addr);
246 let pkt = &buf[..n];
247 if !is(pkt) {
248 debug!("received non STUN pkt");
249 continue;
250 }
251 if let Ok(txid) = parse_binding_request(pkt) {
252 debug!("received binding request");
253 let mut s = stats.0.lock().await;
254 if addr.is_ipv4() {
255 s.0 += 1;
256 } else {
257 s.1 += 1;
258 }
259 drop(s);
260
261 let res = response(txid, addr);
262 if let Err(err) = pc.send_to(&res, addr).await {
263 eprintln!("STUN server write failed: {:?}", err);
264 }
265 }
266 }
267 Err(err) => {
268 eprintln!("failed to read: {:?}", err);
269 }
270 }
271 }
272 }
273 }
274
275 struct ResponseTestCase {
315 name: &'static str,
316 data: Vec<u8>,
317 want_tid: Vec<u8>,
318 want_addr: IpAddr,
319 want_port: u16,
320 }
321
322 #[test]
323 fn test_parse_response() {
324 let cases = vec![
325 ResponseTestCase {
326 name: "google-1",
327 data: vec![
328 0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42,
329 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
330 0x93, 0xe0, 0x80, 0x07, 0x00, 0x20, 0x00, 0x08,
331 0x00, 0x01, 0xc7, 0x86, 0x69, 0x57, 0x85, 0x6f,
332 ],
333 want_tid: vec![
334 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
335 0x93, 0xe0, 0x80, 0x07,
336 ],
337 want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
338 want_port: 59028,
339 },
340 ResponseTestCase {
341 name: "google-2",
342 data: vec![
343 0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42,
344 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
345 0x92, 0x3c, 0xe2, 0x71, 0x00, 0x20, 0x00, 0x08,
346 0x00, 0x01, 0xc7, 0x87, 0x69, 0x57, 0x85, 0x6f,
347 ],
348 want_tid: vec![
349 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
350 0x92, 0x3c, 0xe2, 0x71,
351 ],
352 want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
353 want_port: 59029,
354 },
355 ResponseTestCase{
356 name: "stun.sipgate.net:10000",
357 data: vec![
358 0x01, 0x01, 0x00, 0x44, 0x21, 0x12, 0xa4, 0x42,
359 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
360 0xae, 0xad, 0x64, 0x44, 0x00, 0x01, 0x00, 0x08,
361 0x00, 0x01, 0xe4, 0xab, 0x48, 0x45, 0x21, 0x2d,
362 0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x27, 0x10,
363 0xd9, 0x0a, 0x44, 0x98, 0x00, 0x05, 0x00, 0x08,
364 0x00, 0x01, 0x27, 0x11, 0xd9, 0x74, 0x7a, 0x8a,
365 0x80, 0x20, 0x00, 0x08, 0x00, 0x01, 0xc5, 0xb9,
366 0x69, 0x57, 0x85, 0x6f, 0x80, 0x22, 0x00, 0x10,
367 0x56, 0x6f, 0x76, 0x69, 0x64, 0x61, 0x2e, 0x6f,
368 0x72, 0x67, 0x20, 0x30, 0x2e, 0x39, 0x36, 0x00,
369 ],
370 want_tid: vec![
371 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
372 0xae, 0xad, 0x64, 0x44,
373 ],
374 want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
375 want_port: 58539,
376 },
377 ResponseTestCase{
378 name: "stun.powervoip.com:3478",
379 data: vec![
380 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
381 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
382 0x9d, 0x1d, 0xea, 0xa6, 0x00, 0x01, 0x00, 0x08,
383 0x00, 0x01, 0xe9, 0xd3, 0x48, 0x45, 0x21, 0x2d,
384 0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x0d, 0x96,
385 0x4d, 0x48, 0xa9, 0xd4, 0x00, 0x05, 0x00, 0x08,
386 0x00, 0x01, 0x0d, 0x97, 0x4d, 0x48, 0xa9, 0xd5,
387 ],
388 want_tid: vec![
389 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
390 0x9d, 0x1d, 0xea, 0xa6,
391 ],
392 want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
393 want_port: 59859,
394 },
395 ResponseTestCase{
396 name: "in-process pion server",
397 data: vec![
398 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
399 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
400 0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x0a,
401 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
402 0x65, 0x72, 0x00, 0x00, 0x00, 0x20, 0x00, 0x08,
403 0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
404 0x80, 0x28, 0x00, 0x04, 0xb6, 0x99, 0xbb, 0x02,
405 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
406 ],
407 want_tid: vec![
408 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
409 0x4f, 0x3e, 0x30, 0x8e,
410 ],
411 want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
412 want_port: 61300,
413 },
414 ResponseTestCase{
415 name: "stuntman-server ipv6",
416 data: vec![
417 0x01, 0x01, 0x00, 0x48, 0x21, 0x12, 0xa4, 0x42,
418 0x06, 0xf5, 0x66, 0x85, 0xd2, 0x8a, 0xf3, 0xe6,
419 0x9c, 0xe3, 0x41, 0xe2, 0x00, 0x01, 0x00, 0x14,
420 0x00, 0x02, 0x90, 0xce, 0x26, 0x02, 0x00, 0xd1,
421 0xb4, 0xcf, 0xc1, 0x00, 0x38, 0xb2, 0x31, 0xff,
422 0xfe, 0xef, 0x96, 0xf6, 0x80, 0x2b, 0x00, 0x14,
423 0x00, 0x02, 0x0d, 0x96, 0x26, 0x04, 0xa8, 0x80,
424 0x00, 0x02, 0x00, 0xd1, 0x00, 0x00, 0x00, 0x00,
425 0x00, 0xc5, 0x70, 0x01, 0x00, 0x20, 0x00, 0x14,
426 0x00, 0x02, 0xb1, 0xdc, 0x07, 0x10, 0xa4, 0x93,
427 0xb2, 0x3a, 0xa7, 0x85, 0xea, 0x38, 0xc2, 0x19,
428 0x62, 0x0c, 0xd7, 0x14,
429 ],
430 want_tid: vec![
431 6, 245, 102, 133, 210, 138, 243, 230, 156, 227,
432 65, 226,
433 ],
434 want_addr: "2602:d1:b4cf:c100:38b2:31ff:feef:96f6".parse().unwrap(),
435 want_port: 37070,
436 },
437 ResponseTestCase {
440 name: "software-a",
441 data: vec![
442 0x01, 0x01, 0x00, 0x14, 0x21, 0x12, 0xa4, 0x42,
443 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
444 0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x01,
445 0x61, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x08,
446 0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
447 ],
448 want_tid: vec![
449 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
450 0x4f, 0x3e, 0x30, 0x8e,
451 ],
452 want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
453 want_port: 61300,
454 },
455 ResponseTestCase {
456 name: "software-abc",
457 data: vec![
458 0x01, 0x01, 0x00, 0x14, 0x21, 0x12, 0xa4, 0x42,
459 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
460 0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x03,
461 0x61, 0x62, 0x63, 0x00, 0x00, 0x20, 0x00, 0x08,
462 0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
463 ],
464 want_tid: vec![
465 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
466 0x4f, 0x3e, 0x30, 0x8e,
467 ],
468 want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
469 want_port: 61300,
470 },
471 ResponseTestCase {
472 name: "no-4in6",
473 data: hex::decode("010100182112a4424fd5d202dcb37d31fc773306002000140002cd3d2112a4424fd5d202dcb382ce2dc3fcc7").unwrap(),
474 want_tid: vec![79, 213, 210, 2, 220, 179, 125, 49, 252, 119, 51, 6],
475 want_addr: IpAddr::V4(Ipv4Addr::from([209, 180, 207, 193])),
476 want_port: 60463,
477 },
478 ];
479
480 for (i, test) in cases.into_iter().enumerate() {
481 println!("Case {i}: {}", test.name);
482 let (tx, addr_port) = parse_response(&test.data).unwrap();
483 assert!(is(&test.data));
484 assert_eq!(tx.as_bytes(), &test.want_tid[..]);
485 assert_eq!(addr_port.ip(), test.want_addr);
486 assert_eq!(addr_port.port(), test.want_port);
487 }
488 }
489
490 #[test]
491 fn test_parse_binding_request() {
492 let tx = TransactionId::default();
493 let req = request(tx);
494 assert!(is(&req));
495 let got_tx = parse_binding_request(&req).unwrap();
496 assert_eq!(got_tx, tx);
497 }
498
499 #[test]
500 fn test_stun_cookie() {
501 assert_eq!(stun_rs::MAGIC_COOKIE, COOKIE);
502 }
503
504 #[test]
505 fn test_response() {
506 let txn = |n| TransactionId::from([n; 12]);
507
508 struct Case {
509 tx: TransactionId,
510 addr: IpAddr,
511 port: u16,
512 }
513 let tests = vec![
514 Case {
515 tx: txn(1),
516 addr: "1.2.3.4".parse().unwrap(),
517 port: 254,
518 },
519 Case {
520 tx: txn(2),
521 addr: "1.2.3.4".parse().unwrap(),
522 port: 257,
523 },
524 Case {
525 tx: txn(3),
526 addr: "1::4".parse().unwrap(),
527 port: 254,
528 },
529 Case {
530 tx: txn(4),
531 addr: "1::4".parse().unwrap(),
532 port: 257,
533 },
534 ];
535
536 for tt in tests {
537 let res = response(tt.tx, SocketAddr::new(tt.addr, tt.port));
538 assert!(is(&res));
539 let (tx2, addr2) = parse_response(&res).unwrap();
540 assert_eq!(tt.tx, tx2);
541 assert_eq!(tt.addr, addr2.ip());
542 assert_eq!(tt.port, addr2.port());
543 }
544 }
545}