fluvio_future/openssl/
stream.rs

1use std::fmt::Debug;
2use std::io;
3use std::io::Read;
4use std::io::Write;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use futures_lite::{AsyncRead, AsyncWrite};
9use openssl::ssl;
10
11use super::async_to_sync_wrapper::AsyncToSyncWrapper;
12use super::certificate::Certificate;
13
14#[derive(Debug)]
15pub struct TlsStream<S>(pub(super) ssl::SslStream<AsyncToSyncWrapper<S>>);
16
17impl<S: Unpin> TlsStream<S> {
18    pub fn peer_certificate(&self) -> Option<Certificate> {
19        self.0.ssl().peer_certificate().map(Certificate)
20    }
21
22    fn with_context<F, R>(&mut self, cx: &mut Context<'_>, f: F) -> Poll<io::Result<R>>
23    where
24        F: FnOnce(&mut ssl::SslStream<AsyncToSyncWrapper<S>>) -> io::Result<R>,
25    {
26        self.0.get_mut().set_context(cx);
27        let r = f(&mut self.0);
28        self.0.get_mut().unset_context();
29        result_to_poll(r)
30    }
31}
32
33impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
34    fn poll_read(
35        mut self: Pin<&mut Self>,
36        cx: &mut Context<'_>,
37        buf: &mut [u8],
38    ) -> Poll<io::Result<usize>> {
39        self.with_context(cx, |stream| stream.read(buf))
40    }
41}
42
43impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
44    fn poll_write(
45        mut self: Pin<&mut Self>,
46        cx: &mut Context<'_>,
47        buf: &[u8],
48    ) -> Poll<io::Result<usize>> {
49        self.with_context(cx, |stream| stream.write(buf))
50    }
51
52    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53        self.with_context(cx, |stream| stream.flush())
54    }
55
56    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
57        self.with_context(cx, |stream| match stream.shutdown() {
58            Ok(_) => Ok(()),
59            Err(ref e) if e.code() == openssl::ssl::ErrorCode::ZERO_RETURN => Ok(()),
60            Err(e) => Err(io::Error::other(e)),
61        })
62    }
63}
64
65fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
66    match r {
67        Ok(v) => Poll::Ready(Ok(v)),
68        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
69        Err(e) => Poll::Ready(Err(e)),
70    }
71}