1use core::fmt::Display;
2use std::{
3 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
4 sync::Arc,
5 time::Duration,
7};
8
9use async_recursion::async_recursion;
10use narrowlink_types::{
11 generic::{Connect, CryptographicAlgorithm, Protocol, SigningAlgorithm},
12 NatType, Peer2PeerInstruction,
13};
14use quinn::{ClientConfig, Connection, Endpoint, EndpointConfig, RecvStream, SendStream};
15use tokio::{
16 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
17 net::UdpSocket,
18};
19use tracing::{debug, field::debug, info, warn};
20
21use crate::error::NetworkError;
22#[derive(PartialEq)]
23pub enum Command {
24 IPv4TCP = 0x01,
25 IPv6TCP = 0x02,
26 DomainTCP = 0x03,
27 IPv4UDP = 0x04,
28 IPv6UDP = 0x05,
29 DomainUDP = 0x06,
30}
31
32impl Command {
33 fn from_u8(val: u8) -> Result<Self, NetworkError> {
34 match val {
35 0x01 => Ok(Self::IPv4TCP),
36 0x02 => Ok(Self::IPv6TCP),
37 0x03 => Ok(Self::DomainTCP),
38 0x04 => Ok(Self::IPv4UDP),
39 0x05 => Ok(Self::IPv6UDP),
40 0x06 => Ok(Self::DomainUDP),
41 _ => Err(NetworkError::P2PInvalidCommand),
42 }
43 }
44}
45
46pub enum Request {
47 Ip(
49 SocketAddr,
50 bool,
51 Option<(CryptographicAlgorithm, SigningAlgorithm)>,
52 ), Dns(
54 String,
55 u16,
56 bool,
57 Option<(CryptographicAlgorithm, SigningAlgorithm)>,
58 ), }
60
61impl Request {
62 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Self, NetworkError> {
63 let cmd = Command::from_u8(reader.read_u8().await?)?;
64 let req = match cmd {
65 Command::DomainTCP | Command::DomainUDP => {
66 let len = reader.read_u8().await?;
67 let mut buf = vec![0; len as usize + 2];
68 reader.read_exact(&mut buf).await?;
69 let domain = String::from_utf8(buf[..buf.len() - 2].to_vec())
70 .map_err(|_| NetworkError::P2PInvalidDomain)?;
71 let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
72 Self::Dns(domain, port, cmd == Command::DomainUDP, None)
73 }
74 Command::IPv4TCP | Command::IPv4UDP => {
75 let mut buf = vec![0; 4 + 2];
76 reader.read_exact(&mut buf).await?;
77 let ipv4 = std::net::Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
78 let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
79 Self::Ip(
80 SocketAddr::new(ipv4.into(), port),
81 cmd == Command::IPv4UDP,
82 None,
83 )
84 }
85 Command::IPv6TCP | Command::IPv6UDP => {
86 let mut buf = vec![0; 16 + 2];
87 reader.read_exact(&mut buf).await?;
88 let ipv6 = std::net::Ipv6Addr::new(
89 u16::from_be_bytes([buf[0], buf[1]]),
90 u16::from_be_bytes([buf[2], buf[3]]),
91 u16::from_be_bytes([buf[4], buf[5]]),
92 u16::from_be_bytes([buf[6], buf[7]]),
93 u16::from_be_bytes([buf[8], buf[9]]),
94 u16::from_be_bytes([buf[10], buf[11]]),
95 u16::from_be_bytes([buf[12], buf[13]]),
96 u16::from_be_bytes([buf[14], buf[15]]),
97 );
98 let port = u16::from_be_bytes([buf[buf.len() - 2], buf[buf.len() - 1]]);
99 Self::Ip(
100 SocketAddr::new(ipv6.into(), port),
101 cmd == Command::IPv6UDP,
102 None,
103 )
104 }
105 };
106 if reader.read_u8().await? == 1 {
107 let mut buf = vec![0; 24 + 32];
108 reader.read_exact(&mut buf).await?;
109 let crypto = CryptographicAlgorithm::XChaCha20Poly1305(
110 buf[..24]
111 .try_into()
112 .map_err(|_| NetworkError::P2PInvalidCrypto)?,
113 );
114 let sign = SigningAlgorithm::HmacSha256(
115 buf[24..]
116 .try_into()
117 .map_err(|_| NetworkError::P2PInvalidCrypto)?,
118 );
119 let req = match req {
120 Self::Ip(ip, udp, _) => Self::Ip(ip, udp, Some((crypto, sign))),
121 Self::Dns(domain, port, udp, _) => {
122 Self::Dns(domain, port, udp, Some((crypto, sign)))
123 }
124 };
125 Ok(req)
126 } else {
127 Ok(req)
128 }
129 }
130 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<(), NetworkError> {
131 match self {
132 Request::Ip(ip, udp, crypt) => {
133 let cmd = if ip.is_ipv4() {
134 if *udp {
135 Command::IPv4UDP
136 } else {
137 Command::IPv4TCP
138 }
139 } else if *udp {
140 Command::IPv6UDP
141 } else {
142 Command::IPv6TCP
143 };
144 writer.write_u8(cmd as u8).await?;
145 match ip {
146 SocketAddr::V4(ipv4) => {
147 writer.write_all(&ipv4.ip().octets()).await?;
148 }
149 SocketAddr::V6(ipv6) => {
150 writer.write_all(&ipv6.ip().octets()).await?;
151 }
152 }
153 writer.write_u16(ip.port()).await?;
154 if let Some(c) = crypt {
155 writer.write_u8(1).await?;
156 match c {
157 (
158 CryptographicAlgorithm::XChaCha20Poly1305(iv),
159 SigningAlgorithm::HmacSha256(key),
160 ) => {
161 writer.write_all(iv).await?;
162 writer.write_all(key).await?;
163 }
164 }
165 } else {
166 writer.write_u8(0).await?;
167 }
168 }
169 Request::Dns(domain, port, udp, crypt) => {
170 let cmd = if *udp {
171 Command::DomainUDP
172 } else {
173 Command::DomainTCP
174 };
175 writer.write_u8(cmd as u8).await?;
176 writer.write_u8(domain.len() as u8).await?;
177 writer.write_all(domain.as_bytes()).await?;
178 writer.write_u16(*port).await?;
179 if let Some(c) = crypt {
180 writer.write_u8(1).await?;
181 match c {
182 (
183 CryptographicAlgorithm::XChaCha20Poly1305(iv),
184 SigningAlgorithm::HmacSha256(key),
185 ) => {
186 writer.write_all(iv).await?;
187 writer.write_all(key).await?;
188 }
189 }
190 } else {
191 writer.write_u8(0).await?;
192 }
193 }
194 }
195 Ok(())
196 }
197}
198
199#[derive(Clone, Copy, Debug)]
200pub enum Response {
201 Success = 0x00,
202 InvalidRequest = 0x01,
203 AccessDenied = 0x02,
204 UnableToResolve = 0x03,
205 Failed = 0xFF,
206}
207
208impl Display for Response {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 Self::Success => write!(f, "Success"),
212 Self::InvalidRequest => write!(f, "InvalidRequest"),
213 Self::AccessDenied => write!(f, "AccessDenied"),
214 Self::UnableToResolve => write!(f, "UnableToResolve"),
215 Self::Failed => write!(f, "Failed"),
216 }
217 }
218}
219
220impl Response {
221 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Self, NetworkError> {
222 let val = reader.read_u8().await?;
223 match val {
224 0x00 => Ok(Self::Success),
225 0x01 => Ok(Self::InvalidRequest),
226 0x02 => Ok(Self::AccessDenied),
227 0xFF => Ok(Self::Failed),
228 _ => Err(NetworkError::P2PInvalidCommand),
229 }
230 }
231 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<(), NetworkError> {
232 writer.write_u8(*self as u8).await?;
233 Ok(())
234 }
235}
236
237impl From<&Request> for Connect {
238 fn from(r: &Request) -> Self {
239 let (host, port, is_udp, crypt) = match r {
240 Request::Ip(ip, udp, crypt) => (ip.ip().to_string(), ip.port(), udp, crypt),
241 Request::Dns(domain, port, udp, crypt) => (domain.to_owned(), *port, udp, crypt),
242 };
243 let (cryptography, sign) = if let Some((c, s)) = crypt {
244 (Some(c.clone()), Some(s.clone()))
245 } else {
246 (None, None)
247 };
248 Connect {
249 host,
250 port,
251 protocol: if *is_udp {
252 Protocol::UDP
253 } else {
254 Protocol::TCP
255 },
256 cryptography,
257 sign,
258 }
259 }
260}
261
262impl From<&Connect> for Request {
263 fn from(connect: &Connect) -> Self {
264 let crypt = if let (Some(c), Some(s)) = (&connect.cryptography, &connect.sign) {
265 Some((c.clone(), s.clone()))
266 } else {
267 None
268 };
269 match connect.protocol {
270 Protocol::TCP | Protocol::HTTP | Protocol::HTTPS | Protocol::TLS => {
271 match connect.host.parse::<IpAddr>() {
272 Ok(ip) => Request::Ip(SocketAddr::new(ip, connect.port), false, crypt),
273 Err(_) => Request::Dns(connect.host.to_owned(), connect.port, false, crypt),
274 }
275 }
276 Protocol::UDP | Protocol::DTLS | Protocol::QUIC => match connect.host.parse::<IpAddr>()
277 {
278 Ok(ip) => Request::Ip(SocketAddr::new(ip, connect.port), true, crypt),
279 Err(_) => Request::Dns(connect.host.to_owned(), connect.port, true, crypt),
280 },
281 }
282 }
283}
284
285pub struct QuicStream {
286 con: Connection,
287 }
289
290impl QuicStream {
291 pub async fn new_client(
292 remote_addr: SocketAddr,
293 socket: UdpSocket, cert: Vec<u8>,
295 ) -> Result<Self, NetworkError> {
296 debug!("Connecting to {}", remote_addr);
297 let mut end = Endpoint::new(
298 EndpointConfig::default(),
299 None,
300 socket.into_std()?,
301 Arc::new(quinn::TokioRuntime),
302 )?;
303 let mut root_store = rustls::RootCertStore::empty();
304 root_store
305 .add(&rustls::Certificate(cert))
306 .map_err(|_| NetworkError::TlsError)?;
307 let mut config = rustls::ClientConfig::builder()
308 .with_safe_defaults()
309 .with_root_certificates(root_store)
310 .with_no_client_auth();
311 config.enable_sni = false;
312 end.set_default_client_config(ClientConfig::new(Arc::new(config)));
313
314 let con = end
315 .connect(remote_addr, &remote_addr.ip().to_string())
316 .map_err(|_| NetworkError::QuicError)?
317 .await
318 .map_err(|_| NetworkError::QuicError)?;
319 Ok(Self { con })
320 }
321 pub async fn new_server(
322 socket: UdpSocket, cert: Vec<u8>,
324 key: Vec<u8>,
325 ) -> Result<Self, NetworkError> {
326 debug("Accepting connection");
327 let mut server_config = quinn::ServerConfig::with_single_cert(
328 vec![rustls::Certificate(cert)],
329 rustls::PrivateKey(key),
330 )
331 .map_err(|_| NetworkError::TlsError)?;
332 if let Some(conf) = std::sync::Arc::get_mut(&mut server_config.transport) {
333 conf.keep_alive_interval(Some(Duration::from_secs(4)));
334 conf.max_concurrent_uni_streams(0_u8.into());
335 conf.max_concurrent_bidi_streams(1024_u16.into());
336 };
337 let end = Endpoint::new(
338 EndpointConfig::default(),
339 Some(server_config),
340 socket.into_std()?,
341 Arc::new(quinn::TokioRuntime),
342 )?;
343 let con = end
344 .accept()
345 .await
346 .ok_or(NetworkError::QuicError)?
347 .await
348 .map_err(|_| NetworkError::QuicError)?;
349 Ok(Self { con })
350 }
351 pub async fn open_bi(&self) -> Result<QuicBiSocket, NetworkError> {
352 let (send, recv) = self
353 .con
354 .open_bi()
355 .await
356 .map_err(|_| NetworkError::QuicError)?;
357 Ok(QuicBiSocket {
359 send,
360 recv,
361 })
363 }
364 pub async fn accept_bi(&self) -> Result<QuicBiSocket, NetworkError> {
365 let (send, recv) = self
366 .con
367 .accept_bi()
368 .await
369 .map_err(|_| NetworkError::QuicError)?;
370 Ok(QuicBiSocket {
372 send,
373 recv,
374 })
376 }
377 pub fn remote_addr(&self) -> SocketAddr {
378 self.con.remote_address()
379 }
380}
381
382pub struct QuicBiSocket {
383 send: SendStream,
384 recv: RecvStream,
385 }
387
388impl AsyncRead for QuicBiSocket {
389 fn poll_read(
390 mut self: std::pin::Pin<&mut Self>,
391 cx: &mut std::task::Context<'_>,
392 buf: &mut tokio::io::ReadBuf<'_>,
393 ) -> std::task::Poll<std::io::Result<()>> {
394 std::pin::Pin::new(&mut self.recv).poll_read(cx, buf)
395 }
396}
397
398impl AsyncWrite for QuicBiSocket {
399 fn poll_write(
400 mut self: std::pin::Pin<&mut Self>,
401 cx: &mut std::task::Context<'_>,
402 buf: &[u8],
403 ) -> std::task::Poll<Result<usize, std::io::Error>> {
404 std::pin::Pin::new(&mut self.send).poll_write(cx, buf)
405 }
406
407 fn poll_flush(
408 mut self: std::pin::Pin<&mut Self>,
409 cx: &mut std::task::Context<'_>,
410 ) -> std::task::Poll<Result<(), std::io::Error>> {
411 std::pin::Pin::new(&mut self.send).poll_flush(cx)
412 }
413
414 fn poll_shutdown(
415 mut self: std::pin::Pin<&mut Self>,
416 cx: &mut std::task::Context<'_>,
417 ) -> std::task::Poll<Result<(), std::io::Error>> {
418 std::pin::Pin::new(&mut self.send).poll_shutdown(cx)
419 }
420}
421
422#[async_recursion]
429pub async fn udp_punched_socket(
430 p2p: Peer2PeerInstruction,
431 handshake_key: &[u8],
432 left: bool,
433 inner: bool,
434) -> Result<(UdpSocket, SocketAddr), NetworkError> {
435 debug!("P2P: {:?}", p2p);
436 let unspecified_ip = if p2p.peer_ip.is_ipv4() {
437 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
438 } else {
439 IpAddr::V6(Ipv6Addr::UNSPECIFIED)
440 };
441 #[cfg(unix)]
442 let no_file_limit = rlimit::getrlimit(rlimit::Resource::NOFILE)
443 .map(|(n, _)| n)
444 .ok();
445
446 #[cfg(unix)]
447 if p2p.seq > 128 && no_file_limit.is_some() {
448 _ = rlimit::increase_nofile_limit(512);
449 }
450
451 let (puncher, dyn_my_port, dyn_peer_port) = match (p2p.nat, p2p.peer_nat) {
452 (NatType::Easy, NatType::Easy) => (left, true, true),
453 (NatType::Easy, NatType::Hard) => (true, false, true),
454 (NatType::Easy, NatType::Unknown) => (true, false, true),
455 (NatType::Hard, NatType::Easy) => (false, true, false),
456 (NatType::Hard, NatType::Hard) => (left, left, !left),
457 (NatType::Hard, NatType::Unknown) => (false, true, false),
458 (NatType::Unknown, NatType::Easy) => (false, true, false),
459 (NatType::Unknown, NatType::Hard) => (true, false, true),
460 (NatType::Unknown, NatType::Unknown) => (left, left, !left),
461 };
462
463 if !puncher {
464 tokio::time::sleep(Duration::from_millis(1000)).await;
465 }
466
467 let mut sockets = Vec::new();
468 let mut socket: Option<UdpSocket> = None;
469 for s in 1..p2p.seq + 1 {
470 let my_port = if dyn_my_port {
471 if left {
472 p2p.seed_port - s
473 } else {
474 p2p.seed_port + s
475 }
476 } else {
477 p2p.seed_port
478 };
479 let peer_port = if dyn_peer_port {
480 if left {
481 p2p.seed_port + s
482 } else {
483 p2p.seed_port - s
484 }
485 } else {
486 p2p.seed_port
487 };
488 if socket.is_none() || dyn_my_port {
489 match UdpSocket::bind(SocketAddr::new(unspecified_ip, my_port)).await {
490 Ok(s) => socket.replace(s),
491 Err(e) => {
492 warn!("Error binding socket on {}, {}", my_port, e.to_string());
493 continue;
494 }
495 };
496 }
497
498 if let Some(socket) = socket.as_ref() {
499 let buf = if puncher {
500 debug!(
501 "Punching peer {}:{} -> {}:{}",
502 unspecified_ip, my_port, p2p.peer_ip, peer_port
503 );
504 vec![0]
505 } else {
506 debug!(
507 "Discovering peer {}:{} -> {}:{}",
508 unspecified_ip, my_port, p2p.peer_ip, peer_port
509 );
510 handshake_key[0..3].to_vec()
511 };
512 if let Err(e) = socket
513 .send_to(&buf, SocketAddr::new(p2p.peer_ip, peer_port))
514 .await
515 {
516 warn!("Error sending to peer: {}", e);
517 };
518 }
519 if s == p2p.seq || dyn_my_port {
520 if let Some(socket) = socket.take() {
521 sockets.push(Box::pin(async { socket.readable().await.map(|_| socket) }));
522 }
523 }
524 }
525 loop {
526 if sockets.is_empty() {
527 #[cfg(unix)]
528 no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
529 return Err(NetworkError::P2PFailed);
530 };
531 let Ok((socket, _size, remaining_sockets)) = tokio::time::timeout(
532 Duration::from_secs(if p2p.seq > 128 { 15 } else { 5 }),
533 futures_util::future::select_all(sockets),
534 )
535 .await
536 else {
537 warn!("Timeout waiting for response from peer");
538 if !inner && p2p.nat == p2p.peer_nat {
539 info!("Trying to punch peer from other side");
540 if puncher {
541 tokio::time::sleep(Duration::from_millis(1000)).await;
542 }
543 return udp_punched_socket(p2p, handshake_key, !left, true).await;
544 }
545 #[cfg(unix)]
546 no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
547 return Err(NetworkError::P2PTimeout);
548 };
549 let socket = match socket {
550 Ok(socket) => socket,
551 Err(e) => {
552 warn!("Error reading from socket: {}", e);
553 sockets = remaining_sockets;
554 continue;
555 }
556 };
557
558 let mut buf = vec![0u8; 3];
559 let peer = match socket.recv_from(&mut buf).await {
560 Ok((_, peer)) => peer,
561 Err(e) => {
562 warn!("Error receiving from socket: {}", e);
563 sockets = remaining_sockets;
564 continue;
565 }
566 };
567
568 if puncher && handshake_key[0..3] == buf[0..3] {
569 if let Ok(local_addr) = socket.local_addr() {
570 debug!(
571 "Confirming p2p channel peer {}:{} -> {}:{}",
572 local_addr.ip(),
573 local_addr.port(),
574 peer.ip(),
575 peer.port()
576 );
577 }
578 if let Err(e) = socket.send_to(&handshake_key[3..6], peer).await {
579 warn!("Error sending to peer: {}", e);
580 sockets = remaining_sockets;
581 continue;
582 }
583 } else if handshake_key[3..6] == buf[0..3] {
584 } else {
585 warn!("Invalid response from peer");
586 sockets = remaining_sockets;
587 continue;
588 };
589 #[cfg(unix)]
590 no_file_limit.and_then(|n| rlimit::increase_nofile_limit(n).ok());
591 return Ok((socket, peer));
592 }
593}