narrowlink_network/
p2p.rs

1use core::fmt::Display;
2use std::{
3    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
4    sync::Arc,
5    // sync::{atomic::AtomicU32, Arc},
6    time::Duration,
7};
8
9use async_recursion::async_recursion;
10use narrowlink_types::{
11    generic::{Connect, CryptographicAlgorithm, Protocol, SigningAlgorithm},
12    NatType, Peer2PeerInstruction,
13};
14use quinn::{ClientConfig, Connection, Endpoint, EndpointConfig, RecvStream, SendStream};
15use tokio::{
16    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
17    net::UdpSocket,
18};
19use tracing::{debug, field::debug, info, warn};
20
21use crate::error::NetworkError;
22#[derive(PartialEq)]
23pub enum Command {
24    IPv4TCP = 0x01,
25    IPv6TCP = 0x02,
26    DomainTCP = 0x03,
27    IPv4UDP = 0x04,
28    IPv6UDP = 0x05,
29    DomainUDP = 0x06,
30}
31
32impl Command {
33    fn from_u8(val: u8) -> Result<Self, NetworkError> {
34        match val {
35            0x01 => Ok(Self::IPv4TCP),
36            0x02 => Ok(Self::IPv6TCP),
37            0x03 => Ok(Self::DomainTCP),
38            0x04 => Ok(Self::IPv4UDP),
39            0x05 => Ok(Self::IPv6UDP),
40            0x06 => Ok(Self::DomainUDP),
41            _ => Err(NetworkError::P2PInvalidCommand),
42        }
43    }
44}
45
46pub enum Request {
47    // Todo: Add signature and salt
48    Ip(
49        SocketAddr,
50        bool,
51        Option<(CryptographicAlgorithm, SigningAlgorithm)>,
52    ), // bool is UDP
53    Dns(
54        String,
55        u16,
56        bool,
57        Option<(CryptographicAlgorithm, SigningAlgorithm)>,
58    ), // bool is UDP
59}
60
61impl Request {
62    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Self, NetworkError> {
63        let cmd = Command::from_u8(reader.read_u8().await?)?;
64        let req = match cmd {
65            Command::DomainTCP | Command::DomainUDP => {
66                let len = reader.read_u8().await?;
67                let mut buf = vec![0; len as usize + 2];
68                reader.read_exact(&mut buf).await?;
69                let domain = String::from_utf8(buf[..buf.len() - 2].to_vec())
70                    .map_err(|_| NetworkError::P2PInvalidDomain)?;
71                let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
72                Self::Dns(domain, port, cmd == Command::DomainUDP, None)
73            }
74            Command::IPv4TCP | Command::IPv4UDP => {
75                let mut buf = vec![0; 4 + 2];
76                reader.read_exact(&mut buf).await?;
77                let ipv4 = std::net::Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
78                let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
79                Self::Ip(
80                    SocketAddr::new(ipv4.into(), port),
81                    cmd == Command::IPv4UDP,
82                    None,
83                )
84            }
85            Command::IPv6TCP | Command::IPv6UDP => {
86                let mut buf = vec![0; 16 + 2];
87                reader.read_exact(&mut buf).await?;
88                let ipv6 = std::net::Ipv6Addr::new(
89                    u16::from_be_bytes([buf[0], buf[1]]),
90                    u16::from_be_bytes([buf[2], buf[3]]),
91                    u16::from_be_bytes([buf[4], buf[5]]),
92                    u16::from_be_bytes([buf[6], buf[7]]),
93                    u16::from_be_bytes([buf[8], buf[9]]),
94                    u16::from_be_bytes([buf[10], buf[11]]),
95                    u16::from_be_bytes([buf[12], buf[13]]),
96                    u16::from_be_bytes([buf[14], buf[15]]),
97                );
98                let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
99                Self::Ip(
100                    SocketAddr::new(ipv6.into(), port),
101                    cmd == Command::IPv6UDP,
102                    None,
103                )
104            }
105        };
106        if reader.read_u8().await? == 1 {
107            let mut buf = vec![0; 24 + 32];
108            reader.read_exact(&mut buf).await?;
109            let crypto = CryptographicAlgorithm::XChaCha20Poly1305(
110                buf[..24]
111                    .try_into()
112                    .map_err(|_| NetworkError::P2PInvalidCrypto)?,
113            );
114            let sign = SigningAlgorithm::HmacSha256(
115                buf[24..]
116                    .try_into()
117                    .map_err(|_| NetworkError::P2PInvalidCrypto)?,
118            );
119            let req = match req {
120                Self::Ip(ip, udp, _) => Self::Ip(ip, udp, Some((crypto, sign))),
121                Self::Dns(domain, port, udp, _) => {
122                    Self::Dns(domain, port, udp, Some((crypto, sign)))
123                }
124            };
125            Ok(req)
126        } else {
127            Ok(req)
128        }
129    }
130    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<(), NetworkError> {
131        match self {
132            Request::Ip(ip, udp, crypt) => {
133                let cmd = if ip.is_ipv4() {
134                    if *udp {
135                        Command::IPv4UDP
136                    } else {
137                        Command::IPv4TCP
138                    }
139                } else if *udp {
140                    Command::IPv6UDP
141                } else {
142                    Command::IPv6TCP
143                };
144                writer.write_u8(cmd as u8).await?;
145                match ip {
146                    SocketAddr::V4(ipv4) => {
147                        writer.write_all(&ipv4.ip().octets()).await?;
148                    }
149                    SocketAddr::V6(ipv6) => {
150                        writer.write_all(&ipv6.ip().octets()).await?;
151                    }
152                }
153                writer.write_u16(ip.port()).await?;
154                if let Some(c) = crypt {
155                    writer.write_u8(1).await?;
156                    match c {
157                        (
158                            CryptographicAlgorithm::XChaCha20Poly1305(iv),
159                            SigningAlgorithm::HmacSha256(key),
160                        ) => {
161                            writer.write_all(iv).await?;
162                            writer.write_all(key).await?;
163                        }
164                    }
165                } else {
166                    writer.write_u8(0).await?;
167                }
168            }
169            Request::Dns(domain, port, udp, crypt) => {
170                let cmd = if *udp {
171                    Command::DomainUDP
172                } else {
173                    Command::DomainTCP
174                };
175                writer.write_u8(cmd as u8).await?;
176                writer.write_u8(domain.len() as u8).await?;
177                writer.write_all(domain.as_bytes()).await?;
178                writer.write_u16(*port).await?;
179                if let Some(c) = crypt {
180                    writer.write_u8(1).await?;
181                    match c {
182                        (
183                            CryptographicAlgorithm::XChaCha20Poly1305(iv),
184                            SigningAlgorithm::HmacSha256(key),
185                        ) => {
186                            writer.write_all(iv).await?;
187                            writer.write_all(key).await?;
188                        }
189                    }
190                } else {
191                    writer.write_u8(0).await?;
192                }
193            }
194        }
195        Ok(())
196    }
197}
198
199#[derive(Clone, Copy, Debug)]
200pub enum Response {
201    Success = 0x00,
202    InvalidRequest = 0x01,
203    AccessDenied = 0x02,
204    UnableToResolve = 0x03,
205    Failed = 0xFF,
206}
207
208impl Display for Response {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            Self::Success => write!(f, "Success"),
212            Self::InvalidRequest => write!(f, "InvalidRequest"),
213            Self::AccessDenied => write!(f, "AccessDenied"),
214            Self::UnableToResolve => write!(f, "UnableToResolve"),
215            Self::Failed => write!(f, "Failed"),
216        }
217    }
218}
219
220impl Response {
221    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Self, NetworkError> {
222        let val = reader.read_u8().await?;
223        match val {
224            0x00 => Ok(Self::Success),
225            0x01 => Ok(Self::InvalidRequest),
226            0x02 => Ok(Self::AccessDenied),
227            0xFF => Ok(Self::Failed),
228            _ => Err(NetworkError::P2PInvalidCommand),
229        }
230    }
231    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<(), NetworkError> {
232        writer.write_u8(*self as u8).await?;
233        Ok(())
234    }
235}
236
237impl From<&Request> for Connect {
238    fn from(r: &Request) -> Self {
239        let (host, port, is_udp, crypt) = match r {
240            Request::Ip(ip, udp, crypt) => (ip.ip().to_string(), ip.port(), udp, crypt),
241            Request::Dns(domain, port, udp, crypt) => (domain.to_owned(), *port, udp, crypt),
242        };
243        let (cryptography, sign) = if let Some((c, s)) = crypt {
244            (Some(c.clone()), Some(s.clone()))
245        } else {
246            (None, None)
247        };
248        Connect {
249            host,
250            port,
251            protocol: if *is_udp {
252                Protocol::UDP
253            } else {
254                Protocol::TCP
255            },
256            cryptography,
257            sign,
258        }
259    }
260}
261
262impl From<&Connect> for Request {
263    fn from(connect: &Connect) -> Self {
264        let crypt = if let (Some(c), Some(s)) = (&connect.cryptography, &connect.sign) {
265            Some((c.clone(), s.clone()))
266        } else {
267            None
268        };
269        match connect.protocol {
270            Protocol::TCP | Protocol::HTTP | Protocol::HTTPS | Protocol::TLS => {
271                match connect.host.parse::<IpAddr>() {
272                    Ok(ip) => Request::Ip(SocketAddr::new(ip, connect.port), false, crypt),
273                    Err(_) => Request::Dns(connect.host.to_owned(), connect.port, false, crypt),
274                }
275            }
276            Protocol::UDP | Protocol::DTLS | Protocol::QUIC => match connect.host.parse::<IpAddr>()
277            {
278                Ok(ip) => Request::Ip(SocketAddr::new(ip, connect.port), true, crypt),
279                Err(_) => Request::Dns(connect.host.to_owned(), connect.port, true, crypt),
280            },
281        }
282    }
283}
284
285pub struct QuicStream {
286    con: Connection,
287    // number_of_streams: Arc<AtomicU32>,
288}
289
290impl QuicStream {
291    pub async fn new_client(
292        remote_addr: SocketAddr,
293        socket: UdpSocket, // tokio udpsocket
294        cert: Vec<u8>,
295    ) -> Result<Self, NetworkError> {
296        debug!("Connecting to {}", remote_addr);
297        let mut end = Endpoint::new(
298            EndpointConfig::default(),
299            None,
300            socket.into_std()?,
301            Arc::new(quinn::TokioRuntime),
302        )?;
303        let mut root_store = rustls::RootCertStore::empty();
304        root_store
305            .add(&rustls::Certificate(cert))
306            .map_err(|_| NetworkError::TlsError)?;
307        let mut config = rustls::ClientConfig::builder()
308            .with_safe_defaults()
309            .with_root_certificates(root_store)
310            .with_no_client_auth();
311        config.enable_sni = false;
312        end.set_default_client_config(ClientConfig::new(Arc::new(config)));
313
314        let con = end
315            .connect(remote_addr, &remote_addr.ip().to_string())
316            .map_err(|_| NetworkError::QuicError)?
317            .await
318            .map_err(|_| NetworkError::QuicError)?;
319        Ok(Self { con })
320    }
321    pub async fn new_server(
322        socket: UdpSocket, // tokio udpsocket
323        cert: Vec<u8>,
324        key: Vec<u8>,
325    ) -> Result<Self, NetworkError> {
326        debug("Accepting connection");
327        let mut server_config = quinn::ServerConfig::with_single_cert(
328            vec![rustls::Certificate(cert)],
329            rustls::PrivateKey(key),
330        )
331        .map_err(|_| NetworkError::TlsError)?;
332        if let Some(conf) = std::sync::Arc::get_mut(&mut server_config.transport) {
333            conf.keep_alive_interval(Some(Duration::from_secs(4)));
334            conf.max_concurrent_uni_streams(0_u8.into());
335            conf.max_concurrent_bidi_streams(1024_u16.into());
336        };
337        let end = Endpoint::new(
338            EndpointConfig::default(),
339            Some(server_config),
340            socket.into_std()?,
341            Arc::new(quinn::TokioRuntime),
342        )?;
343        let con = end
344            .accept()
345            .await
346            .ok_or(NetworkError::QuicError)?
347            .await
348            .map_err(|_| NetworkError::QuicError)?;
349        Ok(Self { con })
350    }
351    pub async fn open_bi(&self) -> Result<QuicBiSocket, NetworkError> {
352        let (send, recv) = self
353            .con
354            .open_bi()
355            .await
356            .map_err(|_| NetworkError::QuicError)?;
357        // self.number_of_streams.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
358        Ok(QuicBiSocket {
359            send,
360            recv,
361            // number_of_streams: self.number_of_streams.clone(),
362        })
363    }
364    pub async fn accept_bi(&self) -> Result<QuicBiSocket, NetworkError> {
365        let (send, recv) = self
366            .con
367            .accept_bi()
368            .await
369            .map_err(|_| NetworkError::QuicError)?;
370        // self.number_of_streams.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
371        Ok(QuicBiSocket {
372            send,
373            recv,
374            // number_of_streams: self.number_of_streams.clone(),
375        })
376    }
377    pub fn remote_addr(&self) -> SocketAddr {
378        self.con.remote_address()
379    }
380}
381
382pub struct QuicBiSocket {
383    send: SendStream,
384    recv: RecvStream,
385    // number_of_streams: Arc<AtomicU32>,
386}
387
388impl AsyncRead for QuicBiSocket {
389    fn poll_read(
390        mut self: std::pin::Pin<&mut Self>,
391        cx: &mut std::task::Context<'_>,
392        buf: &mut tokio::io::ReadBuf<'_>,
393    ) -> std::task::Poll<std::io::Result<()>> {
394        std::pin::Pin::new(&mut self.recv).poll_read(cx, buf)
395    }
396}
397
398impl AsyncWrite for QuicBiSocket {
399    fn poll_write(
400        mut self: std::pin::Pin<&mut Self>,
401        cx: &mut std::task::Context<'_>,
402        buf: &[u8],
403    ) -> std::task::Poll<Result<usize, std::io::Error>> {
404        std::pin::Pin::new(&mut self.send).poll_write(cx, buf)
405    }
406
407    fn poll_flush(
408        mut self: std::pin::Pin<&mut Self>,
409        cx: &mut std::task::Context<'_>,
410    ) -> std::task::Poll<Result<(), std::io::Error>> {
411        std::pin::Pin::new(&mut self.send).poll_flush(cx)
412    }
413
414    fn poll_shutdown(
415        mut self: std::pin::Pin<&mut Self>,
416        cx: &mut std::task::Context<'_>,
417    ) -> std::task::Poll<Result<(), std::io::Error>> {
418        std::pin::Pin::new(&mut self.send).poll_shutdown(cx)
419    }
420}
421
422// impl Drop for QuicBiSocket {
423//     fn drop(&mut self) {
424//         self.number_of_streams.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
425//     }
426// }
427
428#[async_recursion]
429pub async fn udp_punched_socket(
430    p2p: Peer2PeerInstruction,
431    handshake_key: &[u8],
432    left: bool,
433    inner: bool,
434) -> Result<(UdpSocket, SocketAddr), NetworkError> {
435    debug!("P2P: {:?}", p2p);
436    let unspecified_ip = if p2p.peer_ip.is_ipv4() {
437        IpAddr::V4(Ipv4Addr::UNSPECIFIED)
438    } else {
439        IpAddr::V6(Ipv6Addr::UNSPECIFIED)
440    };
441    #[cfg(unix)]
442    let no_file_limit = rlimit::getrlimit(rlimit::Resource::NOFILE)
443        .map(|(n, _)| n)
444        .ok();
445
446    #[cfg(unix)]
447    if p2p.seq > 128 && no_file_limit.is_some() {
448        _ = rlimit::increase_nofile_limit(512);
449    }
450
451    let (puncher, dyn_my_port, dyn_peer_port) = match (p2p.nat, p2p.peer_nat) {
452        (NatType::Easy, NatType::Easy) => (left, true, true),
453        (NatType::Easy, NatType::Hard) => (true, false, true),
454        (NatType::Easy, NatType::Unknown) => (true, false, true),
455        (NatType::Hard, NatType::Easy) => (false, true, false),
456        (NatType::Hard, NatType::Hard) => (left, left, !left),
457        (NatType::Hard, NatType::Unknown) => (false, true, false),
458        (NatType::Unknown, NatType::Easy) => (false, true, false),
459        (NatType::Unknown, NatType::Hard) => (true, false, true),
460        (NatType::Unknown, NatType::Unknown) => (left, left, !left),
461    };
462
463    if !puncher {
464        tokio::time::sleep(Duration::from_millis(1000)).await;
465    }
466
467    let mut sockets = Vec::new();
468    let mut socket: Option<UdpSocket> = None;
469    for s in 1..p2p.seq + 1 {
470        let my_port = if dyn_my_port {
471            if left {
472                p2p.seed_port - s
473            } else {
474                p2p.seed_port + s
475            }
476        } else {
477            p2p.seed_port
478        };
479        let peer_port = if dyn_peer_port {
480            if left {
481                p2p.seed_port + s
482            } else {
483                p2p.seed_port - s
484            }
485        } else {
486            p2p.seed_port
487        };
488        if socket.is_none() || dyn_my_port {
489            match UdpSocket::bind(SocketAddr::new(unspecified_ip, my_port)).await {
490                Ok(s) => socket.replace(s),
491                Err(e) => {
492                    warn!("Error binding socket on {}, {}", my_port, e.to_string());
493                    continue;
494                }
495            };
496        }
497
498        if let Some(socket) = socket.as_ref() {
499            let buf = if puncher {
500                debug!(
501                    "Punching peer {}:{} -> {}:{}",
502                    unspecified_ip, my_port, p2p.peer_ip, peer_port
503                );
504                vec![0]
505            } else {
506                debug!(
507                    "Discovering peer {}:{} -> {}:{}",
508                    unspecified_ip, my_port, p2p.peer_ip, peer_port
509                );
510                handshake_key[0..3].to_vec()
511            };
512            if let Err(e) = socket
513                .send_to(&buf, SocketAddr::new(p2p.peer_ip, peer_port))
514                .await
515            {
516                warn!("Error sending to peer: {}", e);
517            };
518        }
519        if s == p2p.seq || dyn_my_port {
520            if let Some(socket) = socket.take() {
521                sockets.push(Box::pin(async { socket.readable().await.map(|_| socket) }));
522            }
523        }
524    }
525    loop {
526        if sockets.is_empty() {
527            #[cfg(unix)]
528            no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
529            return Err(NetworkError::P2PFailed);
530        };
531        let Ok((socket, _size, remaining_sockets)) = tokio::time::timeout(
532            Duration::from_secs(if p2p.seq > 128 { 15 } else { 5 }),
533            futures_util::future::select_all(sockets),
534        )
535        .await
536        else {
537            warn!("Timeout waiting for response from peer");
538            if !inner && p2p.nat == p2p.peer_nat {
539                info!("Trying to punch peer from other side");
540                if puncher {
541                    tokio::time::sleep(Duration::from_millis(1000)).await;
542                }
543                return udp_punched_socket(p2p, handshake_key, !left, true).await;
544            }
545            #[cfg(unix)]
546            no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
547            return Err(NetworkError::P2PTimeout);
548        };
549        let socket = match socket {
550            Ok(socket) => socket,
551            Err(e) => {
552                warn!("Error reading from socket: {}", e);
553                sockets = remaining_sockets;
554                continue;
555            }
556        };
557
558        let mut buf = vec![0u8; 3];
559        let peer = match socket.recv_from(&mut buf).await {
560            Ok((_, peer)) => peer,
561            Err(e) => {
562                warn!("Error receiving from socket: {}", e);
563                sockets = remaining_sockets;
564                continue;
565            }
566        };
567
568        if puncher && handshake_key[0..3] == buf[0..3] {
569            if let Ok(local_addr) = socket.local_addr() {
570                debug!(
571                    "Confirming p2p channel peer {}:{} -> {}:{}",
572                    local_addr.ip(),
573                    local_addr.port(),
574                    peer.ip(),
575                    peer.port()
576                );
577            }
578            if let Err(e) = socket.send_to(&handshake_key[3..6], peer).await {
579                warn!("Error sending to peer: {}", e);
580                sockets = remaining_sockets;
581                continue;
582            }
583        } else if handshake_key[3..6] == buf[0..3] {
584        } else {
585            warn!("Invalid response from peer");
586            sockets = remaining_sockets;
587            continue;
588        };
589        #[cfg(unix)]
590        no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
591        return Ok((socket, peer));
592    }
593}