libp2p_iroh/
stream.rs

1use std::{fmt::Display, pin::Pin};
2
3use tokio::io::AsyncWrite;
4
5// IrohStream error:
6#[derive(Debug, Clone)]
7pub struct StreamError {
8    kind: StreamErrorKind,
9}
10
11#[derive(Debug, Clone)]
12pub enum StreamErrorKind {
13    Read(String),
14    Write(String),
15    Connection(String),
16}
17
18impl From<std::io::Error> for StreamError {
19    fn from(err: std::io::Error) -> Self {
20        Self {
21            kind: StreamErrorKind::Read(err.to_string()),
22        }
23    }
24}
25
26impl From<iroh::endpoint::ConnectionError> for StreamError {
27    fn from(err: iroh::endpoint::ConnectionError) -> Self {
28        Self {
29            kind: StreamErrorKind::Connection(err.to_string()),
30        }
31    }
32}
33
34impl From<iroh::endpoint::WriteError> for StreamError {
35    fn from(err: iroh::endpoint::WriteError) -> Self {
36        Self {
37            kind: StreamErrorKind::Write(err.to_string()),
38        }
39    }
40}
41
42impl From<iroh::endpoint::ReadError> for StreamError {
43    fn from(err: iroh::endpoint::ReadError) -> Self {
44        Self {
45            kind: StreamErrorKind::Read(err.to_string()),
46        }
47    }
48}
49
50
51impl From<&str> for StreamError {
52    fn from(err: &str) -> Self {
53        Self {
54            kind: StreamErrorKind::Connection(err.to_string()),
55        }
56    }
57}
58
59impl Display for StreamError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match &self.kind {
62            StreamErrorKind::Read(msg) => write!(f, "IrohStream Read Error: {msg}"),
63            StreamErrorKind::Write(msg) => write!(f, "IrohStream Write Error: {msg}"),
64            StreamErrorKind::Connection(msg) => {
65                write!(f, "IrohStream Connection Error: {msg}")
66            }
67        }
68    }
69}
70
71impl std::error::Error for StreamError {}
72
73#[derive(Debug)]
74pub struct Stream {
75    sender: Option<iroh::endpoint::SendStream>,
76    receiver: Option<iroh::endpoint::RecvStream>,
77    closing: bool,
78}
79
80impl Stream {
81    pub fn new(
82        sender: iroh::endpoint::SendStream,
83        receiver: iroh::endpoint::RecvStream,
84    ) -> Result<Self, StreamError> {
85        tracing::debug!("Stream::new - Creating new stream wrapper");
86        Ok(Self {
87            sender: Some(sender),
88            receiver: Some(receiver),
89            closing: false,
90        })
91    }
92}
93
94impl futures::AsyncRead for Stream {
95    fn poll_read(
96        mut self: std::pin::Pin<&mut Self>,
97        cx: &mut std::task::Context<'_>,
98        buf: &mut [u8],
99    ) -> std::task::Poll<std::io::Result<usize>> {
100        if let Some(receiver) = &mut self.receiver {
101            match Pin::new(receiver).poll_read(cx, buf) {
102                std::task::Poll::Ready(Ok(n)) => {
103                    if n == 0 {
104                        tracing::debug!("Stream::poll_read - EOF reached (0 bytes)");
105                    } else {
106                        tracing::trace!("Stream::poll_read - Read {} bytes", n);
107                    }
108                    std::task::Poll::Ready(Ok(n))
109                }
110                std::task::Poll::Ready(Err(e)) => {
111                    tracing::debug!("Stream::poll_read - Read error: {}", e);
112                    std::task::Poll::Ready(Err(std::io::Error::other(
113                        e,
114                    )))
115                }
116                std::task::Poll::Pending => std::task::Poll::Pending,
117            }
118        } else {
119            tracing::debug!("Stream::poll_read - Stream receiver already closed locally");
120            std::task::Poll::Ready(Err(std::io::Error::new(
121                std::io::ErrorKind::BrokenPipe,
122                "stream receiver closed",
123            )))
124        }
125    }
126}
127
128impl futures::AsyncWrite for Stream {
129    fn poll_write(
130        mut self: Pin<&mut Self>,
131        cx: &mut std::task::Context<'_>,
132        buf: &[u8],
133    ) -> std::task::Poll<std::io::Result<usize>> {
134        if let Some(sender) = &mut self.sender {
135            match Pin::new(sender).poll_write(cx, buf) {
136                std::task::Poll::Ready(Ok(n)) => {
137                    tracing::trace!("Stream::poll_write - Wrote {} bytes", n);
138                    std::task::Poll::Ready(Ok(n))
139                }
140                std::task::Poll::Ready(Err(e)) => {
141                    // Check if this is a "stopped" error (remote side closed)
142                    let err_str = e.to_string();
143                    if err_str.contains("stopped") || err_str.contains("error 0") {
144                        tracing::debug!("Stream::poll_write - Remote peer closed stream: {}", e);
145                    } else {
146                        tracing::error!("Stream::poll_write - Write error: {}", e);
147                    }
148                    std::task::Poll::Ready(Err(std::io::Error::other(
149                        e,
150                    )))
151                }
152                std::task::Poll::Pending => std::task::Poll::Pending,
153            }
154        } else {
155            tracing::debug!("Stream::poll_write - Stream sender already closed locally");
156            std::task::Poll::Ready(Err(std::io::Error::new(
157                std::io::ErrorKind::BrokenPipe,
158                "stream sender closed",
159            )))
160        }
161    }
162
163    fn poll_flush(
164        mut self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> std::task::Poll<std::io::Result<()>> {
167        if let Some(sender) = &mut self.sender {
168            match Pin::new(sender).poll_flush(cx) {
169                std::task::Poll::Ready(Ok(())) => {
170                    tracing::trace!("Stream::poll_flush - Flush successful");
171                    std::task::Poll::Ready(Ok(()))
172                }
173                std::task::Poll::Ready(Err(e)) => {
174                    tracing::debug!("Stream::poll_flush - Flush error: {}", e);
175                    std::task::Poll::Ready(Err(std::io::Error::other(
176                        e,
177                    )))
178                }
179                std::task::Poll::Pending => std::task::Poll::Pending,
180            }
181        } else {
182            tracing::debug!("Stream::poll_flush - Stream sender already closed locally");
183            std::task::Poll::Ready(Err(std::io::Error::new(
184                std::io::ErrorKind::BrokenPipe,
185                "stream sender closed",
186            )))
187        }
188    }
189
190    fn poll_close(
191        mut self: Pin<&mut Self>,
192        _cx: &mut std::task::Context<'_>,
193    ) -> std::task::Poll<std::io::Result<()>> {
194        if !self.closing {
195            tracing::debug!("Stream::poll_close - Starting to close stream (write side)");
196            self.closing = true;
197            
198            // Finish the sender to signal we're done writing
199            if let Some(mut sender) = self.sender.take() {
200                if let Err(e) = sender.finish() {
201                    tracing::warn!("Stream::poll_close - Error finishing sender: {}", e);
202                } else {
203                    tracing::debug!("Stream::poll_close - Sender finished successfully");
204                }
205            }
206        }
207        tracing::debug!("Stream::poll_close - Write side closed");
208        std::task::Poll::Ready(Ok(()))
209    }
210}