libp2p_iroh/
connection.rs1use std::{error::Error, fmt::Display, pin::Pin, task::Poll};
2
3use crate::{
4 TransportError,
5 stream::{Stream, StreamError},
6};
7use futures::{
8 FutureExt,
9 future::BoxFuture,
10};
11use iroh::
12 endpoint::{RecvStream, SendStream}
13;
14use libp2p_core::StreamMuxer;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16
17#[derive(Debug)]
18pub struct ConnectionError {
19 kind: ConnectionErrorKind,
20}
21
22#[derive(Debug)]
23pub enum ConnectionErrorKind {
24 Accept(String),
25 Open(String),
26 Stream(String),
27}
28
29impl Display for ConnectionError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "ConnectionError: {:?}", self.kind)
32 }
33}
34
35impl Error for ConnectionError {}
36
37impl From<iroh::endpoint::ConnectionError> for ConnectionError {
38 fn from(err: iroh::endpoint::ConnectionError) -> Self {
39 Self {
40 kind: ConnectionErrorKind::Accept(err.to_string()),
41 }
42 }
43}
44
45impl From<&str> for ConnectionError {
46 fn from(err: &str) -> Self {
47 Self {
48 kind: ConnectionErrorKind::Accept(err.to_string()),
49 }
50 }
51}
52
53impl From<StreamError> for ConnectionError {
54 fn from(err: StreamError) -> Self {
55 Self {
56 kind: ConnectionErrorKind::Stream(err.to_string()),
57 }
58 }
59}
60
61pub struct Connection {
62 connection: iroh::endpoint::Connection,
63 incoming: Option<BoxFuture<'static, Result<(SendStream, RecvStream), ConnectionError>>>,
64 outgoing: Option<BoxFuture<'static, Result<(SendStream, RecvStream), ConnectionError>>>,
65 closing: Option<BoxFuture<'static, ConnectionError>>,
66}
67
68pub struct Connecting {
69 pub connecting: BoxFuture<'static, Result<(libp2p::PeerId, iroh::endpoint::Connection), TransportError>>,
70}
71
72impl Connection {
73 pub fn new(connection: iroh::endpoint::Connection) -> Self {
74 tracing::debug!("Connection::new - Creating new connection wrapper");
75 Self {
76 connection,
77 incoming: None,
78 outgoing: None,
79 closing: None,
80 }
81 }
82}
83
84impl StreamMuxer for Connection {
85 type Substream = Stream;
86 type Error = ConnectionError;
87
88 fn poll_inbound(
89 self: Pin<&mut Self>,
90 cx: &mut std::task::Context<'_>,
91 ) -> Poll<Result<Self::Substream, Self::Error>> {
92 let this = self.get_mut();
93
94 let incoming = this.incoming.get_or_insert_with(|| {
95 tracing::debug!("Connection::poll_inbound - Setting up incoming stream future");
96 let connection = this.connection.clone();
97 async move {
98 tracing::debug!("Connection::poll_inbound - Accepting bidirectional stream");
99 match connection.accept_bi().await {
100 Ok((s, mut r)) => {
101 tracing::debug!("Connection::poll_inbound - Bidirectional stream accepted, reading handshake byte");
102 r.read_u8().await.map_err(|e| {
103 tracing::error!("Connection::poll_inbound - Failed to read handshake byte: {}", e);
104 ConnectionError::from("Failed to read from stream")
105 })?;
106 tracing::debug!("Connection::poll_inbound - Handshake byte read successfully");
107 Ok((s, r))
108 },
109 Err(e) => {
110 tracing::error!("Connection::poll_inbound - Failed to accept bidirectional stream: {}", e);
111 Err(ConnectionError::from("Iroh handshake failed during accept"))
112 }
113 }
114 }.boxed()
115 });
116
117 let (send, recv) = futures::ready!(incoming.poll_unpin(cx))?;
118 this.incoming.take();
119 tracing::debug!("Connection::poll_inbound - Inbound stream ready, creating Stream wrapper");
120 Poll::Ready(Stream::new(send, recv).map_err(Into::into))
121 }
122
123 fn poll_outbound(
124 self: Pin<&mut Self>,
125 cx: &mut std::task::Context<'_>,
126 ) -> Poll<Result<Self::Substream, Self::Error>> {
127 let this = self.get_mut();
128
129 let outgoing = this.outgoing.get_or_insert_with(|| {
130 tracing::debug!("Connection::poll_outbound - Setting up outgoing stream future");
131 let connection = this.connection.clone();
132 async move {
133 tracing::debug!("Connection::poll_outbound - Opening bidirectional stream");
134 match connection.open_bi().await {
135 Ok((mut s, r)) => {
136 tracing::debug!("Connection::poll_outbound - Bidirectional stream opened, writing handshake byte");
137 s.write_u8(0).await.map_err(|e| {
139 tracing::error!("Connection::poll_outbound - Failed to write handshake byte: {}", e);
140 ConnectionError::from("Failed to write to stream")
141 })?;
142 tracing::debug!("Connection::poll_outbound - Handshake byte written successfully");
143 Ok((s, r))
144 }
145 Err(e) => {
146 tracing::error!("Connection::poll_outbound - Failed to open bidirectional stream: {}", e);
147 Err(ConnectionError::from("Iroh handshake failed during open"))
148 }
149 }
150 }.boxed()
151 });
152
153 let (send, recv) = futures::ready!(outgoing.poll_unpin(cx))?;
154 this.outgoing.take();
155 tracing::debug!("Connection::poll_outbound - Outbound stream ready, creating Stream wrapper");
156 Poll::Ready(Stream::new(send, recv).map_err(Into::into))
157 }
158
159 fn poll_close(
160 self: Pin<&mut Self>,
161 cx: &mut std::task::Context<'_>,
162 ) -> Poll<Result<(), Self::Error>> {
163 let this = self.get_mut();
164
165 let closing = this.closing.get_or_insert_with(|| {
166 tracing::debug!("Connection::poll_close - Closing connection");
167 this.connection.close(From::from(0u32), &[]);
168 let connection = this.connection.clone();
169 async move {
170 tracing::debug!("Connection::poll_close - Waiting for connection to close");
171 connection.closed().await.into()
172 }.boxed()
173 });
174
175
176 if matches!(
177 futures::ready!(closing.poll_unpin(cx)),
178 crate::ConnectionError { .. }
179 ) {
180 tracing::error!("Connection::poll_close - Failed to close connection");
181 return Poll::Ready(Err("failed to close connection".into()));
182 };
183
184 tracing::debug!("Connection::poll_close - Connection closed successfully");
185 Poll::Ready(Ok(()))
186 }
187
188 fn poll(
189 self: Pin<&mut Self>,
190 _cx: &mut std::task::Context<'_>,
191 ) -> Poll<Result<libp2p_core::muxing::StreamMuxerEvent, Self::Error>> {
192 Poll::Pending
193 }
194}
195
196impl Future for Connecting {
197 type Output = Result<(libp2p::PeerId, libp2p_core::muxing::StreamMuxerBox), TransportError>;
198
199 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
200 tracing::debug!("Connecting::poll - Polling connection future");
201 let (peer_id, conn) = match self.connecting.poll_unpin(cx) {
202 Poll::Ready(Ok((peer_id, conn))) => {
203 tracing::debug!("Connecting::poll - Connection established");
204 (peer_id, conn)
205 },
206 Poll::Ready(Err(e)) => {
207 tracing::error!("Connecting::poll - Connection failed: {}", e);
208 return Poll::Ready(Err(e));
209 },
210 Poll::Pending => {
211 tracing::trace!("Connecting::poll - Connection still pending");
212 return Poll::Pending;
213 }
214 };
215
216 let muxer = Connection::new(conn);
217
218 tracing::debug!("Connecting::poll - Connection muxer created");
219 Poll::Ready(Ok((peer_id, libp2p_core::muxing::StreamMuxerBox::new(muxer))))
220 }
221}