libp2p_iroh/
connection.rs

1use std::{error::Error, fmt::Display, pin::Pin, task::Poll};
2
3use crate::{
4    TransportError,
5    stream::{Stream, StreamError},
6};
7use futures::{
8    FutureExt,
9    future::BoxFuture,
10};
11use iroh::
12    endpoint::{RecvStream, SendStream}
13;
14use libp2p_core::StreamMuxer;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16
17#[derive(Debug)]
18pub struct ConnectionError {
19    kind: ConnectionErrorKind,
20}
21
22#[derive(Debug)]
23pub enum ConnectionErrorKind {
24    Accept(String),
25    Open(String),
26    Stream(String),
27}
28
29impl Display for ConnectionError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "ConnectionError: {:?}", self.kind)
32    }
33}
34
35impl Error for ConnectionError {}
36
37impl From<iroh::endpoint::ConnectionError> for ConnectionError {
38    fn from(err: iroh::endpoint::ConnectionError) -> Self {
39        Self {
40            kind: ConnectionErrorKind::Accept(err.to_string()),
41        }
42    }
43}
44
45impl From<&str> for ConnectionError {
46    fn from(err: &str) -> Self {
47        Self {
48            kind: ConnectionErrorKind::Accept(err.to_string()),
49        }
50    }
51}
52
53impl From<StreamError> for ConnectionError {
54    fn from(err: StreamError) -> Self {
55        Self {
56            kind: ConnectionErrorKind::Stream(err.to_string()),
57        }
58    }
59}
60
61pub struct Connection {
62    connection: iroh::endpoint::Connection,
63    incoming: Option<BoxFuture<'static, Result<(SendStream, RecvStream), ConnectionError>>>,
64    outgoing: Option<BoxFuture<'static, Result<(SendStream, RecvStream), ConnectionError>>>,
65    closing: Option<BoxFuture<'static, ConnectionError>>,
66}
67
68pub struct Connecting {
69    pub connecting: BoxFuture<'static, Result<(libp2p::PeerId, iroh::endpoint::Connection), TransportError>>,
70}
71
72impl Connection {
73    pub fn new(connection: iroh::endpoint::Connection) -> Self {
74        tracing::debug!("Connection::new - Creating new connection wrapper");
75        Self {
76            connection,
77            incoming: None,
78            outgoing: None,
79            closing: None,
80        }
81    }
82}
83
84impl StreamMuxer for Connection {
85    type Substream = Stream;
86    type Error = ConnectionError;
87
88    fn poll_inbound(
89        self: Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> Poll<Result<Self::Substream, Self::Error>> {
92        let this = self.get_mut();
93
94        let incoming = this.incoming.get_or_insert_with(|| {
95            tracing::debug!("Connection::poll_inbound - Setting up incoming stream future");
96            let connection = this.connection.clone();
97            async move { 
98                tracing::debug!("Connection::poll_inbound - Accepting bidirectional stream");
99                match connection.accept_bi().await {
100                    Ok((s, mut r)) => {
101                        tracing::debug!("Connection::poll_inbound - Bidirectional stream accepted, reading handshake byte");
102                        r.read_u8().await.map_err(|e| {
103                            tracing::error!("Connection::poll_inbound - Failed to read handshake byte: {}", e);
104                            ConnectionError::from("Failed to read from stream")
105                        })?;
106                        tracing::debug!("Connection::poll_inbound - Handshake byte read successfully");
107                        Ok((s, r))
108                    },
109                    Err(e) => {
110                        tracing::error!("Connection::poll_inbound - Failed to accept bidirectional stream: {}", e);
111                        Err(ConnectionError::from("Iroh handshake failed during accept"))
112                    }
113                }
114             }.boxed()
115        });
116
117        let (send, recv) = futures::ready!(incoming.poll_unpin(cx))?;
118        this.incoming.take();
119        tracing::debug!("Connection::poll_inbound - Inbound stream ready, creating Stream wrapper");
120        Poll::Ready(Stream::new(send, recv).map_err(Into::into))
121    }
122
123    fn poll_outbound(
124        self: Pin<&mut Self>,
125        cx: &mut std::task::Context<'_>,
126    ) -> Poll<Result<Self::Substream, Self::Error>> {
127        let this = self.get_mut();
128
129        let outgoing = this.outgoing.get_or_insert_with(|| {
130            tracing::debug!("Connection::poll_outbound - Setting up outgoing stream future");
131            let connection = this.connection.clone();
132            async move { 
133                tracing::debug!("Connection::poll_outbound - Opening bidirectional stream");
134                match connection.open_bi().await {
135                    Ok((mut s, r)) => {
136                        tracing::debug!("Connection::poll_outbound - Bidirectional stream opened, writing handshake byte");
137                        // one byte iroh-handshake since accept only connects after open and write, not just open
138                        s.write_u8(0).await.map_err(|e| {
139                            tracing::error!("Connection::poll_outbound - Failed to write handshake byte: {}", e);
140                            ConnectionError::from("Failed to write to stream")
141                        })?;
142                        tracing::debug!("Connection::poll_outbound - Handshake byte written successfully");
143                        Ok((s, r))
144                    }
145                    Err(e) => {
146                        tracing::error!("Connection::poll_outbound - Failed to open bidirectional stream: {}", e);
147                        Err(ConnectionError::from("Iroh handshake failed during open"))
148                    }
149                }
150            }.boxed()
151        });
152
153        let (send, recv) = futures::ready!(outgoing.poll_unpin(cx))?;
154        this.outgoing.take();
155        tracing::debug!("Connection::poll_outbound - Outbound stream ready, creating Stream wrapper");
156        Poll::Ready(Stream::new(send, recv).map_err(Into::into))
157    }
158
159    fn poll_close(
160        self: Pin<&mut Self>,
161        cx: &mut std::task::Context<'_>,
162    ) -> Poll<Result<(), Self::Error>> {
163        let this = self.get_mut();
164
165        let closing = this.closing.get_or_insert_with(|| {
166            tracing::debug!("Connection::poll_close - Closing connection");
167            this.connection.close(From::from(0u32), &[]);
168            let connection = this.connection.clone();
169            async move { 
170                tracing::debug!("Connection::poll_close - Waiting for connection to close");
171                connection.closed().await.into() 
172            }.boxed()
173        });
174
175
176        if matches!(
177            futures::ready!(closing.poll_unpin(cx)),
178            crate::ConnectionError { .. }
179        ) {
180            tracing::error!("Connection::poll_close - Failed to close connection");
181            return Poll::Ready(Err("failed to close connection".into()));
182        };
183
184        tracing::debug!("Connection::poll_close - Connection closed successfully");
185        Poll::Ready(Ok(()))
186    }
187
188    fn poll(
189        self: Pin<&mut Self>,
190        _cx: &mut std::task::Context<'_>,
191    ) -> Poll<Result<libp2p_core::muxing::StreamMuxerEvent, Self::Error>> {
192        Poll::Pending
193    }
194}
195
196impl Future for Connecting {
197    type Output = Result<(libp2p::PeerId, libp2p_core::muxing::StreamMuxerBox), TransportError>;
198
199    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
200        tracing::debug!("Connecting::poll - Polling connection future");
201        let (peer_id, conn) = match self.connecting.poll_unpin(cx) {
202            Poll::Ready(Ok((peer_id, conn))) => {
203                tracing::debug!("Connecting::poll - Connection established");
204                (peer_id, conn)
205            },
206            Poll::Ready(Err(e)) => {
207                tracing::error!("Connecting::poll - Connection failed: {}", e);
208                return Poll::Ready(Err(e));
209            },
210            Poll::Pending => {
211                tracing::trace!("Connecting::poll - Connection still pending");
212                return Poll::Pending;
213            }
214        };
215
216        let muxer = Connection::new(conn);
217
218        tracing::debug!("Connecting::poll - Connection muxer created");
219        Poll::Ready(Ok((peer_id, libp2p_core::muxing::StreamMuxerBox::new(muxer))))
220    }
221}