1use anyhow::Result;
2use bytes::Bytes;
3use foctet_core::connection::SessionId;
4use foctet_core::frame::{Frame, FrameBuilder, FrameFlags, FrameType};
5use foctet_core::stream::StreamId;
6use std::io;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::sync::mpsc::{Receiver, Sender};
12
13#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum StreamState {
16 Init,
18 OpenSent,
20 OpenReceived,
22 Established,
24 LocalClosing,
26 RemoteClosing,
28 Closed,
30 Reset,
32}
33
34#[derive(Debug, Eq, PartialEq)]
36pub enum StreamEvent {
37 Frame(Frame),
38 Closed(StreamId),
39 Error,
40}
41
42#[derive(Debug)]
44pub struct LogicalStreamWriter {
45 session_id: SessionId,
46 stream_id: StreamId,
47 state: Arc<Mutex<StreamState>>,
48 frame_sender: Sender<StreamEvent>,
49}
50
51impl LogicalStreamWriter {
52 pub fn set_state(&mut self, state: StreamState) {
53 match self.state.lock() {
54 Ok(mut state_guard) => {
55 *state_guard = state;
56 }
57 Err(_) => (),
58 }
59 }
60 pub fn session_id(&self) -> SessionId {
62 self.session_id
63 }
64 pub fn stream_id(&self) -> StreamId {
66 self.stream_id
67 }
68 pub fn state(&self) -> StreamState {
70 match self.state.lock() {
71 Ok(state_guard) => *state_guard,
72 Err(_) => StreamState::Closed,
73 }
74 }
75 pub async fn send_event(&self, event: StreamEvent) -> Result<()> {
76 match self.frame_sender.send(event).await {
77 Ok(_) => Ok(()),
78 Err(_) => anyhow::bail!(io::Error::new(
79 io::ErrorKind::BrokenPipe,
80 "Failed to send event"
81 )),
82 }
83 }
84 pub async fn send_frame(&self, frame: Frame) -> Result<()> {
86 self.frame_sender
87 .send(StreamEvent::Frame(frame))
88 .await
89 .map_err(|_| {
90 anyhow::anyhow!(io::Error::new(
91 io::ErrorKind::BrokenPipe,
92 "Failed to send frame"
93 ))
94 })
95 }
96
97 pub async fn send_bytes(&self, bytes: Bytes) -> Result<()> {
99 let frame = FrameBuilder::new()
100 .with_stream_id(self.stream_id)
101 .with_frame_type(FrameType::Data)
102 .with_payload(bytes)
103 .build();
104 self.send_frame(frame).await
105 }
106
107 async fn send_close_request(&mut self) -> Result<()> {
108 let frame_flags = FrameFlags::close_request();
109 let close_frame: Frame = Frame::builder()
110 .with_stream_id(self.stream_id)
111 .with_flags(frame_flags)
112 .build();
113 match self
114 .frame_sender
115 .send(StreamEvent::Frame(close_frame))
116 .await
117 {
118 Ok(_) => Ok(()),
119 Err(_) => anyhow::bail!(io::Error::new(
120 io::ErrorKind::BrokenPipe,
121 "Failed to send close frame"
122 )),
123 }
124 }
125
126 pub async fn close(&mut self) -> Result<()> {
128 let state = match self.state.lock() {
129 Ok(state_guard) => *state_guard,
130 Err(_) => StreamState::Closed,
131 };
132 match state {
133 StreamState::OpenSent
134 | StreamState::OpenReceived
135 | StreamState::Established
136 | StreamState::Init => {
137 self.set_state(StreamState::LocalClosing);
138 self.send_close_request().await?;
139 }
140 StreamState::RemoteClosing => {
141 self.set_state(StreamState::Closed);
142 self.send_close_request().await?;
143 let event = StreamEvent::Closed(self.stream_id);
144 self.send_event(event).await?;
145 }
146 StreamState::Reset | StreamState::Closed => {
147 self.set_state(StreamState::Closed);
148 let event = StreamEvent::Closed(self.stream_id);
149 self.send_event(event).await?;
150 }
151 StreamState::LocalClosing => {
152 self.set_state(StreamState::Closed);
153 let event = StreamEvent::Closed(self.stream_id);
154 self.send_event(event).await?;
155 }
156 }
157 Ok(())
158 }
159}
160
161impl AsyncWrite for LogicalStreamWriter {
162 fn poll_write(
163 self: Pin<&mut Self>,
164 _cx: &mut Context<'_>,
165 buf: &[u8],
166 ) -> Poll<std::io::Result<usize>> {
167 let payload = Bytes::copy_from_slice(buf);
168 let frame = FrameBuilder::new()
169 .with_stream_id(self.stream_id)
170 .with_frame_type(FrameType::Data)
171 .with_payload(payload)
172 .build();
173
174 match self.frame_sender.try_send(StreamEvent::Frame(frame)) {
175 Ok(_) => {
176 Poll::Ready(Ok(buf.len()))
177 }
178 Err(_) => {
179 tracing::error!("Failed to send frame");
180 return Poll::Ready(Err(io::Error::new(
181 io::ErrorKind::BrokenPipe,
182 "Failed to send frame",
183 )));
184 }
185 }
186 }
187
188 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
189 Poll::Ready(Ok(()))
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
193 Poll::Ready(Ok(()))
195 }
196}
197
198#[derive(Debug)]
200pub struct LogicalStreamReader {
201 session_id: SessionId,
202 stream_id: StreamId,
203 state: Arc<Mutex<StreamState>>,
204 frame_receiver: Receiver<Frame>,
205}
206
207impl LogicalStreamReader {
208 pub fn set_state(&mut self, state: StreamState) {
209 match self.state.lock() {
210 Ok(mut state_guard) => {
211 *state_guard = state;
212 }
213 Err(_) => (),
214 }
215 }
216 pub fn session_id(&self) -> SessionId {
218 self.session_id
219 }
220 pub fn stream_id(&self) -> StreamId {
222 self.stream_id
223 }
224 pub fn state(&self) -> StreamState {
226 match self.state.lock() {
227 Ok(state_guard) => *state_guard,
228 Err(_) => StreamState::Closed,
229 }
230 }
231 pub async fn recv_frame(&mut self) -> Result<Frame> {
233 match self.frame_receiver.recv().await {
234 Some(frame) => {
235 self.set_state_from_flags(frame.header.flags);
236 Ok(frame)
237 }
238 None => {
239 if self.frame_receiver.is_closed() {
240 anyhow::bail!(io::Error::new(io::ErrorKind::BrokenPipe, "Channel closed"))
241 } else {
242 anyhow::bail!(io::Error::new(
243 io::ErrorKind::BrokenPipe,
244 "Failed to receive frame"
245 ))
246 }
247 }
248 }
249 }
250
251 pub async fn recv_bytes(&mut self) -> Result<Bytes> {
253 let frame = self.recv_frame().await?;
254 Ok(frame.payload)
255 }
256
257 fn set_state_from_flags(&mut self, flags: FrameFlags) {
258 if flags.is_open_request() {
259 self.set_state(StreamState::OpenReceived);
260 } else if flags.is_open_response() {
261 self.set_state(StreamState::Established);
262 } else if flags.is_open_reset() {
263 self.set_state(StreamState::Reset);
264 }
265 }
266}
267
268impl AsyncRead for LogicalStreamReader {
269 fn poll_read(
270 mut self: Pin<&mut Self>,
271 _cx: &mut Context<'_>,
272 buf: &mut ReadBuf<'_>,
273 ) -> Poll<io::Result<()>> {
274 match self.frame_receiver.try_recv() {
275 Ok(frame) => {
276 self.set_state_from_flags(frame.header.flags);
277 buf.put_slice(&frame.payload);
278 Poll::Ready(Ok(()))
279 }
280 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => Poll::Pending,
281 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
282 self.set_state(StreamState::Closed);
284 Poll::Ready(Ok(()))
285 }
286 }
287 }
288}
289
290#[derive(Debug)]
291pub struct LogicalStream {
292 session_id: SessionId,
294 stream_id: StreamId,
296 state: StreamState,
298 frame_sender: Sender<StreamEvent>,
300 frame_receiver: Receiver<Frame>,
303}
304
305impl LogicalStream {
306 pub fn new(
307 session_id: SessionId,
308 stream_id: StreamId,
309 state: StreamState,
310 frame_sender: Sender<StreamEvent>,
311 frame_receiver: Receiver<Frame>,
312 ) -> Self {
313 Self {
314 session_id,
315 stream_id,
316 state,
317 frame_sender,
318 frame_receiver,
319 }
320 }
321
322 pub async fn send_frame(&self, frame: Frame) -> Result<()> {
323 match self.frame_sender.send(StreamEvent::Frame(frame)).await {
324 Ok(_) => Ok(()),
325 Err(_) => anyhow::bail!(io::Error::new(
326 io::ErrorKind::BrokenPipe,
327 "Failed to send frame"
328 )),
329 }
330 }
331
332 pub async fn recv_frame(&mut self) -> Result<Frame> {
333 match self.frame_receiver.recv().await {
334 Some(frame) => {
335 self.set_state_from_flags(frame.header.flags);
336 Ok(frame)
337 }
338 None => anyhow::bail!(io::Error::new(
339 io::ErrorKind::BrokenPipe,
340 "Failed to receive frame"
341 )),
342 }
343 }
344
345 async fn send_event(&self, event: StreamEvent) -> Result<()> {
346 match self.frame_sender.send(event).await {
347 Ok(_) => Ok(()),
348 Err(_) => anyhow::bail!(io::Error::new(
349 io::ErrorKind::BrokenPipe,
350 "Failed to send event"
351 )),
352 }
353 }
354
355 async fn send_close_request(&mut self) -> Result<()> {
356 let frame_flags = FrameFlags::close_request();
357 let close_frame: Frame = Frame::builder()
358 .with_stream_id(self.stream_id)
359 .with_flags(frame_flags)
360 .build();
361 match self
362 .frame_sender
363 .send(StreamEvent::Frame(close_frame))
364 .await
365 {
366 Ok(_) => Ok(()),
367 Err(_) => anyhow::bail!(io::Error::new(
368 io::ErrorKind::BrokenPipe,
369 "Failed to send close frame"
370 )),
371 }
372 }
373
374 pub async fn close(&mut self) -> Result<()> {
376 match self.state {
377 StreamState::OpenSent
378 | StreamState::OpenReceived
379 | StreamState::Established
380 | StreamState::Init => {
381 self.state = StreamState::LocalClosing;
382 self.send_close_request().await?;
383 }
384 StreamState::RemoteClosing => {
385 self.state = StreamState::Closed;
386 self.send_close_request().await?;
387 let event = StreamEvent::Closed(self.stream_id);
388 self.send_event(event).await?;
389 }
390 StreamState::Reset | StreamState::Closed => {
391 self.state = StreamState::Closed;
392 let event = StreamEvent::Closed(self.stream_id);
393 self.send_event(event).await?;
394 }
395 StreamState::LocalClosing => {
396 self.state = StreamState::Closed;
397 let event = StreamEvent::Closed(self.stream_id);
398 self.send_event(event).await?;
399 }
400 }
401 Ok(())
402 }
403
404 fn set_state_from_flags(&mut self, flags: FrameFlags) {
405 if flags.is_open_request() {
406 self.state = StreamState::OpenReceived;
407 } else if flags.is_open_response() {
408 self.state = StreamState::Established;
409 } else if flags.is_open_reset() {
410 self.state = StreamState::Reset;
411 } else if flags.is_close_request() {
412 self.state = StreamState::RemoteClosing;
413 } else if flags.is_close_response() {
414 self.state = StreamState::Closed;
415 }
416 }
417
418 pub fn split(self) -> (LogicalStreamWriter, LogicalStreamReader) {
420 let state = Arc::new(Mutex::new(self.state));
421 let writer = LogicalStreamWriter {
422 session_id: self.session_id,
423 stream_id: self.stream_id,
424 state: Arc::clone(&state),
425 frame_sender: self.frame_sender,
426 };
427 let reader = LogicalStreamReader {
428 session_id: self.session_id,
429 stream_id: self.stream_id,
430 state,
431 frame_receiver: self.frame_receiver,
432 };
433 (writer, reader)
434 }
435
436 pub fn set_state(&mut self, state: StreamState) {
437 self.state = state;
438 }
439
440 pub fn stream_id(&self) -> StreamId {
441 self.stream_id
442 }
443
444 pub fn state(&self) -> StreamState {
445 self.state
446 }
447
448 pub fn session_id(&self) -> SessionId {
449 self.session_id
450 }
451}