count_write/
tokio.rs

1// count-write
2// Copyright (C) SOFe
3//
4// Licensed under the Apache License, Version 2.0 (the License);
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an AS IS BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use core::pin::Pin;
17use core::task::{Context, Poll};
18
19use tokio_io::AsyncWrite;
20
21use crate::{CountWrite, Result};
22
23/// Wrapper for `tokio_io::AsyncWrite`, used in the `tokio` family
24///
25/// *Only available with the `"tokio"` feature*
26impl<W: AsyncWrite> AsyncWrite for CountWrite<W> {
27    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
28        let Self { inner, count } = unsafe { self.get_unchecked_mut() };
29        let pin = unsafe { Pin::new_unchecked(inner) };
30        let ret = pin.poll_write(ctx, buf);
31        if let Poll::Ready(ret) = &ret {
32            if let Ok(written) = &ret {
33                *count += *written as u64;
34            }
35        }
36        ret
37    }
38
39    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
40        unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_flush(ctx)
41    }
42
43    fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
44        unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_shutdown(ctx)
45    }
46}