ant_quic/
link_transport_impl.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! # P2pEndpoint LinkTransport Implementation
9//!
10//! This module provides the concrete implementation of [`LinkTransport`] and [`LinkConn`]
11//! for [`P2pEndpoint`], bridging the high-level P2P API with the transport abstraction layer.
12//!
13//! ## Usage
14//!
15//! ```rust,ignore
16//! use ant_quic::{P2pConfig, P2pLinkTransport};
17//! use ant_quic::link_transport::{LinkTransport, ProtocolId};
18//!
19//! #[tokio::main]
20//! async fn main() -> anyhow::Result<()> {
21//!     let config = P2pConfig::builder()
22//!         .bind_addr("0.0.0.0:0".parse()?)
23//!         .build()?;
24//!
25//!     let transport = P2pLinkTransport::new(config).await?;
26//!
27//!     // Use as LinkTransport
28//!     let local_peer = transport.local_peer();
29//!     let peers = transport.peer_table();
30//!
31//!     // Dial with protocol
32//!     let proto = ProtocolId::from("my-app/1.0");
33//!     let conn = transport.dial_addr("127.0.0.1:9000".parse()?, proto).await?;
34//!
35//!     Ok(())
36//! }
37//! ```
38
39use std::collections::HashMap;
40use std::net::SocketAddr;
41use std::sync::{Arc, RwLock};
42use std::time::Duration;
43
44use bytes::Bytes;
45use futures_util::StreamExt;
46use tokio::sync::broadcast;
47use tracing::{debug, warn};
48
49use crate::high_level::{
50    Connection as HighLevelConnection, RecvStream as HighLevelRecvStream,
51    SendStream as HighLevelSendStream,
52};
53use crate::link_transport::{
54    BoxFuture, BoxStream, Capabilities, ConnectionStats, DisconnectReason, Incoming, LinkConn,
55    LinkError, LinkEvent, LinkRecvStream, LinkResult, LinkSendStream, LinkTransport, ProtocolId,
56};
57use crate::nat_traversal_api::PeerId;
58use crate::p2p_endpoint::{P2pEndpoint, P2pEvent};
59use crate::unified_config::P2pConfig;
60
61// ============================================================================
62// P2pLinkConn - Connection wrapper
63// ============================================================================
64
65/// A [`LinkConn`] implementation wrapping a high-level QUIC connection.
66pub struct P2pLinkConn {
67    /// The underlying QUIC connection.
68    inner: HighLevelConnection,
69    /// Remote peer ID.
70    peer_id: PeerId,
71    /// Remote address.
72    remote_addr: SocketAddr,
73    /// Connection start time.
74    connected_at: std::time::Instant,
75}
76
77impl P2pLinkConn {
78    /// Create a new connection wrapper.
79    pub fn new(inner: HighLevelConnection, peer_id: PeerId, remote_addr: SocketAddr) -> Self {
80        Self {
81            inner,
82            peer_id,
83            remote_addr,
84            connected_at: std::time::Instant::now(),
85        }
86    }
87
88    /// Get the underlying connection.
89    pub fn inner(&self) -> &HighLevelConnection {
90        &self.inner
91    }
92}
93
94impl LinkConn for P2pLinkConn {
95    fn peer(&self) -> PeerId {
96        self.peer_id
97    }
98
99    fn remote_addr(&self) -> SocketAddr {
100        self.remote_addr
101    }
102
103    fn open_uni(&self) -> BoxFuture<'_, LinkResult<Box<dyn LinkSendStream>>> {
104        Box::pin(async move {
105            let stream = self
106                .inner
107                .open_uni()
108                .await
109                .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
110            Ok(Box::new(P2pSendStream::new(stream)) as Box<dyn LinkSendStream>)
111        })
112    }
113
114    fn open_bi(
115        &self,
116    ) -> BoxFuture<'_, LinkResult<(Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>> {
117        Box::pin(async move {
118            let (send, recv) = self
119                .inner
120                .open_bi()
121                .await
122                .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
123            Ok((
124                Box::new(P2pSendStream::new(send)) as Box<dyn LinkSendStream>,
125                Box::new(P2pRecvStream::new(recv)) as Box<dyn LinkRecvStream>,
126            ))
127        })
128    }
129
130    fn send_datagram(&self, data: Bytes) -> LinkResult<()> {
131        self.inner
132            .send_datagram(data)
133            .map_err(|e| LinkError::Io(e.to_string()))
134    }
135
136    fn recv_datagrams(&self) -> BoxStream<'_, Bytes> {
137        // Create a stream that polls for datagrams
138        let conn = self.inner.clone();
139        Box::pin(futures_util::stream::unfold(conn, |conn| async move {
140            match conn.read_datagram().await {
141                Ok(data) => Some((data, conn)),
142                Err(_) => None,
143            }
144        }))
145    }
146
147    fn close(&self, error_code: u64, reason: &str) {
148        self.inner.close(
149            crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX),
150            reason.as_bytes(),
151        );
152    }
153
154    fn is_open(&self) -> bool {
155        // Check if connection is still alive by examining the close reason
156        self.inner.close_reason().is_none()
157    }
158
159    fn stats(&self) -> ConnectionStats {
160        let quic_stats = self.inner.stats();
161        ConnectionStats {
162            bytes_sent: quic_stats.udp_tx.bytes,
163            bytes_received: quic_stats.udp_rx.bytes,
164            rtt: quic_stats.path.rtt,
165            connected_duration: self.connected_at.elapsed(),
166            streams_opened: 0, // Would need to track this separately
167            packets_lost: quic_stats.path.lost_packets,
168        }
169    }
170}
171
172// ============================================================================
173// P2pSendStream - Send stream wrapper
174// ============================================================================
175
176/// A [`LinkSendStream`] implementation wrapping a high-level send stream.
177pub struct P2pSendStream {
178    inner: HighLevelSendStream,
179}
180
181impl P2pSendStream {
182    /// Create a new send stream wrapper.
183    pub fn new(inner: HighLevelSendStream) -> Self {
184        Self { inner }
185    }
186}
187
188impl LinkSendStream for P2pSendStream {
189    fn write<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<usize>> {
190        Box::pin(async move {
191            self.inner
192                .write(data)
193                .await
194                .map_err(|e| LinkError::Io(e.to_string()))
195        })
196    }
197
198    fn write_all<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<()>> {
199        Box::pin(async move {
200            self.inner
201                .write_all(data)
202                .await
203                .map_err(|e| LinkError::Io(e.to_string()))
204        })
205    }
206
207    fn finish(&mut self) -> LinkResult<()> {
208        self.inner
209            .finish()
210            .map_err(|_| LinkError::ConnectionClosed)
211    }
212
213    fn reset(&mut self, error_code: u64) -> LinkResult<()> {
214        let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
215        self.inner
216            .reset(code)
217            .map_err(|_| LinkError::ConnectionClosed)
218    }
219
220    fn id(&self) -> u64 {
221        self.inner.id().into()
222    }
223}
224
225// ============================================================================
226// P2pRecvStream - Receive stream wrapper
227// ============================================================================
228
229/// A [`LinkRecvStream`] implementation wrapping a high-level receive stream.
230pub struct P2pRecvStream {
231    inner: HighLevelRecvStream,
232}
233
234impl P2pRecvStream {
235    /// Create a new receive stream wrapper.
236    pub fn new(inner: HighLevelRecvStream) -> Self {
237        Self { inner }
238    }
239}
240
241impl LinkRecvStream for P2pRecvStream {
242    fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> BoxFuture<'a, LinkResult<Option<usize>>> {
243        Box::pin(async move {
244            self.inner
245                .read(buf)
246                .await
247                .map_err(|e| LinkError::Io(e.to_string()))
248        })
249    }
250
251    fn read_to_end(&mut self, size_limit: usize) -> BoxFuture<'_, LinkResult<Vec<u8>>> {
252        Box::pin(async move {
253            self.inner
254                .read_to_end(size_limit)
255                .await
256                .map_err(|e| LinkError::Io(e.to_string()))
257        })
258    }
259
260    fn stop(&mut self, error_code: u64) -> LinkResult<()> {
261        let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
262        self.inner
263            .stop(code)
264            .map_err(|_| LinkError::ConnectionClosed)
265    }
266
267    fn id(&self) -> u64 {
268        self.inner.id().into()
269    }
270}
271
272// ============================================================================
273// P2pLinkTransport - LinkTransport Implementation
274// ============================================================================
275
276/// Internal state for the LinkTransport implementation.
277struct LinkTransportState {
278    /// Registered protocols.
279    protocols: Vec<ProtocolId>,
280    /// Peer capabilities cache.
281    capabilities: HashMap<PeerId, Capabilities>,
282    /// Event broadcaster for LinkEvents.
283    event_tx: broadcast::Sender<LinkEvent>,
284}
285
286impl Default for LinkTransportState {
287    fn default() -> Self {
288        let (event_tx, _) = broadcast::channel(256);
289        Self {
290            protocols: vec![ProtocolId::DEFAULT],
291            capabilities: HashMap::new(),
292            event_tx,
293        }
294    }
295}
296
297/// A [`LinkTransport`] implementation wrapping [`P2pEndpoint`].
298///
299/// This provides a stable abstraction layer for overlay networks to use,
300/// decoupling them from specific ant-quic versions.
301pub struct P2pLinkTransport {
302    /// The underlying P2pEndpoint.
303    endpoint: Arc<P2pEndpoint>,
304    /// Additional state for LinkTransport.
305    state: Arc<RwLock<LinkTransportState>>,
306}
307
308impl P2pLinkTransport {
309    /// Create a new LinkTransport from a P2pConfig.
310    pub async fn new(config: P2pConfig) -> Result<Self, crate::p2p_endpoint::EndpointError> {
311        let endpoint = Arc::new(P2pEndpoint::new(config).await?);
312        let state = Arc::new(RwLock::new(LinkTransportState::default()));
313
314        // Spawn event forwarder
315        let endpoint_clone = endpoint.clone();
316        let state_clone = state.clone();
317        tokio::spawn(async move {
318            Self::event_forwarder(endpoint_clone, state_clone).await;
319        });
320
321        Ok(Self { endpoint, state })
322    }
323
324    /// Create from an existing P2pEndpoint.
325    pub fn from_endpoint(endpoint: Arc<P2pEndpoint>) -> Self {
326        let state = Arc::new(RwLock::new(LinkTransportState::default()));
327
328        // Spawn event forwarder
329        let endpoint_clone = endpoint.clone();
330        let state_clone = state.clone();
331        tokio::spawn(async move {
332            Self::event_forwarder(endpoint_clone, state_clone).await;
333        });
334
335        Self { endpoint, state }
336    }
337
338    /// Forward P2pEvents to LinkEvents.
339    async fn event_forwarder(
340        endpoint: Arc<P2pEndpoint>,
341        state: Arc<RwLock<LinkTransportState>>,
342    ) {
343        let mut rx = endpoint.subscribe();
344        loop {
345            match rx.recv().await {
346                Ok(event) => {
347                    let link_event = match event {
348                        P2pEvent::PeerConnected { peer_id, addr } => {
349                            let caps = Capabilities::new_connected(addr);
350                            // Update capabilities cache
351                            if let Ok(mut state) = state.write() {
352                                state.capabilities.insert(peer_id, caps.clone());
353                            }
354                            Some(LinkEvent::PeerConnected { peer: peer_id, caps })
355                        }
356                        P2pEvent::PeerDisconnected { peer_id, reason } => {
357                            let disconnect_reason = match reason {
358                                crate::p2p_endpoint::DisconnectReason::Normal => {
359                                    DisconnectReason::LocalClose
360                                }
361                                crate::p2p_endpoint::DisconnectReason::RemoteClosed => {
362                                    DisconnectReason::RemoteClose
363                                }
364                                crate::p2p_endpoint::DisconnectReason::Timeout => {
365                                    DisconnectReason::Timeout
366                                }
367                                crate::p2p_endpoint::DisconnectReason::ProtocolError(msg) => {
368                                    DisconnectReason::TransportError(msg)
369                                }
370                                crate::p2p_endpoint::DisconnectReason::AuthenticationFailed => {
371                                    DisconnectReason::TransportError(
372                                        "Authentication failed".to_string(),
373                                    )
374                                }
375                                crate::p2p_endpoint::DisconnectReason::ConnectionLost => {
376                                    DisconnectReason::Reset
377                                }
378                            };
379                            // Update capabilities cache
380                            if let Ok(mut state) = state.write() {
381                                if let Some(caps) = state.capabilities.get_mut(&peer_id) {
382                                    caps.is_connected = false;
383                                }
384                            }
385                            Some(LinkEvent::PeerDisconnected {
386                                peer: peer_id,
387                                reason: disconnect_reason,
388                            })
389                        }
390                        P2pEvent::ExternalAddressDiscovered { addr } => {
391                            Some(LinkEvent::ExternalAddressUpdated { addr })
392                        }
393                        _ => None,
394                    };
395
396                    if let Some(event) = link_event {
397                        if let Ok(state) = state.read() {
398                            let _ = state.event_tx.send(event);
399                        }
400                    }
401                }
402                Err(broadcast::error::RecvError::Lagged(n)) => {
403                    warn!("Event forwarder lagged by {} events", n);
404                }
405                Err(broadcast::error::RecvError::Closed) => {
406                    debug!("Event forwarder channel closed");
407                    break;
408                }
409            }
410        }
411    }
412
413    /// Get the underlying P2pEndpoint.
414    pub fn endpoint(&self) -> &P2pEndpoint {
415        &self.endpoint
416    }
417}
418
419impl LinkTransport for P2pLinkTransport {
420    type Conn = P2pLinkConn;
421
422    fn local_peer(&self) -> PeerId {
423        self.endpoint.peer_id()
424    }
425
426    fn external_address(&self) -> Option<SocketAddr> {
427        self.endpoint.external_addr()
428    }
429
430    fn peer_table(&self) -> Vec<(PeerId, Capabilities)> {
431        self.state
432            .read()
433            .map(|state| {
434                state
435                    .capabilities
436                    .iter()
437                    .map(|(k, v)| (*k, v.clone()))
438                    .collect()
439            })
440            .unwrap_or_default()
441    }
442
443    fn peer_capabilities(&self, peer: &PeerId) -> Option<Capabilities> {
444        self.state
445            .read()
446            .ok()
447            .and_then(|state| state.capabilities.get(peer).cloned())
448    }
449
450    fn subscribe(&self) -> broadcast::Receiver<LinkEvent> {
451        self.state
452            .read()
453            .map(|state| state.event_tx.subscribe())
454            .unwrap_or_else(|_| {
455                let (tx, rx) = broadcast::channel(1);
456                drop(tx);
457                rx
458            })
459    }
460
461    fn accept(&self, _proto: ProtocolId) -> Incoming<Self::Conn> {
462        // TODO: Implement protocol-based accept filtering
463        // For now, accept all incoming connections
464        let endpoint = self.endpoint.clone();
465
466        Box::pin(futures_util::stream::unfold(endpoint, |endpoint| async move {
467            // Wait for an incoming connection
468            if let Some(peer_conn) = endpoint.accept().await {
469                // Get the underlying QUIC connection
470                if let Some(conn) = endpoint
471                    .get_quic_connection(&peer_conn.peer_id)
472                    .ok()
473                    .flatten()
474                {
475                    let link_conn =
476                        P2pLinkConn::new(conn, peer_conn.peer_id, peer_conn.remote_addr);
477                    Some((Ok(link_conn), endpoint))
478                } else {
479                    // Connection not found, try again
480                    Some((
481                        Err(LinkError::ConnectionFailed(
482                            "Connection not found".to_string(),
483                        )),
484                        endpoint,
485                    ))
486                }
487            } else {
488                // Endpoint is shutting down
489                None
490            }
491        }))
492    }
493
494    fn dial(&self, peer: PeerId, _proto: ProtocolId) -> BoxFuture<'_, LinkResult<Self::Conn>> {
495        Box::pin(async move {
496            // Look up peer address from capabilities
497            let addr = self
498                .state
499                .read()
500                .ok()
501                .and_then(|state| {
502                    state
503                        .capabilities
504                        .get(&peer)
505                        .and_then(|caps| caps.observed_addrs.first().copied())
506                });
507
508            match addr {
509                Some(addr) => {
510                    // Connect through P2pEndpoint
511                    let peer_conn = self
512                        .endpoint
513                        .connect(addr)
514                        .await
515                        .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
516
517                    // Get the underlying QUIC connection
518                    let conn = self
519                        .endpoint
520                        .get_quic_connection(&peer_conn.peer_id)
521                        .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
522                        .ok_or_else(|| {
523                            LinkError::ConnectionFailed("Connection not found".to_string())
524                        })?;
525
526                    Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
527                }
528                None => Err(LinkError::PeerNotFound(format!("{:?}", peer))),
529            }
530        })
531    }
532
533    fn dial_addr(
534        &self,
535        addr: SocketAddr,
536        _proto: ProtocolId,
537    ) -> BoxFuture<'_, LinkResult<Self::Conn>> {
538        Box::pin(async move {
539            // Connect through P2pEndpoint
540            let peer_conn = self
541                .endpoint
542                .connect(addr)
543                .await
544                .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
545
546            // Get the underlying QUIC connection
547            let conn = self
548                .endpoint
549                .get_quic_connection(&peer_conn.peer_id)
550                .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
551                .ok_or_else(|| LinkError::ConnectionFailed("Connection not found".to_string()))?;
552
553            Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
554        })
555    }
556
557    fn supported_protocols(&self) -> Vec<ProtocolId> {
558        self.state
559            .read()
560            .map(|state| state.protocols.clone())
561            .unwrap_or_default()
562    }
563
564    fn register_protocol(&self, proto: ProtocolId) {
565        if let Ok(mut state) = self.state.write() {
566            if !state.protocols.contains(&proto) {
567                state.protocols.push(proto);
568            }
569        }
570    }
571
572    fn unregister_protocol(&self, proto: ProtocolId) {
573        if let Ok(mut state) = self.state.write() {
574            state.protocols.retain(|p| p != &proto);
575        }
576    }
577
578    fn is_connected(&self, peer: &PeerId) -> bool {
579        self.state
580            .read()
581            .ok()
582            .and_then(|state| state.capabilities.get(peer).map(|caps| caps.is_connected))
583            .unwrap_or(false)
584    }
585
586    fn active_connections(&self) -> usize {
587        self.state
588            .read()
589            .map(|state| {
590                state
591                    .capabilities
592                    .values()
593                    .filter(|caps| caps.is_connected)
594                    .count()
595            })
596            .unwrap_or(0)
597    }
598
599    fn shutdown(&self) -> BoxFuture<'_, ()> {
600        Box::pin(async move {
601            self.endpoint.shutdown().await;
602        })
603    }
604}
605
606// ============================================================================
607// Tests
608// ============================================================================
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_protocol_id_constants() {
616        assert_eq!(ProtocolId::DEFAULT.to_string(), "ant-quic/default");
617        assert_eq!(ProtocolId::NAT_TRAVERSAL.to_string(), "ant-quic/nat");
618        assert_eq!(ProtocolId::RELAY.to_string(), "ant-quic/relay");
619    }
620
621    #[test]
622    fn test_capabilities_connected() {
623        let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid addr");
624        let caps = Capabilities::new_connected(addr);
625
626        assert!(caps.is_connected);
627        assert_eq!(caps.observed_addrs.len(), 1);
628        assert_eq!(caps.observed_addrs[0], addr);
629    }
630
631    #[test]
632    fn test_connection_stats_default() {
633        let stats = ConnectionStats::default();
634        assert_eq!(stats.bytes_sent, 0);
635        assert_eq!(stats.bytes_received, 0);
636    }
637
638    #[test]
639    fn test_link_transport_state_default() {
640        let state = LinkTransportState::default();
641        assert_eq!(state.protocols.len(), 1);
642        assert_eq!(state.protocols[0], ProtocolId::DEFAULT);
643        assert!(state.capabilities.is_empty());
644    }
645}