count_write/
futures.rs

1// count-write
2// Copyright (C) SOFe
3//
4// Licensed under the Apache License, Version 2.0 (the License);
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an AS IS BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use core::pin::Pin;
17use core::task::{Context, Poll};
18
19use futures_io::AsyncWrite;
20
21use crate::{CountWrite, Result};
22
23#[cfg(any(feature = "futures"))]
24/// Wrapper for `futures_io::AsyncWrite`, used in the `futures-preview` family
25///
26/// *Only available with the `"futures"` feature*
27impl<W: AsyncWrite + Unpin> AsyncWrite for CountWrite<W> {
28    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
29        let Self { inner, count } = unsafe { self.get_unchecked_mut() };
30        let pin = unsafe { Pin::new_unchecked(inner) };
31        let ret = pin.poll_write(ctx, buf);
32        if let Poll::Ready(ret) = &ret {
33            if let Ok(written) = &ret {
34                *count += *written as u64;
35            }
36        }
37        ret
38    }
39
40    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
41        unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_flush(ctx)
42    }
43
44    fn poll_close(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result> {
45        unsafe { self.map_unchecked_mut(|cw| &mut cw.inner) }.poll_close(ctx)
46    }
47}