hyper_inspect_io/
lib.rs

1use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
2use std::cmp;
3use std::io::{self, IoSlice};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7pub trait InspectRead {
8    fn inspect_read(&mut self, _value: Result<&[u8], &io::Error>) {}
9}
10
11pub trait InspectWrite {
12    fn inspect_write(&mut self, _value: Result<&[u8], &io::Error>) {}
13    fn inspect_flush(&mut self, _value: Result<(), &io::Error>) {}
14    fn inspect_shutdown(&mut self, _value: Result<(), &io::Error>) {}
15    fn inspect_write_vectored<'a, I>(&mut self, value: Result<I, &io::Error>)
16    where
17        I: Iterator<Item = &'a [u8]>,
18    {
19        match value {
20            Ok(bufs) => {
21                for buf in bufs {
22                    self.inspect_write(Ok(buf));
23                }
24            }
25            Err(e) => self.inspect_write(Err(e)),
26        }
27    }
28}
29
30#[pin_project::pin_project]
31#[derive(Clone, Debug)]
32pub struct Io<T, I> {
33    #[pin]
34    inner: T,
35    inspect: I,
36}
37
38impl<T, I> Io<T, I> {
39    pub fn new(inner: T, inspect: I) -> Self {
40        Self { inner, inspect }
41    }
42}
43
44impl<T, I> Read for Io<T, I>
45where
46    T: Read,
47    I: InspectRead,
48{
49    fn poll_read(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        mut buf: ReadBufCursor<'_>,
53    ) -> Poll<Result<(), io::Error>> {
54        let this = self.project();
55        unsafe {
56            let len = {
57                let mut buf = ReadBuf::uninit(buf.as_mut());
58                let value = ready!(this.inner.poll_read(cx, buf.unfilled()));
59                this.inspect
60                    .inspect_read(value.as_ref().map(|_| buf.filled()));
61                value.map(|_| buf.filled().len())?
62            };
63            buf.advance(len);
64        }
65        Poll::Ready(Ok(()))
66    }
67}
68
69impl<T, I> Write for Io<T, I>
70where
71    T: Write,
72    I: InspectWrite,
73{
74    fn poll_write(
75        self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78    ) -> Poll<io::Result<usize>> {
79        let this = self.project();
80        this.inner.poll_write(cx, buf).map(|value| {
81            this.inspect
82                .inspect_write(value.as_ref().map(|len| &buf[..*len]));
83            value
84        })
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
88        let this = self.project();
89        this.inner.poll_flush(cx).map(|value| {
90            this.inspect.inspect_flush(value.as_ref().map(|_| ()));
91            value
92        })
93    }
94
95    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
96        let this = self.project();
97        this.inner.poll_shutdown(cx).map(|value| {
98            this.inspect.inspect_shutdown(value.as_ref().map(|_| ()));
99            value
100        })
101    }
102
103    fn is_write_vectored(&self) -> bool {
104        self.inner.is_write_vectored()
105    }
106
107    fn poll_write_vectored(
108        self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        bufs: &[IoSlice<'_>],
111    ) -> Poll<Result<usize, io::Error>> {
112        let this = self.project();
113        this.inner.poll_write_vectored(cx, bufs).map(|value| {
114            this.inspect
115                .inspect_write_vectored(value.as_ref().map(|len| {
116                    bufs.iter().scan(*len, |len, buf| {
117                        let buf = &buf[..cmp::min(*len, buf.len())];
118                        *len -= buf.len();
119                        (!buf.is_empty()).then_some(buf)
120                    })
121                }));
122            value
123        })
124    }
125}
126
127#[cfg(feature = "hyper-util")]
128impl<T, I> hyper_util::client::legacy::connect::Connection for Io<T, I>
129where
130    T: hyper_util::client::legacy::connect::Connection,
131{
132    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
133        self.inner.connected()
134    }
135}
136
137#[cfg(feature = "__examples")]
138pub mod __examples;
139#[cfg(test)]
140mod tests;