count_write/tokio.rs
1// count-write
2// Copyright (C) SOFe
3//
4// Licensed under the Apache License, Version 2.0 (the License);
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an AS IS BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use core::pin::Pin;
17use core::task::{Context, Poll};
18
19use tokio_io::AsyncWrite;
20
21use crate::{CountWrite, Result};
22
23/// Wrapper for `tokio_io::AsyncWrite`, used in the `tokio` family
24///
25/// *Only available with the `"tokio"` feature*
26impl<W: AsyncWrite> AsyncWrite for CountWrite<W> {
27 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
28 let Self { inner, count } = unsafe { self.get_unchecked_mut() };
29 let pin = unsafe { Pin::new_unchecked(inner) };
30 let ret = pin.poll_write(ctx, buf);
31 if let Poll::Ready(ret) = &ret {
32 if let Ok(written) = &ret {
33 *count += *written as u64;
34 }
35 }
36 ret
37 }
38
39 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
40 unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_flush(ctx)
41 }
42
43 fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
44 unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_shutdown(ctx)
45 }
46}