1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
//! Core stream type for braid providing [AsyncRead] and [AsyncWrite].

use std::pin::pin;

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::info::BraidAddr;
use crate::info::{ConnectionInfo, HasConnectionInfo};
use crate::stream::duplex::DuplexStream;
use crate::stream::tcp::TcpStream;
use crate::stream::unix::UnixStream;

/// Dispatching wrapper for potential stream connection types
///
/// Effectively implements enum-dispatch for AsyncRead and AsyncWrite
/// around the stream types which we might use in braid.
///
/// This core type is used in the server and client modules, and so is
/// generic over the TLS stream type (which is different for client and server).
#[derive(Debug)]
#[pin_project(project = BraidCoreProjection)]
pub enum Braid {
    /// A TCP stream
    Tcp(#[pin] TcpStream),

    /// A duplex stream
    Duplex(#[pin] DuplexStream),

    /// A Unix stream
    Unix(#[pin] UnixStream),
}

impl HasConnectionInfo for Braid {
    type Addr = BraidAddr;
    fn info(&self) -> ConnectionInfo<BraidAddr> {
        match self {
            Braid::Tcp(stream) => stream.info().map(BraidAddr::Tcp),
            Braid::Duplex(stream) => {
                <DuplexStream as HasConnectionInfo>::info(stream).map(|_| BraidAddr::Duplex)
            }
            Braid::Unix(stream) => stream.info().map(BraidAddr::Unix),
        }
    }
}

macro_rules! dispatch_core {
    ($driver:ident.$method:ident($($args:expr),+)) => {

        match $driver.project() {
            BraidCoreProjection::Tcp(stream) => stream.$method($($args),+),
            BraidCoreProjection::Duplex(stream) => stream.$method($($args),+),
            BraidCoreProjection::Unix(stream) => stream.$method($($args),+),
        }
    };
}

impl AsyncRead for Braid {
    fn poll_read(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        dispatch_core!(self.poll_read(cx, buf))
    }
}

impl AsyncWrite for Braid {
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        dispatch_core!(self.poll_write(cx, buf))
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        dispatch_core!(self.poll_flush(cx))
    }

    fn poll_shutdown(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        dispatch_core!(self.poll_shutdown(cx))
    }
}

impl From<TcpStream> for Braid {
    fn from(stream: TcpStream) -> Self {
        Self::Tcp(stream)
    }
}

impl From<DuplexStream> for Braid {
    fn from(stream: DuplexStream) -> Self {
        Self::Duplex(stream)
    }
}

impl From<UnixStream> for Braid {
    fn from(stream: UnixStream) -> Self {
        Self::Unix(stream)
    }
}