1use std::{fmt::Display, pin::Pin};
2
3use tokio::io::AsyncWrite;
4
5#[derive(Debug, Clone)]
7pub struct StreamError {
8 kind: StreamErrorKind,
9}
10
11#[derive(Debug, Clone)]
12pub enum StreamErrorKind {
13 Read(String),
14 Write(String),
15 Connection(String),
16}
17
18impl From<std::io::Error> for StreamError {
19 fn from(err: std::io::Error) -> Self {
20 Self {
21 kind: StreamErrorKind::Read(err.to_string()),
22 }
23 }
24}
25
26impl From<iroh::endpoint::ConnectionError> for StreamError {
27 fn from(err: iroh::endpoint::ConnectionError) -> Self {
28 Self {
29 kind: StreamErrorKind::Connection(err.to_string()),
30 }
31 }
32}
33
34impl From<iroh::endpoint::WriteError> for StreamError {
35 fn from(err: iroh::endpoint::WriteError) -> Self {
36 Self {
37 kind: StreamErrorKind::Write(err.to_string()),
38 }
39 }
40}
41
42impl From<iroh::endpoint::ReadError> for StreamError {
43 fn from(err: iroh::endpoint::ReadError) -> Self {
44 Self {
45 kind: StreamErrorKind::Read(err.to_string()),
46 }
47 }
48}
49
50
51impl From<&str> for StreamError {
52 fn from(err: &str) -> Self {
53 Self {
54 kind: StreamErrorKind::Connection(err.to_string()),
55 }
56 }
57}
58
59impl Display for StreamError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match &self.kind {
62 StreamErrorKind::Read(msg) => write!(f, "IrohStream Read Error: {msg}"),
63 StreamErrorKind::Write(msg) => write!(f, "IrohStream Write Error: {msg}"),
64 StreamErrorKind::Connection(msg) => {
65 write!(f, "IrohStream Connection Error: {msg}")
66 }
67 }
68 }
69}
70
71impl std::error::Error for StreamError {}
72
73#[derive(Debug)]
74pub struct Stream {
75 sender: Option<iroh::endpoint::SendStream>,
76 receiver: Option<iroh::endpoint::RecvStream>,
77 closing: bool,
78}
79
80impl Stream {
81 pub fn new(
82 sender: iroh::endpoint::SendStream,
83 receiver: iroh::endpoint::RecvStream,
84 ) -> Result<Self, StreamError> {
85 tracing::debug!("Stream::new - Creating new stream wrapper");
86 Ok(Self {
87 sender: Some(sender),
88 receiver: Some(receiver),
89 closing: false,
90 })
91 }
92}
93
94impl futures::AsyncRead for Stream {
95 fn poll_read(
96 mut self: std::pin::Pin<&mut Self>,
97 cx: &mut std::task::Context<'_>,
98 buf: &mut [u8],
99 ) -> std::task::Poll<std::io::Result<usize>> {
100 if let Some(receiver) = &mut self.receiver {
101 match Pin::new(receiver).poll_read(cx, buf) {
102 std::task::Poll::Ready(Ok(n)) => {
103 if n == 0 {
104 tracing::debug!("Stream::poll_read - EOF reached (0 bytes)");
105 } else {
106 tracing::trace!("Stream::poll_read - Read {} bytes", n);
107 }
108 std::task::Poll::Ready(Ok(n))
109 }
110 std::task::Poll::Ready(Err(e)) => {
111 tracing::debug!("Stream::poll_read - Read error: {}", e);
112 std::task::Poll::Ready(Err(std::io::Error::other(
113 e,
114 )))
115 }
116 std::task::Poll::Pending => std::task::Poll::Pending,
117 }
118 } else {
119 tracing::debug!("Stream::poll_read - Stream receiver already closed locally");
120 std::task::Poll::Ready(Err(std::io::Error::new(
121 std::io::ErrorKind::BrokenPipe,
122 "stream receiver closed",
123 )))
124 }
125 }
126}
127
128impl futures::AsyncWrite for Stream {
129 fn poll_write(
130 mut self: Pin<&mut Self>,
131 cx: &mut std::task::Context<'_>,
132 buf: &[u8],
133 ) -> std::task::Poll<std::io::Result<usize>> {
134 if let Some(sender) = &mut self.sender {
135 match Pin::new(sender).poll_write(cx, buf) {
136 std::task::Poll::Ready(Ok(n)) => {
137 tracing::trace!("Stream::poll_write - Wrote {} bytes", n);
138 std::task::Poll::Ready(Ok(n))
139 }
140 std::task::Poll::Ready(Err(e)) => {
141 let err_str = e.to_string();
143 if err_str.contains("stopped") || err_str.contains("error 0") {
144 tracing::debug!("Stream::poll_write - Remote peer closed stream: {}", e);
145 } else {
146 tracing::error!("Stream::poll_write - Write error: {}", e);
147 }
148 std::task::Poll::Ready(Err(std::io::Error::other(
149 e,
150 )))
151 }
152 std::task::Poll::Pending => std::task::Poll::Pending,
153 }
154 } else {
155 tracing::debug!("Stream::poll_write - Stream sender already closed locally");
156 std::task::Poll::Ready(Err(std::io::Error::new(
157 std::io::ErrorKind::BrokenPipe,
158 "stream sender closed",
159 )))
160 }
161 }
162
163 fn poll_flush(
164 mut self: Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> std::task::Poll<std::io::Result<()>> {
167 if let Some(sender) = &mut self.sender {
168 match Pin::new(sender).poll_flush(cx) {
169 std::task::Poll::Ready(Ok(())) => {
170 tracing::trace!("Stream::poll_flush - Flush successful");
171 std::task::Poll::Ready(Ok(()))
172 }
173 std::task::Poll::Ready(Err(e)) => {
174 tracing::debug!("Stream::poll_flush - Flush error: {}", e);
175 std::task::Poll::Ready(Err(std::io::Error::other(
176 e,
177 )))
178 }
179 std::task::Poll::Pending => std::task::Poll::Pending,
180 }
181 } else {
182 tracing::debug!("Stream::poll_flush - Stream sender already closed locally");
183 std::task::Poll::Ready(Err(std::io::Error::new(
184 std::io::ErrorKind::BrokenPipe,
185 "stream sender closed",
186 )))
187 }
188 }
189
190 fn poll_close(
191 mut self: Pin<&mut Self>,
192 _cx: &mut std::task::Context<'_>,
193 ) -> std::task::Poll<std::io::Result<()>> {
194 if !self.closing {
195 tracing::debug!("Stream::poll_close - Starting to close stream (write side)");
196 self.closing = true;
197
198 if let Some(mut sender) = self.sender.take() {
200 if let Err(e) = sender.finish() {
201 tracing::warn!("Stream::poll_close - Error finishing sender: {}", e);
202 } else {
203 tracing::debug!("Stream::poll_close - Sender finished successfully");
204 }
205 }
206 }
207 tracing::debug!("Stream::poll_close - Write side closed");
208 std::task::Poll::Ready(Ok(()))
209 }
210}