1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use std::{
    io::{Error, ErrorKind},
    os::unix::{
        io::{AsRawFd, RawFd},
        net::UnixStream as OsUnixStream,
        prelude::{FromRawFd, IntoRawFd},
    },
};

use tokio::{
    io::Interest,
    net::{
        unix::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf},
        UnixStream,
    },
};

use passfd::FdPassingExt;

use crate::{AsyncRecvFd, AsyncSendFd};

/// A trait to send raw file descriptors
pub trait AsyncSendTokioStream {
    fn send_stream(
        &self,
        fd: UnixStream,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send;
}

/// A trait to receive raw file descriptors
pub trait AsyncRecvTokioStream {
    fn recv_stream(&self) -> impl std::future::Future<Output = Result<UnixStream, Error>> + Send;
}

impl AsyncRecvFd for UnixStream {
    async fn recv_fd(&self) -> Result<RawFd, Error> {
        loop {
            self.readable().await?;

            match self.try_io(Interest::READABLE, || self.as_raw_fd().recv_fd()) {
                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
                    continue;
                }
                r => return r,
            }
        }
    }
}

impl AsyncSendFd for UnixStream {
    async fn send_fd(&self, fd: RawFd) -> Result<(), Error> {
        loop {
            self.writable().await?;

            match self.try_io(Interest::WRITABLE, || self.as_raw_fd().send_fd(fd)) {
                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
                    continue;
                }
                r => return r,
            }
        }
    }
}

impl AsyncSendTokioStream for UnixStream {
    async fn send_stream(&self, stream: UnixStream) -> Result<(), Error> {
        let fd = stream.into_std()?.into_raw_fd();

        self.send_fd(fd).await
    }
}

impl AsyncRecvTokioStream for UnixStream {
    async fn recv_stream(&self) -> Result<UnixStream, Error> {
        let fd = self.recv_fd().await?;

        let os_stream = unsafe { OsUnixStream::from_raw_fd(fd) };
        UnixStream::from_std(os_stream)
    }
}

impl AsyncRecvFd for ReadHalf<'_> {
    async fn recv_fd(&self) -> Result<RawFd, Error> {
        self.as_ref().recv_fd().await
    }
}

impl AsyncRecvTokioStream for ReadHalf<'_> {
    async fn recv_stream(&self) -> Result<UnixStream, Error> {
        self.as_ref().recv_stream().await
    }
}

impl AsyncSendFd for WriteHalf<'_> {
    async fn send_fd(&self, fd: RawFd) -> Result<(), Error> {
        self.as_ref().send_fd(fd).await
    }
}

impl AsyncSendTokioStream for WriteHalf<'_> {
    async fn send_stream(&self, stream: UnixStream) -> Result<(), Error> {
        self.as_ref().send_stream(stream).await
    }
}

impl AsyncRecvFd for OwnedReadHalf {
    async fn recv_fd(&self) -> Result<RawFd, Error> {
        self.as_ref().recv_fd().await
    }
}

impl AsyncRecvTokioStream for OwnedReadHalf {
    async fn recv_stream(&self) -> Result<UnixStream, Error> {
        self.as_ref().recv_stream().await
    }
}

impl AsyncSendFd for OwnedWriteHalf {
    async fn send_fd(&self, fd: RawFd) -> Result<(), Error> {
        self.as_ref().send_fd(fd).await
    }
}

impl AsyncSendTokioStream for OwnedWriteHalf {
    async fn send_stream(&self, stream: UnixStream) -> Result<(), Error> {
        self.as_ref().send_stream(stream).await
    }
}