1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use tokio_util::sync::CancellationToken;

use std::io::{Error as IoError, ErrorKind, IoSlice, Result as IoResult};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::conn::HttpBuilder;
use crate::fuse::{ArcFusewire, FuseEvent};
use crate::http::HttpConnection;
use crate::service::HyperHandler;

/// A stream that can be fused.
#[pin_project]
pub struct StraightStream<C> {
    #[pin]
    inner: C,
    fusewire: Option<ArcFusewire>,
}

impl<C> StraightStream<C>
where
    C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
    /// Create a new `StraightStream`.
    pub fn new(inner: C, fusewire: Option<ArcFusewire>) -> Self {
        Self { inner, fusewire }
    }
}

impl<C> HttpConnection for StraightStream<C>
where
    C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
    async fn serve(
        self,
        handler: HyperHandler,
        builder: Arc<HttpBuilder>,
        graceful_stop_token: CancellationToken,
    ) -> std::io::Result<()> {
        let fusewire = self.fusewire.clone();
        if let Some(fusewire) = &fusewire {
            fusewire.event(FuseEvent::Alive);
        }
        builder
            .serve_connection(self, handler, fusewire, graceful_stop_token)
            .await
            .map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
    }
    fn fusewire(&self) -> Option<ArcFusewire> {
        self.fusewire.clone()
    }
}

impl<C> AsyncRead for StraightStream<C>
where
    C: AsyncRead,
{
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
        let this = self.project();
        let remaining = buf.remaining();
        match this.inner.poll_read(cx, buf) {
            Poll::Ready(Ok(())) => {
                if let Some(fusewire) = &this.fusewire {
                    fusewire.event(FuseEvent::ReadData(remaining - buf.remaining()));
                }
                Poll::Ready(Ok(()))
            }
            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
            Poll::Pending => {
                if let Some(fusewire) = &this.fusewire {
                    fusewire.event(FuseEvent::Alive);
                }
                Poll::Pending
            }
        }
    }
}

impl<C> AsyncWrite for StraightStream<C>
where
    C: AsyncWrite,
{
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
        let this = self.project();
        match this.inner.poll_write(cx, buf) {
            Poll::Ready(Ok(len)) => {
                if let Some(fusewire) = &this.fusewire {
                    fusewire.event(FuseEvent::WriteData(len));
                }
                Poll::Ready(Ok(len))
            }
            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
            Poll::Pending => {
                if let Some(fusewire) = &this.fusewire {
                    fusewire.event(FuseEvent::Alive);
                }
                Poll::Pending
            }
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
        let this = self.project();
        if let Some(fusewire) = &this.fusewire {
            fusewire.event(FuseEvent::Alive);
        }
        this.inner.poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
        let this = self.project();
        if let Some(fusewire) = &this.fusewire {
            fusewire.event(FuseEvent::Alive);
        }
        this.inner.poll_shutdown(cx)
    }

    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<IoResult<usize>> {
        let this = self.project();
        if let Some(fusewire) = &this.fusewire {
            fusewire.event(FuseEvent::Alive);
        }
        this.inner.poll_write_vectored(cx, bufs)
    }

    fn is_write_vectored(&self) -> bool {
        self.inner.is_write_vectored()
    }
}