1#![cfg(target_os = "windows")]
2use std::{
5 fmt,
6 future::Future,
7 io::{self, Read, Write},
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12use schannel::tls_stream::{HandshakeError, MidHandshakeTlsStream};
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15pub struct StreamWrapper<S> {
17 stream: S,
18 context: usize,
19}
20
21impl<S> fmt::Debug for StreamWrapper<S>
22where
23 S: fmt::Debug,
24{
25 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
26 fmt::Debug::fmt(&self.stream, fmt)
27 }
28}
29
30impl<S> StreamWrapper<S> {
31 unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
36 debug_assert_ne!(self.context, 0);
37 let stream = unsafe { Pin::new_unchecked(&mut self.stream) };
38 let context = unsafe { &mut *(self.context as *mut Context<'_>) };
39 (stream, context)
40 }
41
42 }
53
54impl<S> Read for StreamWrapper<S>
55where
56 S: AsyncRead,
57{
58 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
59 let (stream, cx) = unsafe { self.parts() };
60 let mut buf = ReadBuf::new(buf);
61 match stream.poll_read(cx, &mut buf)? {
62 Poll::Ready(()) => Ok(buf.filled().len()),
63 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
64 }
65 }
66}
67
68impl<S> Write for StreamWrapper<S>
69where
70 S: AsyncWrite,
71{
72 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
73 let (stream, cx) = unsafe { self.parts() };
74 match stream.poll_write(cx, buf) {
75 Poll::Ready(r) => r,
76 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
77 }
78 }
79
80 fn flush(&mut self) -> io::Result<()> {
81 let (stream, cx) = unsafe { self.parts() };
82 match stream.poll_flush(cx) {
83 Poll::Ready(r) => r,
84 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
85 }
86 }
87}
88
89fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
91 match r {
92 Ok(v) => Poll::Ready(Ok(v)),
93 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
94 Err(e) => Poll::Ready(Err(e)),
95 }
96}
97
98impl<S> StreamWrapper<S> {
99 pub fn get_ref(&self) -> &S {
101 &self.stream
102 }
103
104 pub fn get_mut(&mut self) -> &mut S {
106 &mut self.stream
107 }
108}
109
110#[derive(Debug)]
112pub struct TlsStream<S>(schannel::tls_stream::TlsStream<StreamWrapper<S>>);
113
114impl<S> TlsStream<S> {
115 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
123 where
124 F: FnOnce(&mut schannel::tls_stream::TlsStream<StreamWrapper<S>>) -> R,
125 {
126 let this = unsafe { self.get_unchecked_mut() };
127 this.0.get_mut().context = ctx as *mut _ as usize;
128 let r = f(&mut this.0);
129 this.0.get_mut().context = 0;
130 r
131 }
132
133 pub fn get_ref(&self) -> &schannel::tls_stream::TlsStream<StreamWrapper<S>> {
135 &self.0
136 }
137
138 pub fn get_mut(&mut self) -> &mut schannel::tls_stream::TlsStream<StreamWrapper<S>> {
140 &mut self.0
141 }
142}
143
144impl<S> AsyncRead for TlsStream<S>
145where
146 S: AsyncRead + AsyncWrite + Unpin,
147{
148 fn poll_read(
149 self: Pin<&mut Self>,
150 ctx: &mut Context<'_>,
151 buf: &mut ReadBuf<'_>,
152 ) -> Poll<io::Result<()>> {
153 self.with_context(ctx, |s| {
154 match cvt(s.read(buf.initialize_unfilled()))? {
156 Poll::Ready(nread) => {
157 buf.advance(nread);
158 Poll::Ready(Ok(()))
159 }
160 Poll::Pending => Poll::Pending,
161 }
162 })
163 }
164}
165
166impl<S> AsyncWrite for TlsStream<S>
167where
168 S: AsyncRead + AsyncWrite,
169{
170 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
171 self.with_context(ctx, |s| cvt(s.write(buf)))
172 }
173
174 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
175 self.with_context(ctx, |s| cvt(s.flush()))
176 }
177
178 fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
179 self.with_context(ctx, |s| cvt(s.shutdown()))
181 }
182}
183
184pub struct TlsAcceptor {
186 inner: schannel::tls_stream::Builder,
187}
188
189impl TlsAcceptor {
190 pub fn new(inner: schannel::tls_stream::Builder) -> Self {
191 Self { inner }
192 }
193
194 pub async fn accept<S>(
195 &mut self,
196 cred: schannel::schannel_cred::SchannelCred,
197 stream: S,
198 ) -> Result<TlsStream<S>, std::io::Error>
199 where
200 S: AsyncRead + AsyncWrite + Unpin,
201 {
202 handshake(move |s| self.inner.accept(cred, s), stream).await
203 }
204}
205
206pub struct TlsConnector {
208 inner: schannel::tls_stream::Builder,
209}
210
211impl TlsConnector {
212 pub fn new(inner: schannel::tls_stream::Builder) -> Self {
213 Self { inner }
214 }
215
216 pub async fn connect<IO>(
217 &mut self,
218 cred: schannel::schannel_cred::SchannelCred,
219 stream: IO,
220 ) -> io::Result<TlsStream<IO>>
221 where
222 IO: AsyncRead + AsyncWrite + Unpin,
223 {
224 handshake(move |s| self.inner.connect(cred, s), stream).await
225 }
226}
227
228struct MidHandshake<S>(Option<MidHandshakeTlsStream<StreamWrapper<S>>>);
229
230enum StartedHandshake<S> {
231 Done(TlsStream<S>),
232 Mid(MidHandshakeTlsStream<StreamWrapper<S>>),
233}
234
235struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
236struct StartedHandshakeFutureInner<F, S> {
237 f: F,
238 stream: S,
239}
240
241async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, std::io::Error>
242where
243 F: FnOnce(
244 StreamWrapper<S>,
245 ) -> Result<
246 schannel::tls_stream::TlsStream<StreamWrapper<S>>,
247 schannel::tls_stream::HandshakeError<StreamWrapper<S>>,
248 > + Unpin,
249 S: AsyncRead + AsyncWrite + Unpin,
250{
251 let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
252
253 match start.await {
254 Err(e) => Err(e),
255 Ok(StartedHandshake::Done(s)) => Ok(s),
256 Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
257 }
258}
259
260impl<F, S> Future for StartedHandshakeFuture<F, S>
261where
262 F: FnOnce(
263 StreamWrapper<S>,
264 ) -> Result<
265 schannel::tls_stream::TlsStream<StreamWrapper<S>>,
266 schannel::tls_stream::HandshakeError<StreamWrapper<S>>,
267 > + Unpin,
268 S: Unpin,
269 StreamWrapper<S>: Read + Write,
270{
271 type Output = Result<StartedHandshake<S>, std::io::Error>;
272
273 fn poll(
274 mut self: Pin<&mut Self>,
275 ctx: &mut Context<'_>,
276 ) -> Poll<Result<StartedHandshake<S>, std::io::Error>> {
277 let inner = self.0.take().expect("future polled after completion");
278 let stream = StreamWrapper {
279 stream: inner.stream,
280 context: ctx as *mut _ as usize,
281 };
282
283 match (inner.f)(stream) {
284 Ok(mut s) => {
285 s.get_mut().context = 0;
286 Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s))))
287 }
288 Err(HandshakeError::Interrupted(mut s)) => {
289 s.get_mut().context = 0;
290 Poll::Ready(Ok(StartedHandshake::Mid(s)))
291 }
292 Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
293 }
294 }
295}
296
297impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
307 type Output = Result<TlsStream<S>, std::io::Error>;
308
309 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 let mut_self = self.get_mut();
311 let mut s = mut_self.0.take().expect("future polled after completion");
312
313 s.get_mut().context = cx as *mut _ as usize;
314 match s.handshake() {
315 Ok(mut s) => {
316 s.get_mut().context = 0;
317 Poll::Ready(Ok(TlsStream(s)))
318 }
319 Err(HandshakeError::Interrupted(mut s)) => {
320 s.get_mut().context = 0;
321 mut_self.0 = Some(s);
322 Poll::Pending
323 }
324 Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests;