1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::fmt::Debug;
use std::io;
use std::io::Read;
use std::io::Result as IoResult;
use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};

use futures_lite::{AsyncRead, AsyncWrite};
use openssl::ssl;
use pin_project::pin_project;

use crate::net::TcpStream;

use super::async_to_sync_wrapper::AsyncToSyncWrapper;
use super::certificate::Certificate;

#[derive(Debug)]
pub struct TlsStream<S>(pub(super) ssl::SslStream<AsyncToSyncWrapper<S>>);

impl<S: Unpin> TlsStream<S> {
    pub fn peer_certificate(&self) -> Option<Certificate> {
        self.0.ssl().peer_certificate().map(Certificate)
    }

    fn with_context<F, R>(&mut self, cx: &mut Context<'_>, f: F) -> Poll<io::Result<R>>
    where
        F: FnOnce(&mut ssl::SslStream<AsyncToSyncWrapper<S>>) -> io::Result<R>,
    {
        self.0.get_mut().set_context(cx);
        let r = f(&mut self.0);
        self.0.get_mut().unset_context();
        result_to_poll(r)
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        self.with_context(cx, |stream| stream.read(buf))
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        self.with_context(cx, |stream| stream.write(buf))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.with_context(cx, |stream| stream.flush())
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.with_context(cx, |stream| match stream.shutdown() {
            Ok(_) => Ok(()),
            Err(ref e) if e.code() == openssl::ssl::ErrorCode::ZERO_RETURN => Ok(()),
            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
        })
    }
}

fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
    match r {
        Ok(v) => Poll::Ready(Ok(v)),
        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
        Err(e) => Poll::Ready(Err(e)),
    }
}

#[pin_project(project = EnumProj)]
pub enum AllTcpStream {
    Tcp(#[pin] TcpStream),
    Tls(#[pin] TlsStream<TcpStream>),
}

impl AllTcpStream {
    pub fn tcp(stream: TcpStream) -> Self {
        Self::Tcp(stream)
    }

    pub fn tls(stream: TlsStream<TcpStream>) -> Self {
        Self::Tls(stream)
    }
}

impl AsyncRead for AllTcpStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        match self.project() {
            EnumProj::Tcp(stream) => stream.poll_read(cx, buf),
            EnumProj::Tls(stream) => stream.poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for AllTcpStream {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<IoResult<usize>> {
        match self.project() {
            EnumProj::Tcp(stream) => stream.poll_write(cx, buf),
            EnumProj::Tls(stream) => stream.poll_write(cx, buf),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<IoResult<()>> {
        match self.project() {
            EnumProj::Tcp(stream) => stream.poll_flush(cx),
            EnumProj::Tls(stream) => stream.poll_flush(cx),
        }
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<IoResult<()>> {
        match self.project() {
            EnumProj::Tcp(stream) => stream.poll_close(cx),
            EnumProj::Tls(stream) => stream.poll_close(cx),
        }
    }
}