1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//! Primitives for asynchronous writes.

use std::{
    future::Future,
    io::{ErrorKind, Result},
};

/// Writes some bytes into an object.
pub trait Write {
    /// A future that resolves to the result of [`Self::write`].
    type Write<'a>: Future<Output = Result<usize>> + 'a
    where
        Self: 'a;

    /// Writes some bytes from `buf` into this object.
    ///
    /// Returns the number of bytes written.
    fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::Write<'a>;
}

/// Provides extension methods for [`Write`].
pub trait WriteExt {
    /// A future that resolves to the result of [`Self::write_all`].
    type WriteAll<'a>: Future<Output = Result<()>> + 'a
    where
        Self: 'a;

    /// Writes all bytes from `buf` into this object.
    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAll<'a>;
}

impl<T> WriteExt for T
where
    T: Write,
{
    type WriteAll<'a> = impl Future<Output = Result<()>> + 'a
    where
        Self: 'a;

    fn write_all<'a>(&'a mut self, mut buf: &'a [u8]) -> Self::WriteAll<'a> {
        async move {
            while !buf.is_empty() {
                match self.write(buf).await {
                    Ok(0) => return Err(ErrorKind::WriteZero.into()),
                    Ok(n) => buf = &buf[n..],
                    Err(e) if e.kind() == ErrorKind::Interrupted => {}
                    Err(e) => return Err(e),
                }
            }
            Ok(())
        }
    }
}

/// Writes some bytes into an object at a given position.
pub trait WriteAt {
    /// A future that resolves to the result of [`Self::write_at`].
    type WriteAt<'a>: Future<Output = Result<usize>> + 'a
    where
        Self: 'a;

    /// Writes some bytes from `buf` into this object at `pos`.
    ///
    /// Returns the number of bytes written.
    fn write_at<'a>(&'a self, buf: &'a [u8], pos: u64) -> Self::WriteAt<'a>;
}

/// Provides extension methods for [`WriteAt`].
pub trait WriteAtExt {
    /// A future that resolves to the result of [`Self::write_all_at`].
    type WriteAllAt<'a>: Future<Output = Result<()>> + 'a
    where
        Self: 'a;

    /// Writes all bytes from `buf` into this object at `pos`.
    fn write_all_at<'a>(&'a self, buf: &'a [u8], pos: u64) -> Self::WriteAllAt<'a>;
}

impl<T> WriteAtExt for T
where
    T: WriteAt,
{
    type WriteAllAt<'a> = impl Future<Output = Result<()>> + 'a
    where
        Self: 'a;

    fn write_all_at<'a>(&'a self, mut buf: &'a [u8], mut pos: u64) -> Self::WriteAllAt<'a> {
        async move {
            while !buf.is_empty() {
                match self.write_at(buf, pos).await {
                    Ok(0) => return Err(ErrorKind::WriteZero.into()),
                    Ok(n) => {
                        buf = &buf[n..];
                        pos += n as u64;
                    }
                    Err(e) if e.kind() == ErrorKind::Interrupted => {}
                    Err(e) => return Err(e),
                }
            }
            Ok(())
        }
    }
}