1use std::collections::HashMap;
40use std::net::SocketAddr;
41use std::sync::{Arc, RwLock};
42use std::time::Duration;
43
44use bytes::Bytes;
45use futures_util::StreamExt;
46use tokio::sync::broadcast;
47use tracing::{debug, warn};
48
49use crate::high_level::{
50 Connection as HighLevelConnection, RecvStream as HighLevelRecvStream,
51 SendStream as HighLevelSendStream,
52};
53use crate::link_transport::{
54 BoxFuture, BoxStream, Capabilities, ConnectionStats, DisconnectReason, Incoming, LinkConn,
55 LinkError, LinkEvent, LinkRecvStream, LinkResult, LinkSendStream, LinkTransport, ProtocolId,
56};
57use crate::nat_traversal_api::PeerId;
58use crate::p2p_endpoint::{P2pEndpoint, P2pEvent};
59use crate::unified_config::P2pConfig;
60
61pub struct P2pLinkConn {
67 inner: HighLevelConnection,
69 peer_id: PeerId,
71 remote_addr: SocketAddr,
73 connected_at: std::time::Instant,
75}
76
77impl P2pLinkConn {
78 pub fn new(inner: HighLevelConnection, peer_id: PeerId, remote_addr: SocketAddr) -> Self {
80 Self {
81 inner,
82 peer_id,
83 remote_addr,
84 connected_at: std::time::Instant::now(),
85 }
86 }
87
88 pub fn inner(&self) -> &HighLevelConnection {
90 &self.inner
91 }
92}
93
94impl LinkConn for P2pLinkConn {
95 fn peer(&self) -> PeerId {
96 self.peer_id
97 }
98
99 fn remote_addr(&self) -> SocketAddr {
100 self.remote_addr
101 }
102
103 fn open_uni(&self) -> BoxFuture<'_, LinkResult<Box<dyn LinkSendStream>>> {
104 Box::pin(async move {
105 let stream = self
106 .inner
107 .open_uni()
108 .await
109 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
110 Ok(Box::new(P2pSendStream::new(stream)) as Box<dyn LinkSendStream>)
111 })
112 }
113
114 fn open_bi(
115 &self,
116 ) -> BoxFuture<'_, LinkResult<(Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>> {
117 Box::pin(async move {
118 let (send, recv) = self
119 .inner
120 .open_bi()
121 .await
122 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
123 Ok((
124 Box::new(P2pSendStream::new(send)) as Box<dyn LinkSendStream>,
125 Box::new(P2pRecvStream::new(recv)) as Box<dyn LinkRecvStream>,
126 ))
127 })
128 }
129
130 fn send_datagram(&self, data: Bytes) -> LinkResult<()> {
131 self.inner
132 .send_datagram(data)
133 .map_err(|e| LinkError::Io(e.to_string()))
134 }
135
136 fn recv_datagrams(&self) -> BoxStream<'_, Bytes> {
137 let conn = self.inner.clone();
139 Box::pin(futures_util::stream::unfold(conn, |conn| async move {
140 match conn.read_datagram().await {
141 Ok(data) => Some((data, conn)),
142 Err(_) => None,
143 }
144 }))
145 }
146
147 fn close(&self, error_code: u64, reason: &str) {
148 self.inner.close(
149 crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX),
150 reason.as_bytes(),
151 );
152 }
153
154 fn is_open(&self) -> bool {
155 self.inner.close_reason().is_none()
157 }
158
159 fn stats(&self) -> ConnectionStats {
160 let quic_stats = self.inner.stats();
161 ConnectionStats {
162 bytes_sent: quic_stats.udp_tx.bytes,
163 bytes_received: quic_stats.udp_rx.bytes,
164 rtt: quic_stats.path.rtt,
165 connected_duration: self.connected_at.elapsed(),
166 streams_opened: 0, packets_lost: quic_stats.path.lost_packets,
168 }
169 }
170}
171
172pub struct P2pSendStream {
178 inner: HighLevelSendStream,
179}
180
181impl P2pSendStream {
182 pub fn new(inner: HighLevelSendStream) -> Self {
184 Self { inner }
185 }
186}
187
188impl LinkSendStream for P2pSendStream {
189 fn write<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<usize>> {
190 Box::pin(async move {
191 self.inner
192 .write(data)
193 .await
194 .map_err(|e| LinkError::Io(e.to_string()))
195 })
196 }
197
198 fn write_all<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<()>> {
199 Box::pin(async move {
200 self.inner
201 .write_all(data)
202 .await
203 .map_err(|e| LinkError::Io(e.to_string()))
204 })
205 }
206
207 fn finish(&mut self) -> LinkResult<()> {
208 self.inner.finish().map_err(|_| LinkError::ConnectionClosed)
209 }
210
211 fn reset(&mut self, error_code: u64) -> LinkResult<()> {
212 let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
213 self.inner
214 .reset(code)
215 .map_err(|_| LinkError::ConnectionClosed)
216 }
217
218 fn id(&self) -> u64 {
219 self.inner.id().into()
220 }
221}
222
223pub struct P2pRecvStream {
229 inner: HighLevelRecvStream,
230}
231
232impl P2pRecvStream {
233 pub fn new(inner: HighLevelRecvStream) -> Self {
235 Self { inner }
236 }
237}
238
239impl LinkRecvStream for P2pRecvStream {
240 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> BoxFuture<'a, LinkResult<Option<usize>>> {
241 Box::pin(async move {
242 self.inner
243 .read(buf)
244 .await
245 .map_err(|e| LinkError::Io(e.to_string()))
246 })
247 }
248
249 fn read_to_end(&mut self, size_limit: usize) -> BoxFuture<'_, LinkResult<Vec<u8>>> {
250 Box::pin(async move {
251 self.inner
252 .read_to_end(size_limit)
253 .await
254 .map_err(|e| LinkError::Io(e.to_string()))
255 })
256 }
257
258 fn stop(&mut self, error_code: u64) -> LinkResult<()> {
259 let code = crate::VarInt::from_u64(error_code).unwrap_or(crate::VarInt::MAX);
260 self.inner
261 .stop(code)
262 .map_err(|_| LinkError::ConnectionClosed)
263 }
264
265 fn id(&self) -> u64 {
266 self.inner.id().into()
267 }
268}
269
270struct LinkTransportState {
276 protocols: Vec<ProtocolId>,
278 capabilities: HashMap<PeerId, Capabilities>,
280 event_tx: broadcast::Sender<LinkEvent>,
282}
283
284impl Default for LinkTransportState {
285 fn default() -> Self {
286 let (event_tx, _) = broadcast::channel(256);
287 Self {
288 protocols: vec![ProtocolId::DEFAULT],
289 capabilities: HashMap::new(),
290 event_tx,
291 }
292 }
293}
294
295pub struct P2pLinkTransport {
300 endpoint: Arc<P2pEndpoint>,
302 state: Arc<RwLock<LinkTransportState>>,
304}
305
306impl P2pLinkTransport {
307 pub async fn new(config: P2pConfig) -> Result<Self, crate::p2p_endpoint::EndpointError> {
309 let endpoint = Arc::new(P2pEndpoint::new(config).await?);
310 let state = Arc::new(RwLock::new(LinkTransportState::default()));
311
312 let endpoint_clone = endpoint.clone();
314 let state_clone = state.clone();
315 tokio::spawn(async move {
316 Self::event_forwarder(endpoint_clone, state_clone).await;
317 });
318
319 Ok(Self { endpoint, state })
320 }
321
322 pub fn from_endpoint(endpoint: Arc<P2pEndpoint>) -> Self {
324 let state = Arc::new(RwLock::new(LinkTransportState::default()));
325
326 let endpoint_clone = endpoint.clone();
328 let state_clone = state.clone();
329 tokio::spawn(async move {
330 Self::event_forwarder(endpoint_clone, state_clone).await;
331 });
332
333 Self { endpoint, state }
334 }
335
336 async fn event_forwarder(endpoint: Arc<P2pEndpoint>, state: Arc<RwLock<LinkTransportState>>) {
338 let mut rx = endpoint.subscribe();
339 loop {
340 match rx.recv().await {
341 Ok(event) => {
342 let link_event = match event {
343 P2pEvent::PeerConnected { peer_id, addr } => {
344 let caps = Capabilities::new_connected(addr);
345 if let Ok(mut state) = state.write() {
347 state.capabilities.insert(peer_id, caps.clone());
348 }
349 Some(LinkEvent::PeerConnected {
350 peer: peer_id,
351 caps,
352 })
353 }
354 P2pEvent::PeerDisconnected { peer_id, reason } => {
355 let disconnect_reason = match reason {
356 crate::p2p_endpoint::DisconnectReason::Normal => {
357 DisconnectReason::LocalClose
358 }
359 crate::p2p_endpoint::DisconnectReason::RemoteClosed => {
360 DisconnectReason::RemoteClose
361 }
362 crate::p2p_endpoint::DisconnectReason::Timeout => {
363 DisconnectReason::Timeout
364 }
365 crate::p2p_endpoint::DisconnectReason::ProtocolError(msg) => {
366 DisconnectReason::TransportError(msg)
367 }
368 crate::p2p_endpoint::DisconnectReason::AuthenticationFailed => {
369 DisconnectReason::TransportError(
370 "Authentication failed".to_string(),
371 )
372 }
373 crate::p2p_endpoint::DisconnectReason::ConnectionLost => {
374 DisconnectReason::Reset
375 }
376 };
377 if let Ok(mut state) = state.write() {
379 if let Some(caps) = state.capabilities.get_mut(&peer_id) {
380 caps.is_connected = false;
381 }
382 }
383 Some(LinkEvent::PeerDisconnected {
384 peer: peer_id,
385 reason: disconnect_reason,
386 })
387 }
388 P2pEvent::ExternalAddressDiscovered { addr } => {
389 Some(LinkEvent::ExternalAddressUpdated { addr })
390 }
391 _ => None,
392 };
393
394 if let Some(event) = link_event {
395 if let Ok(state) = state.read() {
396 let _ = state.event_tx.send(event);
397 }
398 }
399 }
400 Err(broadcast::error::RecvError::Lagged(n)) => {
401 warn!("Event forwarder lagged by {} events", n);
402 }
403 Err(broadcast::error::RecvError::Closed) => {
404 debug!("Event forwarder channel closed");
405 break;
406 }
407 }
408 }
409 }
410
411 pub fn endpoint(&self) -> &P2pEndpoint {
413 &self.endpoint
414 }
415}
416
417impl LinkTransport for P2pLinkTransport {
418 type Conn = P2pLinkConn;
419
420 fn local_peer(&self) -> PeerId {
421 self.endpoint.peer_id()
422 }
423
424 fn external_address(&self) -> Option<SocketAddr> {
425 self.endpoint.external_addr()
426 }
427
428 fn peer_table(&self) -> Vec<(PeerId, Capabilities)> {
429 self.state
430 .read()
431 .map(|state| {
432 state
433 .capabilities
434 .iter()
435 .map(|(k, v)| (*k, v.clone()))
436 .collect()
437 })
438 .unwrap_or_default()
439 }
440
441 fn peer_capabilities(&self, peer: &PeerId) -> Option<Capabilities> {
442 self.state
443 .read()
444 .ok()
445 .and_then(|state| state.capabilities.get(peer).cloned())
446 }
447
448 fn subscribe(&self) -> broadcast::Receiver<LinkEvent> {
449 self.state
450 .read()
451 .map(|state| state.event_tx.subscribe())
452 .unwrap_or_else(|_| {
453 let (tx, rx) = broadcast::channel(1);
454 drop(tx);
455 rx
456 })
457 }
458
459 fn accept(&self, _proto: ProtocolId) -> Incoming<Self::Conn> {
460 let endpoint = self.endpoint.clone();
463
464 Box::pin(futures_util::stream::unfold(
465 endpoint,
466 |endpoint| async move {
467 if let Some(peer_conn) = endpoint.accept().await {
469 if let Some(conn) = endpoint
471 .get_quic_connection(&peer_conn.peer_id)
472 .ok()
473 .flatten()
474 {
475 let link_conn =
476 P2pLinkConn::new(conn, peer_conn.peer_id, peer_conn.remote_addr);
477 Some((Ok(link_conn), endpoint))
478 } else {
479 Some((
481 Err(LinkError::ConnectionFailed(
482 "Connection not found".to_string(),
483 )),
484 endpoint,
485 ))
486 }
487 } else {
488 None
490 }
491 },
492 ))
493 }
494
495 fn dial(&self, peer: PeerId, _proto: ProtocolId) -> BoxFuture<'_, LinkResult<Self::Conn>> {
496 Box::pin(async move {
497 let addr = self.state.read().ok().and_then(|state| {
499 state
500 .capabilities
501 .get(&peer)
502 .and_then(|caps| caps.observed_addrs.first().copied())
503 });
504
505 match addr {
506 Some(addr) => {
507 let peer_conn = self
509 .endpoint
510 .connect(addr)
511 .await
512 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
513
514 let conn = self
516 .endpoint
517 .get_quic_connection(&peer_conn.peer_id)
518 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
519 .ok_or_else(|| {
520 LinkError::ConnectionFailed("Connection not found".to_string())
521 })?;
522
523 Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
524 }
525 None => Err(LinkError::PeerNotFound(format!("{:?}", peer))),
526 }
527 })
528 }
529
530 fn dial_addr(
531 &self,
532 addr: SocketAddr,
533 _proto: ProtocolId,
534 ) -> BoxFuture<'_, LinkResult<Self::Conn>> {
535 Box::pin(async move {
536 let peer_conn = self
538 .endpoint
539 .connect(addr)
540 .await
541 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?;
542
543 let conn = self
545 .endpoint
546 .get_quic_connection(&peer_conn.peer_id)
547 .map_err(|e| LinkError::ConnectionFailed(e.to_string()))?
548 .ok_or_else(|| LinkError::ConnectionFailed("Connection not found".to_string()))?;
549
550 Ok(P2pLinkConn::new(conn, peer_conn.peer_id, addr))
551 })
552 }
553
554 fn supported_protocols(&self) -> Vec<ProtocolId> {
555 self.state
556 .read()
557 .map(|state| state.protocols.clone())
558 .unwrap_or_default()
559 }
560
561 fn register_protocol(&self, proto: ProtocolId) {
562 if let Ok(mut state) = self.state.write() {
563 if !state.protocols.contains(&proto) {
564 state.protocols.push(proto);
565 }
566 }
567 }
568
569 fn unregister_protocol(&self, proto: ProtocolId) {
570 if let Ok(mut state) = self.state.write() {
571 state.protocols.retain(|p| p != &proto);
572 }
573 }
574
575 fn is_connected(&self, peer: &PeerId) -> bool {
576 self.state
577 .read()
578 .ok()
579 .and_then(|state| state.capabilities.get(peer).map(|caps| caps.is_connected))
580 .unwrap_or(false)
581 }
582
583 fn active_connections(&self) -> usize {
584 self.state
585 .read()
586 .map(|state| {
587 state
588 .capabilities
589 .values()
590 .filter(|caps| caps.is_connected)
591 .count()
592 })
593 .unwrap_or(0)
594 }
595
596 fn shutdown(&self) -> BoxFuture<'_, ()> {
597 Box::pin(async move {
598 self.endpoint.shutdown().await;
599 })
600 }
601}
602
603#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_protocol_id_constants() {
613 assert_eq!(ProtocolId::DEFAULT.to_string(), "ant-quic/default");
614 assert_eq!(ProtocolId::NAT_TRAVERSAL.to_string(), "ant-quic/nat");
615 assert_eq!(ProtocolId::RELAY.to_string(), "ant-quic/relay");
616 }
617
618 #[test]
619 fn test_capabilities_connected() {
620 let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid addr");
621 let caps = Capabilities::new_connected(addr);
622
623 assert!(caps.is_connected);
624 assert_eq!(caps.observed_addrs.len(), 1);
625 assert_eq!(caps.observed_addrs[0], addr);
626 }
627
628 #[test]
629 fn test_connection_stats_default() {
630 let stats = ConnectionStats::default();
631 assert_eq!(stats.bytes_sent, 0);
632 assert_eq!(stats.bytes_received, 0);
633 }
634
635 #[test]
636 fn test_link_transport_state_default() {
637 let state = LinkTransportState::default();
638 assert_eq!(state.protocols.len(), 1);
639 assert_eq!(state.protocols[0], ProtocolId::DEFAULT);
640 assert!(state.capabilities.is_empty());
641 }
642}