iroh_net/
stun.rs

1//! STUN packets sending and receiving.
2
3use std::net::SocketAddr;
4
5use stun_rs::{
6    attributes::stun::{Fingerprint, XorMappedAddress},
7    DecoderContextBuilder, MessageDecoderBuilder, MessageEncoderBuilder, StunMessageBuilder,
8};
9pub use stun_rs::{
10    attributes::StunAttribute, error::StunDecodeError, methods, MessageClass, MessageDecoder,
11    TransactionId,
12};
13
14/// Errors that can occur when handling a STUN packet.
15#[derive(Debug, thiserror::Error)]
16pub enum Error {
17    /// The STUN message could not be parsed or is otherwise invalid.
18    #[error("invalid message")]
19    InvalidMessage,
20    /// STUN request is not a binding request when it should be.
21    #[error("not binding")]
22    NotBinding,
23    /// STUN packet is not a response when it should be.
24    #[error("not success response")]
25    NotSuccessResponse,
26    /// STUN response has malformed attributes.
27    #[error("malformed attributes")]
28    MalformedAttrs,
29    /// STUN request didn't end in fingerprint.
30    #[error("no fingerprint")]
31    NoFingerprint,
32    /// STUN request had bogus fingerprint.
33    #[error("invalid fingerprint")]
34    InvalidFingerprint,
35}
36
37/// Generates a binding request STUN packet.
38pub fn request(tx: TransactionId) -> Vec<u8> {
39    let fp = Fingerprint::default();
40    let msg = StunMessageBuilder::new(methods::BINDING, MessageClass::Request)
41        .with_transaction_id(tx)
42        .with_attribute(fp)
43        .build();
44
45    let encoder = MessageEncoderBuilder::default().build();
46    let mut buffer = vec![0u8; 150];
47    let size = encoder.encode(&mut buffer, &msg).expect("invalid encoding");
48    buffer.truncate(size);
49    buffer
50}
51
52/// Generates a binding response.
53pub fn response(tx: TransactionId, addr: SocketAddr) -> Vec<u8> {
54    let msg = StunMessageBuilder::new(methods::BINDING, MessageClass::SuccessResponse)
55        .with_transaction_id(tx)
56        .with_attribute(XorMappedAddress::from(addr))
57        .build();
58
59    let encoder = MessageEncoderBuilder::default().build();
60    let mut buffer = vec![0u8; 150];
61    let size = encoder.encode(&mut buffer, &msg).expect("invalid encoding");
62    buffer.truncate(size);
63    buffer
64}
65
66// Copied from stun_rs
67// const MAGIC_COOKIE: Cookie = Cookie(0x2112_A442);
68const COOKIE: [u8; 4] = 0x2112_A442u32.to_be_bytes();
69
70/// Reports whether b is a STUN message.
71pub fn is(b: &[u8]) -> bool {
72    b.len() >= stun_rs::MESSAGE_HEADER_SIZE &&
73        b[0]&0b11000000 == 0 && // top two bits must be zero
74        b[4..8] == COOKIE
75}
76
77/// Parses a STUN binding request.
78pub fn parse_binding_request(b: &[u8]) -> Result<TransactionId, Error> {
79    let ctx = DecoderContextBuilder::default()
80        .with_validation() // ensure fingerprint is validated
81        .build();
82    let decoder = MessageDecoderBuilder::default().with_context(ctx).build();
83    let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?;
84
85    let tx = *msg.transaction_id();
86    if msg.method() != methods::BINDING {
87        return Err(Error::NotBinding);
88    }
89
90    // TODO: Tailscale sets the software to tailscale, we should check if we want to do this too.
91
92    if msg
93        .attributes()
94        .last()
95        .map(|attr| !attr.is_fingerprint())
96        .unwrap_or_default()
97    {
98        return Err(Error::NoFingerprint);
99    }
100
101    Ok(tx)
102}
103
104/// Parses a successful binding response STUN packet.
105/// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
106pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), Error> {
107    let decoder = MessageDecoder::default();
108    let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?;
109
110    let tx = *msg.transaction_id();
111    if msg.class() != MessageClass::SuccessResponse {
112        return Err(Error::NotSuccessResponse);
113    }
114
115    // Read through the attributes.
116    // The the addr+port reported by XOR-MAPPED-ADDRESS
117    // as the canonical value. If the attribute is not
118    // present but the STUN server responds with
119    // MAPPED-ADDRESS we fall back to it.
120
121    let mut addr = None;
122    let mut fallback_addr = None;
123    for attr in msg.attributes() {
124        match attr {
125            StunAttribute::XorMappedAddress(a) => {
126                let mut a = *a.socket_address();
127                a.set_ip(a.ip().to_canonical());
128                addr = Some(a);
129            }
130            StunAttribute::MappedAddress(a) => {
131                let mut a = *a.socket_address();
132                a.set_ip(a.ip().to_canonical());
133                fallback_addr = Some(a);
134            }
135            _ => {}
136        }
137    }
138
139    if let Some(addr) = addr {
140        return Ok((tx, addr));
141    }
142
143    if let Some(addr) = fallback_addr {
144        return Ok((tx, addr));
145    }
146
147    Err(Error::MalformedAttrs)
148}
149
150#[cfg(test)]
151pub(crate) mod tests {
152    use std::{
153        net::{IpAddr, Ipv4Addr},
154        sync::Arc,
155    };
156
157    use anyhow::Result;
158    use tokio::{
159        net,
160        sync::{oneshot, Mutex},
161    };
162    use tracing::{debug, trace};
163
164    use super::*;
165    use crate::{
166        relay::{RelayMap, RelayNode, RelayUrl},
167        test_utils::CleanupDropGuard,
168    };
169
170    // TODO: make all this private
171
172    // (read_ipv4, read_ipv5)
173    #[derive(Debug, Default, Clone)]
174    pub struct StunStats(Arc<Mutex<(usize, usize)>>);
175
176    impl StunStats {
177        pub async fn total(&self) -> usize {
178            let s = self.0.lock().await;
179            s.0 + s.1
180        }
181    }
182
183    pub fn relay_map_of(stun: impl Iterator<Item = SocketAddr>) -> RelayMap {
184        relay_map_of_opts(stun.map(|addr| (addr, true)))
185    }
186
187    pub fn relay_map_of_opts(stun: impl Iterator<Item = (SocketAddr, bool)>) -> RelayMap {
188        let nodes = stun.map(|(addr, stun_only)| {
189            let host = addr.ip();
190            let port = addr.port();
191
192            let url: RelayUrl = format!("http://{host}:{port}").parse().unwrap();
193            RelayNode {
194                url,
195                stun_port: port,
196                stun_only,
197            }
198        });
199        RelayMap::from_nodes(nodes).expect("generated invalid nodes")
200    }
201
202    /// Sets up a simple STUN server binding to `0.0.0.0:0`.
203    ///
204    /// See [`serve`] for more details.
205    pub(crate) async fn serve_v4() -> Result<(SocketAddr, StunStats, CleanupDropGuard)> {
206        serve(std::net::Ipv4Addr::UNSPECIFIED.into()).await
207    }
208
209    /// Sets up a simple STUN server.
210    pub(crate) async fn serve(ip: IpAddr) -> Result<(SocketAddr, StunStats, CleanupDropGuard)> {
211        let stats = StunStats::default();
212
213        let pc = net::UdpSocket::bind((ip, 0)).await?;
214        let mut addr = pc.local_addr()?;
215        match addr.ip() {
216            IpAddr::V4(ip) => {
217                if ip.octets() == [0, 0, 0, 0] {
218                    addr.set_ip("127.0.0.1".parse().unwrap());
219                }
220            }
221            _ => unreachable!("using ipv4"),
222        }
223
224        println!("STUN listening on {}", addr);
225        let (s, r) = oneshot::channel();
226        let stats_c = stats.clone();
227        tokio::task::spawn(async move {
228            run_stun(pc, stats_c, r).await;
229        });
230
231        Ok((addr, stats, CleanupDropGuard(s)))
232    }
233
234    async fn run_stun(pc: net::UdpSocket, stats: StunStats, mut done: oneshot::Receiver<()>) {
235        let mut buf = vec![0u8; 64 << 10];
236        loop {
237            trace!("read loop");
238            tokio::select! {
239                _ = &mut done => {
240                    debug!("shutting down");
241                    break;
242                }
243                res = pc.recv_from(&mut buf) => match res {
244                    Ok((n, addr)) => {
245                        trace!("read packet {}bytes from {}", n, addr);
246                        let pkt = &buf[..n];
247                        if !is(pkt) {
248                            debug!("received non STUN pkt");
249                            continue;
250                        }
251                        if let Ok(txid) = parse_binding_request(pkt) {
252                            debug!("received binding request");
253                            let mut s = stats.0.lock().await;
254                            if addr.is_ipv4() {
255                                s.0 += 1;
256                            } else {
257                                s.1 += 1;
258                            }
259                            drop(s);
260
261                            let res = response(txid, addr);
262                            if let Err(err) = pc.send_to(&res, addr).await {
263                                eprintln!("STUN server write failed: {:?}", err);
264                            }
265                        }
266                    }
267                    Err(err) => {
268                        eprintln!("failed to read: {:?}", err);
269                    }
270                }
271            }
272        }
273    }
274
275    // Test to check if an existing stun server works
276    // #[tokio::test]
277    // async fn test_stun_server() {
278    //     use tokio::net::UdpSocket;
279    //     use std::sync::Arc;
280    //     use hickory_resolver::TokioAsyncResolver;
281
282    //     let domain = "cert-test.iroh.computer";
283    //     let port = 3478;
284
285    //     let txid = TransactionId::default();
286    //     let req = request(txid);
287    //     let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
288
289    //     let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
290    //     let response = resolver.lookup_ip(domain).await.unwrap();
291    //     dbg!(&response);
292
293    //     let server_socket = socket.clone();
294    //     let server_task = tokio::task::spawn(async move {
295    //         let mut buf = vec![0u8; 64000];
296    //         let len = server_socket.recv(&mut buf).await.unwrap();
297    //         dbg!(len);
298    //         buf.truncate(len);
299    //         buf
300    //     });
301
302    //     for addr in response {
303    //         let addr = SocketAddr::new(addr, port);
304    //         println!("sending to {addr}");
305    //         socket.send_to(&req, addr).await.unwrap();
306    //     }
307
308    //     let response = server_task.await.unwrap();
309    //     let (txid_back, response_addr) = parse_response(&response).unwrap();
310    //     assert_eq!(txid, txid_back);
311    //     println!("got {response_addr}");
312    // }
313
314    struct ResponseTestCase {
315        name: &'static str,
316        data: Vec<u8>,
317        want_tid: Vec<u8>,
318        want_addr: IpAddr,
319        want_port: u16,
320    }
321
322    #[test]
323    fn test_parse_response() {
324        let cases = vec![
325            ResponseTestCase {
326		name: "google-1",
327		data: vec![
328		    0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42,
329		    0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
330		    0x93, 0xe0, 0x80, 0x07, 0x00, 0x20, 0x00, 0x08,
331		    0x00, 0x01, 0xc7, 0x86, 0x69, 0x57, 0x85, 0x6f,
332		],
333		want_tid: vec![
334		    0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
335		    0x93, 0xe0, 0x80, 0x07,
336		],
337		want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
338		want_port: 59028,
339	    },
340	    ResponseTestCase {
341		name: "google-2",
342		data: vec![
343		    0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42,
344		    0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
345		    0x92, 0x3c, 0xe2, 0x71, 0x00, 0x20, 0x00, 0x08,
346		    0x00, 0x01, 0xc7, 0x87, 0x69, 0x57, 0x85, 0x6f,
347		],
348		want_tid: vec![
349		    0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
350		    0x92, 0x3c, 0xe2, 0x71,
351		],
352		want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
353		want_port: 59029,
354	    },
355	    ResponseTestCase{
356		name: "stun.sipgate.net:10000",
357		data: vec![
358		    0x01, 0x01, 0x00, 0x44, 0x21, 0x12, 0xa4, 0x42,
359		    0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
360		    0xae, 0xad, 0x64, 0x44, 0x00, 0x01, 0x00, 0x08,
361		    0x00, 0x01, 0xe4, 0xab, 0x48, 0x45, 0x21, 0x2d,
362		    0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x27, 0x10,
363		    0xd9, 0x0a, 0x44, 0x98, 0x00, 0x05, 0x00, 0x08,
364		    0x00, 0x01, 0x27, 0x11, 0xd9, 0x74, 0x7a, 0x8a,
365		    0x80, 0x20, 0x00, 0x08, 0x00, 0x01, 0xc5, 0xb9,
366		    0x69, 0x57, 0x85, 0x6f, 0x80, 0x22, 0x00, 0x10,
367		    0x56, 0x6f, 0x76, 0x69, 0x64, 0x61, 0x2e, 0x6f,
368		    0x72, 0x67, 0x20, 0x30, 0x2e, 0x39, 0x36, 0x00,
369		],
370		want_tid: vec![
371		    0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
372		    0xae, 0xad, 0x64, 0x44,
373		],
374		want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
375		want_port: 58539,
376	    },
377	    ResponseTestCase{
378		name: "stun.powervoip.com:3478",
379		data: vec![
380		    0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
381		    0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
382		    0x9d, 0x1d, 0xea, 0xa6, 0x00, 0x01, 0x00, 0x08,
383		    0x00, 0x01, 0xe9, 0xd3, 0x48, 0x45, 0x21, 0x2d,
384		    0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x0d, 0x96,
385		    0x4d, 0x48, 0xa9, 0xd4, 0x00, 0x05, 0x00, 0x08,
386		    0x00, 0x01, 0x0d, 0x97, 0x4d, 0x48, 0xa9, 0xd5,
387		],
388		want_tid: vec![
389		    0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
390		    0x9d, 0x1d, 0xea, 0xa6,
391		],
392		want_addr: IpAddr::V4(Ipv4Addr::from([72, 69, 33, 45])),
393		want_port: 59859,
394	    },
395	    ResponseTestCase{
396		name: "in-process pion server",
397		data: vec![
398		    0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
399		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
400		    0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x0a,
401		    0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
402		    0x65, 0x72, 0x00, 0x00, 0x00, 0x20, 0x00, 0x08,
403		    0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
404		    0x80, 0x28, 0x00, 0x04, 0xb6, 0x99, 0xbb, 0x02,
405		    0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42,
406		],
407		want_tid: vec![
408		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
409		    0x4f, 0x3e, 0x30, 0x8e,
410		],
411		want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
412		want_port: 61300,
413	    },
414	    ResponseTestCase{
415		name: "stuntman-server ipv6",
416		data: vec![
417		    0x01, 0x01, 0x00, 0x48, 0x21, 0x12, 0xa4, 0x42,
418		    0x06, 0xf5, 0x66, 0x85, 0xd2, 0x8a, 0xf3, 0xe6,
419		    0x9c, 0xe3, 0x41, 0xe2, 0x00, 0x01, 0x00, 0x14,
420		    0x00, 0x02, 0x90, 0xce, 0x26, 0x02, 0x00, 0xd1,
421		    0xb4, 0xcf, 0xc1, 0x00, 0x38, 0xb2, 0x31, 0xff,
422		    0xfe, 0xef, 0x96, 0xf6, 0x80, 0x2b, 0x00, 0x14,
423		    0x00, 0x02, 0x0d, 0x96, 0x26, 0x04, 0xa8, 0x80,
424		    0x00, 0x02, 0x00, 0xd1, 0x00, 0x00, 0x00, 0x00,
425		    0x00, 0xc5, 0x70, 0x01, 0x00, 0x20, 0x00, 0x14,
426		    0x00, 0x02, 0xb1, 0xdc, 0x07, 0x10, 0xa4, 0x93,
427		    0xb2, 0x3a, 0xa7, 0x85, 0xea, 0x38, 0xc2, 0x19,
428		    0x62, 0x0c, 0xd7, 0x14,
429		],
430		want_tid: vec![
431		    6, 245, 102, 133, 210, 138, 243, 230, 156, 227,
432		    65, 226,
433		],
434		want_addr: "2602:d1:b4cf:c100:38b2:31ff:feef:96f6".parse().unwrap(),
435		want_port: 37070,
436	    },
437	    // Testing STUN attribute padding rules using STUN software attribute
438	    // with values of 1 & 3 length respectively before the XorMappedAddress attribute
439	    ResponseTestCase {
440		name: "software-a",
441		data: vec![
442		    0x01, 0x01, 0x00, 0x14, 0x21, 0x12, 0xa4, 0x42,
443		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
444		    0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x01,
445		    0x61, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x08,
446		    0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
447		],
448		want_tid: vec![
449		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
450		    0x4f, 0x3e, 0x30, 0x8e,
451		],
452		want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
453		want_port: 61300,
454	    },
455            ResponseTestCase	{
456		name: "software-abc",
457		data: vec![
458		    0x01, 0x01, 0x00, 0x14, 0x21, 0x12, 0xa4, 0x42,
459		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
460		    0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x03,
461		    0x61, 0x62, 0x63, 0x00, 0x00, 0x20, 0x00, 0x08,
462		    0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43,
463		],
464		want_tid: vec![
465		    0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
466		    0x4f, 0x3e, 0x30, 0x8e,
467		],
468		want_addr: IpAddr::V4(Ipv4Addr::from([127, 0, 0, 1])),
469		want_port: 61300,
470	    },
471            ResponseTestCase	{
472	        name:     "no-4in6",
473	        data:     hex::decode("010100182112a4424fd5d202dcb37d31fc773306002000140002cd3d2112a4424fd5d202dcb382ce2dc3fcc7").unwrap(),
474	        want_tid:  vec![79, 213, 210, 2, 220, 179, 125, 49, 252, 119, 51, 6],
475	        want_addr: IpAddr::V4(Ipv4Addr::from([209, 180, 207, 193])),
476		want_port: 60463,
477	    },
478        ];
479
480        for (i, test) in cases.into_iter().enumerate() {
481            println!("Case {i}: {}", test.name);
482            let (tx, addr_port) = parse_response(&test.data).unwrap();
483            assert!(is(&test.data));
484            assert_eq!(tx.as_bytes(), &test.want_tid[..]);
485            assert_eq!(addr_port.ip(), test.want_addr);
486            assert_eq!(addr_port.port(), test.want_port);
487        }
488    }
489
490    #[test]
491    fn test_parse_binding_request() {
492        let tx = TransactionId::default();
493        let req = request(tx);
494        assert!(is(&req));
495        let got_tx = parse_binding_request(&req).unwrap();
496        assert_eq!(got_tx, tx);
497    }
498
499    #[test]
500    fn test_stun_cookie() {
501        assert_eq!(stun_rs::MAGIC_COOKIE, COOKIE);
502    }
503
504    #[test]
505    fn test_response() {
506        let txn = |n| TransactionId::from([n; 12]);
507
508        struct Case {
509            tx: TransactionId,
510            addr: IpAddr,
511            port: u16,
512        }
513        let tests = vec![
514            Case {
515                tx: txn(1),
516                addr: "1.2.3.4".parse().unwrap(),
517                port: 254,
518            },
519            Case {
520                tx: txn(2),
521                addr: "1.2.3.4".parse().unwrap(),
522                port: 257,
523            },
524            Case {
525                tx: txn(3),
526                addr: "1::4".parse().unwrap(),
527                port: 254,
528            },
529            Case {
530                tx: txn(4),
531                addr: "1::4".parse().unwrap(),
532                port: 257,
533            },
534        ];
535
536        for tt in tests {
537            let res = response(tt.tx, SocketAddr::new(tt.addr, tt.port));
538            assert!(is(&res));
539            let (tx2, addr2) = parse_response(&res).unwrap();
540            assert_eq!(tt.tx, tx2);
541            assert_eq!(tt.addr, addr2.ip());
542            assert_eq!(tt.port, addr2.port());
543        }
544    }
545}