1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
use super::{Io, MayBeTls, TlsUpgrade};
use crate::common::*;
use core::panic;
use std::fmt;

pub struct TlsCapable {
    state: State,
}

enum State {
    /// TLS upgrade is not enabled - only plaintext or wrapper mode
    /// or is already encrypted
    Done(Box<dyn Io>, bool),
    /// Plain TCP stream with name and potential TLS upgrade
    Enabled(Box<dyn Io>, Box<dyn TlsUpgrade>, String),
    /// Pending TLS handshake
    Handshake(S3Fut<std::io::Result<Box<dyn Io>>>),
    /// TLS failed or in transition state
    Failed,
}

impl MayBeTls for TlsCapable {
    fn encrypt(mut self: Pin<&mut Self>) {
        match std::mem::replace(&mut self.state, State::Failed) {
            State::Enabled(io, provider, peer_name) => {
                trace!("Switching to TLS");
                // Calling `upgrade_to_tls` will start the TLS handshake
                // The handshake is a future we can await to get an encrypted
                // stream back.
                let newme = State::Handshake(Box::pin(provider.upgrade_to_tls(io, peer_name)));
                self.state = newme;
            }
            State::Done(_, encrypted) => self.fail(
                format!(
                    "start_tls: TLS upgrade is not enabled. encrypted: {}",
                    encrypted,
                )
                .as_str(),
            ),
            State::Handshake(_) => self.fail("start_tls: TLS handshake already in progress"),
            State::Failed => self.fail("start_tls: TLS setup failed"),
        }
    }
    fn can_encrypt(&self) -> bool {
        match self.state {
            State::Done(_, _) => false,
            State::Enabled(_, _, _) => true,
            State::Handshake(_) => false,
            State::Failed => false,
        }
    }
    fn is_encrypted(&self) -> bool {
        match self.state {
            State::Done(_, encrypted) => encrypted,
            State::Enabled(_, _, _) => false,
            State::Handshake(_) => true,
            State::Failed => false,
        }
    }

    fn enable_encryption(&mut self, upgrade: Box<dyn super::TlsUpgrade>, name: String) {
        self.state = match std::mem::replace(&mut self.state, State::Failed) {
            State::Enabled(io, _, _) => State::Enabled(io, upgrade, name),
            State::Done(io, _) => State::Enabled(io, upgrade, name),
            State::Handshake(_) => panic!("currently upgrading"),
            State::Failed => panic!("IO failed"),
        }
    }
}
impl TlsCapable {
    pub fn plaintext(io: Box<dyn Io>) -> Self {
        TlsCapable {
            state: State::Done(io, false),
        }
    }
    pub fn encrypted(io: Box<dyn Io>) -> Self {
        TlsCapable {
            state: State::Done(io, true),
        }
    }
    pub fn enabled(io: Box<dyn Io>, upgrade: Box<dyn TlsUpgrade>, peer_name: String) -> Self {
        TlsCapable {
            state: State::Enabled(io, upgrade, peer_name),
        }
    }
    fn poll_tls(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        let this = self.get_mut();
        match this.state {
            State::Handshake(ref mut h) => {
                trace!("Waiting for TLS handshake");
                match Pin::new(h).poll(cx)? {
                    Poll::Pending => {
                        trace!("TLS is not ready yet");
                        Poll::Pending
                    }
                    Poll::Ready(encrypted) => {
                        trace!("TLS is on!");
                        this.state = State::Done(encrypted, true);
                        Poll::Ready(Ok(()))
                    }
                }
            }
            _ => Poll::Ready(Ok(())),
        }
    }
    fn fail(mut self: Pin<&mut Self>, msg: &str) {
        error!("{}", msg);
        self.state = State::Failed;
    }
    fn failed() -> std::io::Error {
        std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Tls setup failed")
    }
    fn ready_failed<T>() -> Poll<std::io::Result<T>> {
        Poll::Ready(Err(Self::failed()))
    }
}

impl io::Read for TlsCapable {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<std::io::Result<usize>> {
        trace!("poll_read on {:?}", self.state);
        match (self.as_mut().poll_tls(cx))? {
            Poll::Pending => return Poll::Pending,
            Poll::Ready(()) => (),
        };
        let result = match self.state {
            State::Done(ref mut io, _) => Pin::new(io).poll_read(cx, buf),
            State::Enabled(ref mut io, _, _) => Pin::new(io).poll_read(cx, buf),
            State::Handshake(_) => unreachable!("poll_read: This path is handled in poll_tls()"),
            State::Failed => Self::ready_failed(),
        };
        trace!("poll_read got {:?}", result);
        result
    }
}

impl io::Write for TlsCapable {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        match (self.as_mut().poll_tls(cx))? {
            Poll::Pending => return Poll::Pending,
            Poll::Ready(()) => (),
        };
        match self.state {
            State::Done(ref mut io, _) => Pin::new(io).poll_write(cx, buf),
            State::Enabled(ref mut io, _, _) => Pin::new(io).poll_write(cx, buf),
            State::Handshake(_) => unreachable!("poll_write: This path is handled in poll_tls()"),
            State::Failed => Self::ready_failed(),
        }
    }
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match (self.as_mut().poll_tls(cx))? {
            Poll::Pending => return Poll::Pending,
            Poll::Ready(()) => (),
        };
        match self.state {
            State::Done(ref mut io, _) => Pin::new(io).poll_flush(cx),
            State::Enabled(ref mut io, _, _) => Pin::new(io).poll_flush(cx),
            State::Handshake(_) => unreachable!("poll_flush: This path is handled in poll_tls()"),
            State::Failed => Self::ready_failed(),
        }
    }
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match (self.as_mut().poll_tls(cx))? {
            Poll::Pending => return Poll::Pending,
            Poll::Ready(()) => (),
        };
        match self.state {
            State::Done(ref mut io, _) => Pin::new(io).poll_close(cx),
            State::Enabled(ref mut io, _, _) => Pin::new(io).poll_close(cx),
            State::Handshake(_) => unreachable!("poll_close: This path is handled in poll_tls()"),
            State::Failed => Self::ready_failed(),
        }
    }
}

impl std::fmt::Debug for TlsCapable {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
        self.state.fmt(f)
    }
}

impl fmt::Debug for State {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        use State as S;
        fmt.write_str(match self {
            S::Done(_, encrypted) => {
                if *encrypted {
                    "Done(stream,encrypted)"
                } else {
                    "Done(stream,plaintext)"
                }
            }
            S::Enabled(_, _, _) => "Enabled(stream, upgrade, peer_name)",
            S::Handshake(_) => "Handshake(tls_handshake)",
            S::Failed => "Failed",
        })
    }
}