1use std::collections::HashMap;
40use std::net::SocketAddr;
41use std::sync::{Arc, RwLock};
42use std::time::Duration;
43
44use bytes::Bytes;
45use futures_util::StreamExt;
46use tokio::sync::broadcast;
47use tracing::{debug, warn};
48
49use crate::high_level::{
50 Connection as HighLevelConnection, RecvStream as HighLevelRecvStream,
51 SendStream as HighLevelSendStream,
52};
53use crate::link_transport::{
54 BoxFuture, BoxStream, Capabilities, ConnectionStats, DisconnectReason, Incoming, LinkConn,
55 LinkError, LinkEvent, LinkRecvStream, LinkResult, LinkSendStream, LinkTransport, ProtocolId,
56};
57use crate::nat_traversal_api::PeerId;
58use crate::p2p_endpoint::{P2pEndpoint, P2pEvent};
59use crate::unified_config::P2pConfig;
60
61pub struct P2pLinkConn {
67 inner: HighLevelConnection,
69 peer_id: PeerId,
71 remote_addr: SocketAddr,
73 connected_at: std::time::Instant,
75}
76
77impl P2pLinkConn {
78 pub fn new(inner: HighLevelConnection, peer_id: PeerId, remote_addr: SocketAddr) -> Self {
80 Self {
81 inner,
82 peer_id,
83 remote_addr,
84 connected_at: std::time::Instant::now(),
85 }
86 }
87
88 pub fn inner(&self) -> &HighLevelConnection {
90 &self.inner
91 }
92}
93
94impl LinkConn for P2pLinkConn {
95 fn peer(&self) -> PeerId {
96 self.peer_id
97 }
98
99 fn remote_addr(&self) -> SocketAddr {
100 self.remote_addr
101 }
102
103 fn open_uni(&self) -> BoxFuture<'_, LinkResult<Box<dyn LinkSendStream>>> {
104 Box::pin(async move {
105 let stream = self
106 .inner
107 .open_uni()
108 .await
109 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
110 Ok(Box::new(P2pSendStream::new(stream)) as Box<dyn LinkSendStream>)
111 })
112 }
113
114 fn open_bi(
115 &self,
116 ) -> BoxFuture<'_, LinkResult<(Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>> {
117 Box::pin(async move {
118 let (send, recv) = self
119 .inner
120 .open_bi()
121 .await
122 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
123 Ok((
124 Box::new(P2pSendStream::new(send)) as Box<dyn LinkSendStream>,
125 Box::new(P2pRecvStream::new(recv)) as Box<dyn LinkRecvStream>,
126 ))
127 })
128 }
129
130 fn send_datagram(&self, data: Bytes) -> LinkResult<()> {
131 self.inner
132 .send_datagram(data)
133 .map_err(|e| LinkError::Io(e.to_string()))
134 }
135
136 fn recv_datagrams(&self) -> BoxStream<'_, Bytes> {
137 let conn = self.inner.clone();
139 Box::pin(futures_util::stream::unfold(conn, |conn| async move {
140 match conn.read_datagram().await {
141 Ok(data) => Some((data, conn)),
142 Err(_) => None,
143 }
144 }))
145 }
146
147 fn close(&self, error_code: u64, reason: &str) {
148 self.inner.close(
149 crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX),
150 reason.as_bytes(),
151 );
152 }
153
154 fn is_open(&self) -> bool {
155 self.inner.close_reason().is_none()
157 }
158
159 fn stats(&self) -> ConnectionStats {
160 let quic_stats = self.inner.stats();
161 ConnectionStats {
162 bytes_sent: quic_stats.udp_tx.bytes,
163 bytes_received: quic_stats.udp_rx.bytes,
164 rtt: quic_stats.path.rtt,
165 connected_duration: self.connected_at.elapsed(),
166 streams_opened: 0, packets_lost: quic_stats.path.lost_packets,
168 }
169 }
170}
171
172pub struct P2pSendStream {
178 inner: HighLevelSendStream,
179}
180
181impl P2pSendStream {
182 pub fn new(inner: HighLevelSendStream) -> Self {
184 Self { inner }
185 }
186}
187
188impl LinkSendStream for P2pSendStream {
189 fn write<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<usize>> {
190 Box::pin(async move {
191 self.inner
192 .write(data)
193 .await
194 .map_err(|e| LinkError::Io(e.to_string()))
195 })
196 }
197
198 fn write_all<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<()>> {
199 Box::pin(async move {
200 self.inner
201 .write_all(data)
202 .await
203 .map_err(|e| LinkError::Io(e.to_string()))
204 })
205 }
206
207 fn finish(&mut self) -> LinkResult<()> {
208 self.inner
209 .finish()
210 .map_err(|_| LinkError::ConnectionClosed)
211 }
212
213 fn reset(&mut self, error_code: u64) -> LinkResult<()> {
214 let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
215 self.inner
216 .reset(code)
217 .map_err(|_| LinkError::ConnectionClosed)
218 }
219
220 fn id(&self) -> u64 {
221 self.inner.id().into()
222 }
223}
224
225pub struct P2pRecvStream {
231 inner: HighLevelRecvStream,
232}
233
234impl P2pRecvStream {
235 pub fn new(inner: HighLevelRecvStream) -> Self {
237 Self { inner }
238 }
239}
240
241impl LinkRecvStream for P2pRecvStream {
242 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> BoxFuture<'a, LinkResult<Option<usize>>> {
243 Box::pin(async move {
244 self.inner
245 .read(buf)
246 .await
247 .map_err(|e| LinkError::Io(e.to_string()))
248 })
249 }
250
251 fn read_to_end(&mut self, size_limit: usize) -> BoxFuture<'_, LinkResult<Vec<u8>>> {
252 Box::pin(async move {
253 self.inner
254 .read_to_end(size_limit)
255 .await
256 .map_err(|e| LinkError::Io(e.to_string()))
257 })
258 }
259
260 fn stop(&mut self, error_code: u64) -> LinkResult<()> {
261 let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
262 self.inner
263 .stop(code)
264 .map_err(|_| LinkError::ConnectionClosed)
265 }
266
267 fn id(&self) -> u64 {
268 self.inner.id().into()
269 }
270}
271
272struct LinkTransportState {
278 protocols: Vec<ProtocolId>,
280 capabilities: HashMap<PeerId, Capabilities>,
282 event_tx: broadcast::Sender<LinkEvent>,
284}
285
286impl Default for LinkTransportState {
287 fn default() -> Self {
288 let (event_tx, _) = broadcast::channel(256);
289 Self {
290 protocols: vec![ProtocolId::DEFAULT],
291 capabilities: HashMap::new(),
292 event_tx,
293 }
294 }
295}
296
297pub struct P2pLinkTransport {
302 endpoint: Arc<P2pEndpoint>,
304 state: Arc<RwLock<LinkTransportState>>,
306}
307
308impl P2pLinkTransport {
309 pub async fn new(config: P2pConfig) -> Result<Self, crate::p2p_endpoint::EndpointError> {
311 let endpoint = Arc::new(P2pEndpoint::new(config).await?);
312 let state = Arc::new(RwLock::new(LinkTransportState::default()));
313
314 let endpoint_clone = endpoint.clone();
316 let state_clone = state.clone();
317 tokio::spawn(async move {
318 Self::event_forwarder(endpoint_clone, state_clone).await;
319 });
320
321 Ok(Self { endpoint, state })
322 }
323
324 pub fn from_endpoint(endpoint: Arc<P2pEndpoint>) -> Self {
326 let state = Arc::new(RwLock::new(LinkTransportState::default()));
327
328 let endpoint_clone = endpoint.clone();
330 let state_clone = state.clone();
331 tokio::spawn(async move {
332 Self::event_forwarder(endpoint_clone, state_clone).await;
333 });
334
335 Self { endpoint, state }
336 }
337
338 async fn event_forwarder(
340 endpoint: Arc<P2pEndpoint>,
341 state: Arc<RwLock<LinkTransportState>>,
342 ) {
343 let mut rx = endpoint.subscribe();
344 loop {
345 match rx.recv().await {
346 Ok(event) => {
347 let link_event = match event {
348 P2pEvent::PeerConnected { peer_id, addr } => {
349 let caps = Capabilities::new_connected(addr);
350 if let Ok(mut state) = state.write() {
352 state.capabilities.insert(peer_id, caps.clone());
353 }
354 Some(LinkEvent::PeerConnected { peer: peer_id, caps })
355 }
356 P2pEvent::PeerDisconnected { peer_id, reason } => {
357 let disconnect_reason = match reason {
358 crate::p2p_endpoint::DisconnectReason::Normal => {
359 DisconnectReason::LocalClose
360 }
361 crate::p2p_endpoint::DisconnectReason::RemoteClosed => {
362 DisconnectReason::RemoteClose
363 }
364 crate::p2p_endpoint::DisconnectReason::Timeout => {
365 DisconnectReason::Timeout
366 }
367 crate::p2p_endpoint::DisconnectReason::ProtocolError(msg) => {
368 DisconnectReason::TransportError(msg)
369 }
370 crate::p2p_endpoint::DisconnectReason::AuthenticationFailed => {
371 DisconnectReason::TransportError(
372 "Authentication failed".to_string(),
373 )
374 }
375 crate::p2p_endpoint::DisconnectReason::ConnectionLost => {
376 DisconnectReason::Reset
377 }
378 };
379 if let Ok(mut state) = state.write() {
381 if let Some(caps) = state.capabilities.get_mut(&peer_id) {
382 caps.is_connected = false;
383 }
384 }
385 Some(LinkEvent::PeerDisconnected {
386 peer: peer_id,
387 reason: disconnect_reason,
388 })
389 }
390 P2pEvent::ExternalAddressDiscovered { addr } => {
391 Some(LinkEvent::ExternalAddressUpdated { addr })
392 }
393 _ => None,
394 };
395
396 if let Some(event) = link_event {
397 if let Ok(state) = state.read() {
398 let _ = state.event_tx.send(event);
399 }
400 }
401 }
402 Err(broadcast::error::RecvError::Lagged(n)) => {
403 warn!("Event forwarder lagged by {} events", n);
404 }
405 Err(broadcast::error::RecvError::Closed) => {
406 debug!("Event forwarder channel closed");
407 break;
408 }
409 }
410 }
411 }
412
413 pub fn endpoint(&self) -> &P2pEndpoint {
415 &self.endpoint
416 }
417}
418
419impl LinkTransport for P2pLinkTransport {
420 type Conn = P2pLinkConn;
421
422 fn local_peer(&self) -> PeerId {
423 self.endpoint.peer_id()
424 }
425
426 fn external_address(&self) -> Option<SocketAddr> {
427 self.endpoint.external_addr()
428 }
429
430 fn peer_table(&self) -> Vec<(PeerId, Capabilities)> {
431 self.state
432 .read()
433 .map(|state| {
434 state
435 .capabilities
436 .iter()
437 .map(|(k, v)| (*k, v.clone()))
438 .collect()
439 })
440 .unwrap_or_default()
441 }
442
443 fn peer_capabilities(&self, peer: &PeerId) -> Option<Capabilities> {
444 self.state
445 .read()
446 .ok()
447 .and_then(|state| state.capabilities.get(peer).cloned())
448 }
449
450 fn subscribe(&self) -> broadcast::Receiver<LinkEvent> {
451 self.state
452 .read()
453 .map(|state| state.event_tx.subscribe())
454 .unwrap_or_else(|_| {
455 let (tx, rx) = broadcast::channel(1);
456 drop(tx);
457 rx
458 })
459 }
460
461 fn accept(&self, _proto: ProtocolId) -> Incoming<Self::Conn> {
462 let endpoint = self.endpoint.clone();
465
466 Box::pin(futures_util::stream::unfold(endpoint, |endpoint| async move {
467 if let Some(peer_conn) = endpoint.accept().await {
469 if let Some(conn) = endpoint
471 .get_quic_connection(&peer_conn.peer_id)
472 .ok()
473 .flatten()
474 {
475 let link_conn =
476 P2pLinkConn::new(conn, peer_conn.peer_id, peer_conn.remote_addr);
477 Some((Ok(link_conn), endpoint))
478 } else {
479 Some((
481 Err(LinkError::ConnectionFailed(
482 "Connection not found".to_string(),
483 )),
484 endpoint,
485 ))
486 }
487 } else {
488 None
490 }
491 }))
492 }
493
494 fn dial(&self, peer: PeerId, _proto: ProtocolId) -> BoxFuture<'_, LinkResult<Self::Conn>> {
495 Box::pin(async move {
496 let addr = self
498 .state
499 .read()
500 .ok()
501 .and_then(|state| {
502 state
503 .capabilities
504 .get(&peer)
505 .and_then(|caps| caps.observed_addrs.first().copied())
506 });
507
508 match addr {
509 Some(addr) => {
510 let peer_conn = self
512 .endpoint
513 .connect(addr)
514 .await
515 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
516
517 let conn = self
519 .endpoint
520 .get_quic_connection(&peer_conn.peer_id)
521 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
522 .ok_or_else(|| {
523 LinkError::ConnectionFailed("Connection not found".to_string())
524 })?;
525
526 Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
527 }
528 None => Err(LinkError::PeerNotFound(format!("{:?}", peer))),
529 }
530 })
531 }
532
533 fn dial_addr(
534 &self,
535 addr: SocketAddr,
536 _proto: ProtocolId,
537 ) -> BoxFuture<'_, LinkResult<Self::Conn>> {
538 Box::pin(async move {
539 let peer_conn = self
541 .endpoint
542 .connect(addr)
543 .await
544 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
545
546 let conn = self
548 .endpoint
549 .get_quic_connection(&peer_conn.peer_id)
550 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
551 .ok_or_else(|| LinkError::ConnectionFailed("Connection not found".to_string()))?;
552
553 Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
554 })
555 }
556
557 fn supported_protocols(&self) -> Vec<ProtocolId> {
558 self.state
559 .read()
560 .map(|state| state.protocols.clone())
561 .unwrap_or_default()
562 }
563
564 fn register_protocol(&self, proto: ProtocolId) {
565 if let Ok(mut state) = self.state.write() {
566 if !state.protocols.contains(&proto) {
567 state.protocols.push(proto);
568 }
569 }
570 }
571
572 fn unregister_protocol(&self, proto: ProtocolId) {
573 if let Ok(mut state) = self.state.write() {
574 state.protocols.retain(|p| p != &proto);
575 }
576 }
577
578 fn is_connected(&self, peer: &PeerId) -> bool {
579 self.state
580 .read()
581 .ok()
582 .and_then(|state| state.capabilities.get(peer).map(|caps| caps.is_connected))
583 .unwrap_or(false)
584 }
585
586 fn active_connections(&self) -> usize {
587 self.state
588 .read()
589 .map(|state| {
590 state
591 .capabilities
592 .values()
593 .filter(|caps| caps.is_connected)
594 .count()
595 })
596 .unwrap_or(0)
597 }
598
599 fn shutdown(&self) -> BoxFuture<'_, ()> {
600 Box::pin(async move {
601 self.endpoint.shutdown().await;
602 })
603 }
604}
605
606#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_protocol_id_constants() {
616 assert_eq!(ProtocolId::DEFAULT.to_string(), "ant-quic/default");
617 assert_eq!(ProtocolId::NAT_TRAVERSAL.to_string(), "ant-quic/nat");
618 assert_eq!(ProtocolId::RELAY.to_string(), "ant-quic/relay");
619 }
620
621 #[test]
622 fn test_capabilities_connected() {
623 let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid addr");
624 let caps = Capabilities::new_connected(addr);
625
626 assert!(caps.is_connected);
627 assert_eq!(caps.observed_addrs.len(), 1);
628 assert_eq!(caps.observed_addrs[0], addr);
629 }
630
631 #[test]
632 fn test_connection_stats_default() {
633 let stats = ConnectionStats::default();
634 assert_eq!(stats.bytes_sent, 0);
635 assert_eq!(stats.bytes_received, 0);
636 }
637
638 #[test]
639 fn test_link_transport_state_default() {
640 let state = LinkTransportState::default();
641 assert_eq!(state.protocols.len(), 1);
642 assert_eq!(state.protocols[0], ProtocolId::DEFAULT);
643 assert!(state.capabilities.is_empty());
644 }
645}