tokio_schannel/
lib.rs

1#![cfg(target_os = "windows")]
2// schannel
3
4use std::{
5    fmt,
6    future::Future,
7    io::{self, Read, Write},
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use schannel::tls_stream::{HandshakeError, MidHandshakeTlsStream};
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15// ----------- wrapper for none async/ pollable stream.
16pub struct StreamWrapper<S> {
17    stream: S,
18    context: usize,
19}
20
21impl<S> fmt::Debug for StreamWrapper<S>
22where
23    S: fmt::Debug,
24{
25    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
26        fmt::Debug::fmt(&self.stream, fmt)
27    }
28}
29
30impl<S> StreamWrapper<S> {
31    /// # Safety
32    ///
33    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
34    /// wrapper must be pinned in memory.
35    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
36        debug_assert_ne!(self.context, 0);
37        let stream = unsafe { Pin::new_unchecked(&mut self.stream) };
38        let context = unsafe { &mut *(self.context as *mut Context<'_>) };
39        (stream, context)
40    }
41
42    // // internal helper to set context and execute sync function.
43    // fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
44    // where
45    //     F: FnOnce(&mut Self) -> R,
46    // {
47    //     self.context = ctx as *mut _ as usize;
48    //     let r = f(self);
49    //     self.context = 0;
50    //     r
51    // }
52}
53
54impl<S> Read for StreamWrapper<S>
55where
56    S: AsyncRead,
57{
58    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
59        let (stream, cx) = unsafe { self.parts() };
60        let mut buf = ReadBuf::new(buf);
61        match stream.poll_read(cx, &mut buf)? {
62            Poll::Ready(()) => Ok(buf.filled().len()),
63            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
64        }
65    }
66}
67
68impl<S> Write for StreamWrapper<S>
69where
70    S: AsyncWrite,
71{
72    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
73        let (stream, cx) = unsafe { self.parts() };
74        match stream.poll_write(cx, buf) {
75            Poll::Ready(r) => r,
76            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
77        }
78    }
79
80    fn flush(&mut self) -> io::Result<()> {
81        let (stream, cx) = unsafe { self.parts() };
82        match stream.poll_flush(cx) {
83            Poll::Ready(r) => r,
84            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
85        }
86    }
87}
88
89// cvt error to poll
90fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
91    match r {
92        Ok(v) => Poll::Ready(Ok(v)),
93        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
94        Err(e) => Poll::Ready(Err(e)),
95    }
96}
97
98impl<S> StreamWrapper<S> {
99    /// Returns a shared reference to the inner stream.
100    pub fn get_ref(&self) -> &S {
101        &self.stream
102    }
103
104    /// Returns a mutable reference to the inner stream.
105    pub fn get_mut(&mut self) -> &mut S {
106        &mut self.stream
107    }
108}
109
110/// Wrapper around schannels' tls stream and provide async apis.
111#[derive(Debug)]
112pub struct TlsStream<S>(schannel::tls_stream::TlsStream<StreamWrapper<S>>);
113
114impl<S> TlsStream<S> {
115    // /// Like [`TlsStream::new`](schannel::tls_stream::TlsStream).
116    // pub fn new( stream: S) -> Result<Self, ErrorStack> {
117    //     ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
118    // }
119    //pub fn poll_connect()
120
121    // pass the ctx in the wrapper and invoke f
122    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
123    where
124        F: FnOnce(&mut schannel::tls_stream::TlsStream<StreamWrapper<S>>) -> R,
125    {
126        let this = unsafe { self.get_unchecked_mut() };
127        this.0.get_mut().context = ctx as *mut _ as usize;
128        let r = f(&mut this.0);
129        this.0.get_mut().context = 0;
130        r
131    }
132
133    /// Returns a shared reference to the inner stream.
134    pub fn get_ref(&self) -> &schannel::tls_stream::TlsStream<StreamWrapper<S>> {
135        &self.0
136    }
137
138    /// Returns a mutable reference to the inner stream.
139    pub fn get_mut(&mut self) -> &mut schannel::tls_stream::TlsStream<StreamWrapper<S>> {
140        &mut self.0
141    }
142}
143
144impl<S> AsyncRead for TlsStream<S>
145where
146    S: AsyncRead + AsyncWrite + Unpin,
147{
148    fn poll_read(
149        self: Pin<&mut Self>,
150        ctx: &mut Context<'_>,
151        buf: &mut ReadBuf<'_>,
152    ) -> Poll<io::Result<()>> {
153        self.with_context(ctx, |s| {
154            // TODO: read into uninitialized for optimize
155            match cvt(s.read(buf.initialize_unfilled()))? {
156                Poll::Ready(nread) => {
157                    buf.advance(nread);
158                    Poll::Ready(Ok(()))
159                }
160                Poll::Pending => Poll::Pending,
161            }
162        })
163    }
164}
165
166impl<S> AsyncWrite for TlsStream<S>
167where
168    S: AsyncRead + AsyncWrite,
169{
170    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
171        self.with_context(ctx, |s| cvt(s.write(buf)))
172    }
173
174    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
175        self.with_context(ctx, |s| cvt(s.flush()))
176    }
177
178    fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
179        // TODO: May need to check error and retry
180        self.with_context(ctx, |s| cvt(s.shutdown()))
181    }
182}
183
184// acceptor
185pub struct TlsAcceptor {
186    inner: schannel::tls_stream::Builder,
187}
188
189impl TlsAcceptor {
190    pub fn new(inner: schannel::tls_stream::Builder) -> Self {
191        Self { inner }
192    }
193
194    pub async fn accept<S>(
195        &mut self,
196        cred: schannel::schannel_cred::SchannelCred,
197        stream: S,
198    ) -> Result<TlsStream<S>, std::io::Error>
199    where
200        S: AsyncRead + AsyncWrite + Unpin,
201    {
202        handshake(move |s| self.inner.accept(cred, s), stream).await
203    }
204}
205
206// connector
207pub struct TlsConnector {
208    inner: schannel::tls_stream::Builder,
209}
210
211impl TlsConnector {
212    pub fn new(inner: schannel::tls_stream::Builder) -> Self {
213        Self { inner }
214    }
215
216    pub async fn connect<IO>(
217        &mut self,
218        cred: schannel::schannel_cred::SchannelCred,
219        stream: IO,
220    ) -> io::Result<TlsStream<IO>>
221    where
222        IO: AsyncRead + AsyncWrite + Unpin,
223    {
224        handshake(move |s| self.inner.connect(cred, s), stream).await
225    }
226}
227
228struct MidHandshake<S>(Option<MidHandshakeTlsStream<StreamWrapper<S>>>);
229
230enum StartedHandshake<S> {
231    Done(TlsStream<S>),
232    Mid(MidHandshakeTlsStream<StreamWrapper<S>>),
233}
234
235struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
236struct StartedHandshakeFutureInner<F, S> {
237    f: F,
238    stream: S,
239}
240
241async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, std::io::Error>
242where
243    F: FnOnce(
244            StreamWrapper<S>,
245        ) -> Result<
246            schannel::tls_stream::TlsStream<StreamWrapper<S>>,
247            schannel::tls_stream::HandshakeError<StreamWrapper<S>>,
248        > + Unpin,
249    S: AsyncRead + AsyncWrite + Unpin,
250{
251    let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
252
253    match start.await {
254        Err(e) => Err(e),
255        Ok(StartedHandshake::Done(s)) => Ok(s),
256        Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
257    }
258}
259
260impl<F, S> Future for StartedHandshakeFuture<F, S>
261where
262    F: FnOnce(
263            StreamWrapper<S>,
264        ) -> Result<
265            schannel::tls_stream::TlsStream<StreamWrapper<S>>,
266            schannel::tls_stream::HandshakeError<StreamWrapper<S>>,
267        > + Unpin,
268    S: Unpin,
269    StreamWrapper<S>: Read + Write,
270{
271    type Output = Result<StartedHandshake<S>, std::io::Error>;
272
273    fn poll(
274        mut self: Pin<&mut Self>,
275        ctx: &mut Context<'_>,
276    ) -> Poll<Result<StartedHandshake<S>, std::io::Error>> {
277        let inner = self.0.take().expect("future polled after completion");
278        let stream = StreamWrapper {
279            stream: inner.stream,
280            context: ctx as *mut _ as usize,
281        };
282
283        match (inner.f)(stream) {
284            Ok(mut s) => {
285                s.get_mut().context = 0;
286                Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s))))
287            }
288            Err(HandshakeError::Interrupted(mut s)) => {
289                s.get_mut().context = 0;
290                Poll::Ready(Ok(StartedHandshake::Mid(s)))
291            }
292            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
293        }
294    }
295}
296
297// pub struct StartHandShakeFu<S> {
298//     f: dyn FnOnce() -> StartHandShake<S>,
299// }
300
301// impl<S> Future for StartHandShakeFu<S>
302// {
303//     type Output = StartHandShake<S>;
304// }
305
306impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
307    type Output = Result<TlsStream<S>, std::io::Error>;
308
309    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310        let mut_self = self.get_mut();
311        let mut s = mut_self.0.take().expect("future polled after completion");
312
313        s.get_mut().context = cx as *mut _ as usize;
314        match s.handshake() {
315            Ok(mut s) => {
316                s.get_mut().context = 0;
317                Poll::Ready(Ok(TlsStream(s)))
318            }
319            Err(HandshakeError::Interrupted(mut s)) => {
320                s.get_mut().context = 0;
321                mut_self.0 = Some(s);
322                Poll::Pending
323            }
324            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests;