mysql_async/io/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9pub use self::{read_packet::ReadPacket, write_packet::WritePacket};
10
11use bytes::BytesMut;
12use futures_core::{ready, stream};
13use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
14use pin_project::pin_project;
15#[cfg(not(target_os = "wasi"))]
16use socket2::{Socket as Socket2Socket, TcpKeepalive};
17#[cfg(unix)]
18use tokio::io::AsyncWriteExt;
19use tokio::{
20    io::{AsyncRead, AsyncWrite, ErrorKind::Interrupted, ReadBuf},
21    net::TcpStream,
22};
23use tokio_util::codec::{Decoder, Encoder, Framed};
24
25#[cfg(unix)]
26use std::path::Path;
27use std::{
28    fmt,
29    future::Future,
30    io::{
31        self,
32        ErrorKind::{BrokenPipe, NotConnected, Other},
33    },
34    mem::replace,
35    ops::{Deref, DerefMut},
36    pin::Pin,
37    task::{Context, Poll},
38    time::Duration,
39};
40
41use crate::{buffer_pool::PooledBuf, error::IoError, opts::HostPortOrUrl};
42
43#[cfg(unix)]
44use crate::io::socket::Socket;
45
46mod tls;
47
48macro_rules! with_interrupted {
49    ($e:expr) => {
50        loop {
51            match $e {
52                Poll::Ready(Err(err)) if err.kind() == Interrupted => continue,
53                x => break x,
54            }
55        }
56    };
57}
58
59mod read_packet;
60mod socket;
61mod write_packet;
62
63#[derive(Debug)]
64pub struct PacketCodec {
65    inner: PacketCodecInner,
66    decode_buf: PooledBuf,
67}
68
69impl Default for PacketCodec {
70    fn default() -> Self {
71        Self {
72            inner: Default::default(),
73            decode_buf: crate::BUFFER_POOL.get(),
74        }
75    }
76}
77
78impl Deref for PacketCodec {
79    type Target = PacketCodecInner;
80
81    fn deref(&self) -> &Self::Target {
82        &self.inner
83    }
84}
85
86impl DerefMut for PacketCodec {
87    fn deref_mut(&mut self) -> &mut Self::Target {
88        &mut self.inner
89    }
90}
91
92impl Decoder for PacketCodec {
93    type Item = PooledBuf;
94    type Error = IoError;
95
96    fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, IoError> {
97        if self.inner.decode(src, self.decode_buf.as_mut())? {
98            let new_buf = crate::BUFFER_POOL.get();
99            Ok(Some(replace(&mut self.decode_buf, new_buf)))
100        } else {
101            Ok(None)
102        }
103    }
104}
105
106impl Encoder<PooledBuf> for PacketCodec {
107    type Error = IoError;
108
109    fn encode(&mut self, item: PooledBuf, dst: &mut BytesMut) -> std::result::Result<(), IoError> {
110        Ok(self.inner.encode(&mut item.as_ref(), dst)?)
111    }
112}
113
114#[pin_project(project = EndpointProj)]
115#[derive(Debug)]
116pub(crate) enum Endpoint {
117    Plain(Option<TcpStream>),
118    #[cfg(feature = "native-tls-tls")]
119    Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
120    #[cfg(feature = "rustls-tls")]
121    Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
122    #[cfg(feature = "wasmedge-tls")]
123    Secure(#[pin] wasmedge_rustls_api::stream::async_stream::TlsStream<tokio::net::TcpStream>),
124    #[cfg(unix)]
125    Socket(#[pin] Socket),
126}
127
128/// This future will check that TcpStream is live.
129///
130/// This check is similar to a one, implemented by GitHub team for the go-sql-driver/mysql.
131#[derive(Debug)]
132struct CheckTcpStream<'a>(&'a mut TcpStream);
133
134impl Future for CheckTcpStream<'_> {
135    type Output = io::Result<()>;
136    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        match self.0.poll_read_ready(cx) {
138            Poll::Ready(Ok(())) => {
139                // stream is readable
140                let mut buf = [0_u8; 1];
141                match self.0.try_read(&mut buf) {
142                    Ok(0) => Poll::Ready(Err(io::Error::new(BrokenPipe, "broken pipe"))),
143                    Ok(_) => Poll::Ready(Err(io::Error::new(Other, "stream should be empty"))),
144                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Ready(Ok(())),
145                    Err(err) => Poll::Ready(Err(err)),
146                }
147            }
148            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
149            Poll::Pending => Poll::Ready(Ok(())),
150        }
151    }
152}
153
154impl Endpoint {
155    #[cfg(unix)]
156    fn is_socket(&self) -> bool {
157        matches!(self, Self::Socket(_))
158    }
159
160    /// Checks, that connection is alive.
161    async fn check(&mut self) -> std::result::Result<(), IoError> {
162        //return Ok(());
163        match self {
164            Endpoint::Plain(Some(stream)) => {
165                CheckTcpStream(stream).await?;
166                Ok(())
167            }
168            #[cfg(feature = "native-tls-tls")]
169            Endpoint::Secure(tls_stream) => {
170                CheckTcpStream(tls_stream.get_mut().get_mut().get_mut()).await?;
171                Ok(())
172            }
173            #[cfg(feature = "rustls-tls")]
174            Endpoint::Secure(tls_stream) => {
175                let stream = tls_stream.get_mut().0;
176                CheckTcpStream(stream).await?;
177                Ok(())
178            }
179            #[cfg(feature = "wasmedge-tls")]
180            Endpoint::Secure(tls_stream) => {
181                let stream = tls_stream.get_mut().0;
182                CheckTcpStream(stream).await?;
183                Ok(())
184            }
185            #[cfg(unix)]
186            Endpoint::Socket(socket) => {
187                let _ = socket.write(&[]).await?;
188                Ok(())
189            }
190            Endpoint::Plain(None) => unreachable!(),
191        }
192    }
193
194    #[cfg(any(
195        feature = "native-tls-tls",
196        feature = "rustls-tls",
197        feature = "wasmedge-tls"
198    ))]
199    pub fn is_secure(&self) -> bool {
200        matches!(self, Endpoint::Secure(_))
201    }
202
203    #[cfg(all(
204        not(feature = "native-tls"),
205        not(feature = "rustls"),
206        not(feature = "wasmedge-tls")
207    ))]
208    pub async fn make_secure(
209        &mut self,
210        _domain: String,
211        _ssl_opts: crate::SslOpts,
212    ) -> crate::error::Result<()> {
213        panic!(
214            "Client had asked for TLS connection but TLS support is disabled. \
215            Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
216        )
217    }
218
219    pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
220        match *self {
221            Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?,
222            Endpoint::Plain(None) => unreachable!(),
223            #[cfg(feature = "native-tls-tls")]
224            Endpoint::Secure(ref stream) => {
225                stream.get_ref().get_ref().get_ref().set_nodelay(val)?
226            }
227            #[cfg(feature = "rustls-tls")]
228            Endpoint::Secure(ref stream) => {
229                let stream = stream.get_ref().0;
230                stream.set_nodelay(val)?;
231            }
232            #[cfg(feature = "wasmedge-tls")]
233            Endpoint::Secure(ref stream) => {
234                let stream = stream.get_ref().0;
235                stream.set_nodelay(val)?;
236            }
237            #[cfg(unix)]
238            Endpoint::Socket(_) => (/* inapplicable */),
239        }
240        Ok(())
241    }
242}
243
244impl From<TcpStream> for Endpoint {
245    fn from(stream: TcpStream) -> Self {
246        Endpoint::Plain(Some(stream))
247    }
248}
249
250#[cfg(unix)]
251impl From<Socket> for Endpoint {
252    fn from(socket: Socket) -> Self {
253        Endpoint::Socket(socket)
254    }
255}
256
257#[cfg(feature = "native-tls-tls")]
258impl From<tokio_native_tls::TlsStream<TcpStream>> for Endpoint {
259    fn from(stream: tokio_native_tls::TlsStream<TcpStream>) -> Self {
260        Endpoint::Secure(stream)
261    }
262}
263
264/* TODO
265#[cfg(feature = "rustls-tls")]
266*/
267
268impl AsyncRead for Endpoint {
269    fn poll_read(
270        self: Pin<&mut Self>,
271        cx: &mut Context<'_>,
272        buf: &mut ReadBuf<'_>,
273    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
274        let mut this = self.project();
275        with_interrupted!(match this {
276            EndpointProj::Plain(ref mut stream) => {
277                Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf)
278            }
279            #[cfg(feature = "native-tls-tls")]
280            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
281            #[cfg(feature = "rustls-tls")]
282            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
283            #[cfg(feature = "wasmedge-tls")]
284            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
285            #[cfg(unix)]
286            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
287        })
288    }
289}
290
291impl AsyncWrite for Endpoint {
292    fn poll_write(
293        self: Pin<&mut Self>,
294        cx: &mut Context,
295        buf: &[u8],
296    ) -> Poll<std::result::Result<usize, tokio::io::Error>> {
297        let mut this = self.project();
298        with_interrupted!(match this {
299            EndpointProj::Plain(ref mut stream) => {
300                Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf)
301            }
302            #[cfg(feature = "native-tls-tls")]
303            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
304            #[cfg(feature = "rustls-tls")]
305            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
306            #[cfg(feature = "wasmedge-tls")]
307            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
308            #[cfg(unix)]
309            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
310        })
311    }
312
313    fn poll_flush(
314        self: Pin<&mut Self>,
315        cx: &mut Context,
316    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
317        let mut this = self.project();
318        with_interrupted!(match this {
319            EndpointProj::Plain(ref mut stream) => {
320                Pin::new(stream.as_mut().unwrap()).poll_flush(cx)
321            }
322            #[cfg(feature = "native-tls-tls")]
323            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
324            #[cfg(feature = "rustls-tls")]
325            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
326            #[cfg(feature = "wasmedge-tls")]
327            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
328            #[cfg(unix)]
329            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
330        })
331    }
332
333    fn poll_shutdown(
334        self: Pin<&mut Self>,
335        cx: &mut Context,
336    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
337        let mut this = self.project();
338        with_interrupted!(match this {
339            EndpointProj::Plain(ref mut stream) => {
340                Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx)
341            }
342            #[cfg(feature = "native-tls-tls")]
343            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
344            #[cfg(feature = "rustls-tls")]
345            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
346            #[cfg(feature = "wasmedge-tls")]
347            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
348            #[cfg(unix)]
349            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
350        })
351    }
352}
353
354/// A Stream, connected to MySql server.
355pub struct Stream {
356    closed: bool,
357    pub(crate) codec: Option<Box<Framed<Endpoint, PacketCodec>>>,
358}
359
360impl fmt::Debug for Stream {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        write!(
363            f,
364            "Stream (endpoint={:?})",
365            self.codec.as_ref().unwrap().get_ref()
366        )
367    }
368}
369
370impl Stream {
371    #[cfg(unix)]
372    fn new<T: Into<Endpoint>>(endpoint: T) -> Self {
373        let endpoint = endpoint.into();
374
375        Self {
376            closed: false,
377            codec: Box::new(Framed::new(endpoint, PacketCodec::default())).into(),
378        }
379    }
380
381    pub(crate) async fn connect_tcp(
382        addr: &HostPortOrUrl,
383        _keepalive: Option<Duration>,
384    ) -> io::Result<Stream> {
385        let tcp_stream = match addr {
386            HostPortOrUrl::HostPort(host, port) => {
387                TcpStream::connect((host.as_str(), *port)).await?
388            }
389            HostPortOrUrl::Url(url) => {
390                #[cfg(not(target_os = "wasi"))]
391                {
392                    let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
393                    TcpStream::connect(addrs).await?
394                }
395
396                #[cfg(target_os = "wasi")]
397                {
398                    let addrs = (
399                        url.host_str().expect("Unable to get host"),
400                        url.port_or_known_default().expect("No port found in url"),
401                    );
402                    TcpStream::connect(addrs).await?
403                }
404            }
405        };
406        #[cfg(not(target_os = "wasi"))]
407        if let Some(duration) = keepalive {
408            #[cfg(unix)]
409            let socket = {
410                use std::os::unix::prelude::*;
411                let fd = tcp_stream.as_raw_fd();
412                unsafe { Socket2Socket::from_raw_fd(fd) }
413            };
414            #[cfg(windows)]
415            let socket = {
416                use std::os::windows::prelude::*;
417                let sock = tcp_stream.as_raw_socket();
418                unsafe { Socket2Socket::from_raw_socket(sock) }
419            };
420            socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?;
421            std::mem::forget(socket);
422        }
423
424        Ok(Stream {
425            closed: false,
426            codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(),
427        })
428    }
429
430    #[cfg(unix)]
431    pub(crate) async fn connect_socket<P: AsRef<Path>>(path: P) -> io::Result<Stream> {
432        Ok(Stream::new(Socket::new(path).await?))
433    }
434
435    pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
436        self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
437    }
438    pub(crate) async fn make_secure(
439        &mut self,
440        domain: String,
441        ssl_opts: crate::SslOpts,
442    ) -> crate::error::Result<()> {
443        use tokio_util::codec::FramedParts;
444
445        let codec = self.codec.take().unwrap();
446        let FramedParts { mut io, codec, .. } = codec.into_parts();
447        io.make_secure(domain, ssl_opts).await?;
448        let codec = Framed::new(io, codec);
449        self.codec = Some(Box::new(codec));
450        Ok(())
451    }
452
453    #[cfg(any(
454        feature = "native-tls-tls",
455        feature = "rustls-tls",
456        feature = "wasmedge-tls"
457    ))]
458    pub(crate) fn is_secure(&self) -> bool {
459        self.codec.as_ref().unwrap().get_ref().is_secure()
460    }
461
462    #[cfg(unix)]
463    pub(crate) fn is_socket(&self) -> bool {
464        self.codec.as_ref().unwrap().get_ref().is_socket()
465    }
466
467    pub(crate) fn reset_seq_id(&mut self) {
468        if let Some(codec) = self.codec.as_mut() {
469            codec.codec_mut().reset_seq_id();
470        }
471    }
472
473    pub(crate) fn sync_seq_id(&mut self) {
474        if let Some(codec) = self.codec.as_mut() {
475            codec.codec_mut().sync_seq_id();
476        }
477    }
478
479    pub(crate) fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) {
480        if let Some(codec) = self.codec.as_mut() {
481            codec.codec_mut().max_allowed_packet = max_allowed_packet;
482        }
483    }
484
485    pub(crate) fn compress(&mut self, level: crate::Compression) {
486        if let Some(codec) = self.codec.as_mut() {
487            codec.codec_mut().compress(level);
488        }
489    }
490
491    /// Checks, that connection is alive.
492    pub(crate) async fn check(&mut self) -> std::result::Result<(), IoError> {
493        if let Some(codec) = self.codec.as_mut() {
494            codec.get_mut().check().await?;
495        }
496        Ok(())
497    }
498
499    pub(crate) async fn close(mut self) -> std::result::Result<(), IoError> {
500        self.closed = true;
501        if let Some(mut codec) = self.codec {
502            use futures_sink::Sink;
503            futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) {
504                Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => {
505                    Poll::Ready(Ok(()))
506                }
507                x => x,
508            })
509            .await?;
510        }
511        Ok(())
512    }
513}
514
515impl stream::Stream for Stream {
516    type Item = std::result::Result<PooledBuf, IoError>;
517
518    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
519        if !self.closed {
520            let item = ready!(Pin::new(self.codec.as_mut().unwrap()).poll_next(cx)).transpose()?;
521            Poll::Ready(Ok(item).transpose())
522        } else {
523            Poll::Ready(None)
524        }
525    }
526}
527
528#[cfg(test)]
529mod test {
530    #[cfg(unix)] // no sane way to retrieve current keepalive value on windows
531    #[tokio::test]
532    async fn should_connect_with_keepalive() {
533        use crate::{test_misc::get_opts, Conn};
534
535        let opts = get_opts()
536            .tcp_keepalive(Some(42_000_u32))
537            .prefer_socket(false);
538        let mut conn: Conn = Conn::new(opts).await.unwrap();
539        let stream = conn.stream_mut().unwrap();
540        let endpoint = stream.codec.as_mut().unwrap().get_ref();
541        let stream = match endpoint {
542            super::Endpoint::Plain(Some(stream)) => stream,
543            #[cfg(feature = "rustls-tls")]
544            super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
545            #[cfg(feature = "wasmedge-tls")]
546            super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
547            #[cfg(feature = "native-tls")]
548            super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
549            _ => unreachable!(),
550        };
551        let sock = unsafe {
552            use std::os::unix::prelude::*;
553            let raw = stream.as_raw_fd();
554            socket2::Socket::from_raw_fd(raw)
555        };
556
557        assert_eq!(
558            sock.keepalive_time().unwrap(),
559            std::time::Duration::from_millis(42_000),
560        );
561
562        std::mem::forget(sock);
563
564        conn.disconnect().await.unwrap();
565    }
566}