1use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
2use std::cmp;
3use std::io::{self, IoSlice};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7pub trait InspectRead {
8 fn inspect_read(&mut self, _value: Result<&[u8], &io::Error>) {}
9}
10
11pub trait InspectWrite {
12 fn inspect_write(&mut self, _value: Result<&[u8], &io::Error>) {}
13 fn inspect_flush(&mut self, _value: Result<(), &io::Error>) {}
14 fn inspect_shutdown(&mut self, _value: Result<(), &io::Error>) {}
15 fn inspect_write_vectored<'a, I>(&mut self, value: Result<I, &io::Error>)
16 where
17 I: Iterator<Item = &'a [u8]>,
18 {
19 match value {
20 Ok(bufs) => {
21 for buf in bufs {
22 self.inspect_write(Ok(buf));
23 }
24 }
25 Err(e) => self.inspect_write(Err(e)),
26 }
27 }
28}
29
30#[pin_project::pin_project]
31#[derive(Clone, Debug)]
32pub struct Io<T, I> {
33 #[pin]
34 inner: T,
35 inspect: I,
36}
37
38impl<T, I> Io<T, I> {
39 pub fn new(inner: T, inspect: I) -> Self {
40 Self { inner, inspect }
41 }
42}
43
44impl<T, I> Read for Io<T, I>
45where
46 T: Read,
47 I: InspectRead,
48{
49 fn poll_read(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 mut buf: ReadBufCursor<'_>,
53 ) -> Poll<Result<(), io::Error>> {
54 let this = self.project();
55 unsafe {
56 let len = {
57 let mut buf = ReadBuf::uninit(buf.as_mut());
58 let value = ready!(this.inner.poll_read(cx, buf.unfilled()));
59 this.inspect
60 .inspect_read(value.as_ref().map(|_| buf.filled()));
61 value.map(|_| buf.filled().len())?
62 };
63 buf.advance(len);
64 }
65 Poll::Ready(Ok(()))
66 }
67}
68
69impl<T, I> Write for Io<T, I>
70where
71 T: Write,
72 I: InspectWrite,
73{
74 fn poll_write(
75 self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &[u8],
78 ) -> Poll<io::Result<usize>> {
79 let this = self.project();
80 this.inner.poll_write(cx, buf).map(|value| {
81 this.inspect
82 .inspect_write(value.as_ref().map(|len| &buf[..*len]));
83 value
84 })
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
88 let this = self.project();
89 this.inner.poll_flush(cx).map(|value| {
90 this.inspect.inspect_flush(value.as_ref().map(|_| ()));
91 value
92 })
93 }
94
95 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
96 let this = self.project();
97 this.inner.poll_shutdown(cx).map(|value| {
98 this.inspect.inspect_shutdown(value.as_ref().map(|_| ()));
99 value
100 })
101 }
102
103 fn is_write_vectored(&self) -> bool {
104 self.inner.is_write_vectored()
105 }
106
107 fn poll_write_vectored(
108 self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 bufs: &[IoSlice<'_>],
111 ) -> Poll<Result<usize, io::Error>> {
112 let this = self.project();
113 this.inner.poll_write_vectored(cx, bufs).map(|value| {
114 this.inspect
115 .inspect_write_vectored(value.as_ref().map(|len| {
116 bufs.iter().scan(*len, |len, buf| {
117 let buf = &buf[..cmp::min(*len, buf.len())];
118 *len -= buf.len();
119 (!buf.is_empty()).then_some(buf)
120 })
121 }));
122 value
123 })
124 }
125}
126
127#[cfg(feature = "hyper-util")]
128impl<T, I> hyper_util::client::legacy::connect::Connection for Io<T, I>
129where
130 T: hyper_util::client::legacy::connect::Connection,
131{
132 fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
133 self.inner.connected()
134 }
135}
136
137#[cfg(feature = "__examples")]
138pub mod __examples;
139#[cfg(test)]
140mod tests;