chitey_server/server/
http3_stream_wrapper.rs

1use bytes::{Bytes, Buf};
2use h3::{quic::BidiStream, server::RequestStream};
3use core::pin::Pin;
4use std::{task::{Context, Poll}, sync::{Mutex, Arc}};
5use futures_util::{Future, Stream};
6
7pub struct StreamWrapper<W> 
8where
9W: BidiStream<Bytes> + 'static + Send + Sync
10{
11    inner: Arc<Mutex<RequestStream<W::RecvStream, Bytes>>>,
12    inner_read: Arc<Mutex<Pin<Box<dyn Future<Output = Option<Result<Bytes, h3::Error>>>>>>>,
13}
14
15impl<W> StreamWrapper<W>
16where
17    W: BidiStream<Bytes> + 'static + Send + Sync
18{
19    #[inline]
20    pub fn new(inner: RequestStream<W::RecvStream, Bytes>) -> Self {
21        let inner = Arc::new(Mutex::new(inner));
22        Self {
23            inner: inner.clone(),
24            inner_read: Arc::new(Mutex::new(Box::pin(Self::recv_data_wrap(inner)))),
25        }
26    }
27
28    #[inline]
29    async fn recv_data_wrap(inner: Arc<Mutex<RequestStream<W::RecvStream, Bytes>>>) -> Option<Result<Bytes, h3::Error>>
30    where
31        W: BidiStream<Bytes> + 'static + Send + Sync
32    {
33        match inner.lock().unwrap().recv_data().await {
34        Ok(v) => {
35            match v {
36                Some(data) => Some(Ok(Bytes::copy_from_slice(data.chunk()))),
37                None => {None}
38            }
39        },
40        Err(e) => {Some(Err(e))}
41        }
42    }
43}
44
45impl<W> Stream for StreamWrapper<W>
46where
47W: BidiStream<Bytes> + 'static + Send + Sync
48{
49    type Item = Result<Bytes, h3::Error>;
50
51    #[inline]
52    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>)
53        -> Poll<Option<Self::Item>>
54    {
55        let mut p = self.inner_read.lock().unwrap();
56        match Pin::new(&mut *p).poll(cx) {
57            Poll::Ready(bytes) => {
58                *p = Box::pin(Self::recv_data_wrap(self.inner.clone()));
59                Poll::Ready(bytes)
60            }
61            Poll::Pending => Poll::Pending,
62        }
63    }
64}
65
66unsafe impl<W> Send for StreamWrapper<W>
67where
68W: BidiStream<Bytes> + 'static + Send + Sync { }