1pub use self::{read_packet::ReadPacket, write_packet::WritePacket};
10
11use bytes::BytesMut;
12use futures_core::{ready, stream};
13use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
14use pin_project::pin_project;
15#[cfg(not(target_os = "wasi"))]
16use socket2::{Socket as Socket2Socket, TcpKeepalive};
17#[cfg(unix)]
18use tokio::io::AsyncWriteExt;
19use tokio::{
20 io::{AsyncRead, AsyncWrite, ErrorKind::Interrupted, ReadBuf},
21 net::TcpStream,
22};
23use tokio_util::codec::{Decoder, Encoder, Framed};
24
25#[cfg(unix)]
26use std::path::Path;
27use std::{
28 fmt,
29 future::Future,
30 io::{
31 self,
32 ErrorKind::{BrokenPipe, NotConnected, Other},
33 },
34 mem::replace,
35 ops::{Deref, DerefMut},
36 pin::Pin,
37 task::{Context, Poll},
38 time::Duration,
39};
40
41use crate::{buffer_pool::PooledBuf, error::IoError, opts::HostPortOrUrl};
42
43#[cfg(unix)]
44use crate::io::socket::Socket;
45
46mod tls;
47
48macro_rules! with_interrupted {
49 ($e:expr) => {
50 loop {
51 match $e {
52 Poll::Ready(Err(err)) if err.kind() == Interrupted => continue,
53 x => break x,
54 }
55 }
56 };
57}
58
59mod read_packet;
60mod socket;
61mod write_packet;
62
63#[derive(Debug)]
64pub struct PacketCodec {
65 inner: PacketCodecInner,
66 decode_buf: PooledBuf,
67}
68
69impl Default for PacketCodec {
70 fn default() -> Self {
71 Self {
72 inner: Default::default(),
73 decode_buf: crate::BUFFER_POOL.get(),
74 }
75 }
76}
77
78impl Deref for PacketCodec {
79 type Target = PacketCodecInner;
80
81 fn deref(&self) -> &Self::Target {
82 &self.inner
83 }
84}
85
86impl DerefMut for PacketCodec {
87 fn deref_mut(&mut self) -> &mut Self::Target {
88 &mut self.inner
89 }
90}
91
92impl Decoder for PacketCodec {
93 type Item = PooledBuf;
94 type Error = IoError;
95
96 fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, IoError> {
97 if self.inner.decode(src, self.decode_buf.as_mut())? {
98 let new_buf = crate::BUFFER_POOL.get();
99 Ok(Some(replace(&mut self.decode_buf, new_buf)))
100 } else {
101 Ok(None)
102 }
103 }
104}
105
106impl Encoder<PooledBuf> for PacketCodec {
107 type Error = IoError;
108
109 fn encode(&mut self, item: PooledBuf, dst: &mut BytesMut) -> std::result::Result<(), IoError> {
110 Ok(self.inner.encode(&mut item.as_ref(), dst)?)
111 }
112}
113
114#[pin_project(project = EndpointProj)]
115#[derive(Debug)]
116pub(crate) enum Endpoint {
117 Plain(Option<TcpStream>),
118 #[cfg(feature = "native-tls-tls")]
119 Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
120 #[cfg(feature = "rustls-tls")]
121 Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
122 #[cfg(feature = "wasmedge-tls")]
123 Secure(#[pin] wasmedge_rustls_api::stream::async_stream::TlsStream<tokio::net::TcpStream>),
124 #[cfg(unix)]
125 Socket(#[pin] Socket),
126}
127
128#[derive(Debug)]
132struct CheckTcpStream<'a>(&'a mut TcpStream);
133
134impl Future for CheckTcpStream<'_> {
135 type Output = io::Result<()>;
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 match self.0.poll_read_ready(cx) {
138 Poll::Ready(Ok(())) => {
139 let mut buf = [0_u8; 1];
141 match self.0.try_read(&mut buf) {
142 Ok(0) => Poll::Ready(Err(io::Error::new(BrokenPipe, "broken pipe"))),
143 Ok(_) => Poll::Ready(Err(io::Error::new(Other, "stream should be empty"))),
144 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Ready(Ok(())),
145 Err(err) => Poll::Ready(Err(err)),
146 }
147 }
148 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
149 Poll::Pending => Poll::Ready(Ok(())),
150 }
151 }
152}
153
154impl Endpoint {
155 #[cfg(unix)]
156 fn is_socket(&self) -> bool {
157 matches!(self, Self::Socket(_))
158 }
159
160 async fn check(&mut self) -> std::result::Result<(), IoError> {
162 match self {
164 Endpoint::Plain(Some(stream)) => {
165 CheckTcpStream(stream).await?;
166 Ok(())
167 }
168 #[cfg(feature = "native-tls-tls")]
169 Endpoint::Secure(tls_stream) => {
170 CheckTcpStream(tls_stream.get_mut().get_mut().get_mut()).await?;
171 Ok(())
172 }
173 #[cfg(feature = "rustls-tls")]
174 Endpoint::Secure(tls_stream) => {
175 let stream = tls_stream.get_mut().0;
176 CheckTcpStream(stream).await?;
177 Ok(())
178 }
179 #[cfg(feature = "wasmedge-tls")]
180 Endpoint::Secure(tls_stream) => {
181 let stream = tls_stream.get_mut().0;
182 CheckTcpStream(stream).await?;
183 Ok(())
184 }
185 #[cfg(unix)]
186 Endpoint::Socket(socket) => {
187 let _ = socket.write(&[]).await?;
188 Ok(())
189 }
190 Endpoint::Plain(None) => unreachable!(),
191 }
192 }
193
194 #[cfg(any(
195 feature = "native-tls-tls",
196 feature = "rustls-tls",
197 feature = "wasmedge-tls"
198 ))]
199 pub fn is_secure(&self) -> bool {
200 matches!(self, Endpoint::Secure(_))
201 }
202
203 #[cfg(all(
204 not(feature = "native-tls"),
205 not(feature = "rustls"),
206 not(feature = "wasmedge-tls")
207 ))]
208 pub async fn make_secure(
209 &mut self,
210 _domain: String,
211 _ssl_opts: crate::SslOpts,
212 ) -> crate::error::Result<()> {
213 panic!(
214 "Client had asked for TLS connection but TLS support is disabled. \
215 Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
216 )
217 }
218
219 pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
220 match *self {
221 Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?,
222 Endpoint::Plain(None) => unreachable!(),
223 #[cfg(feature = "native-tls-tls")]
224 Endpoint::Secure(ref stream) => {
225 stream.get_ref().get_ref().get_ref().set_nodelay(val)?
226 }
227 #[cfg(feature = "rustls-tls")]
228 Endpoint::Secure(ref stream) => {
229 let stream = stream.get_ref().0;
230 stream.set_nodelay(val)?;
231 }
232 #[cfg(feature = "wasmedge-tls")]
233 Endpoint::Secure(ref stream) => {
234 let stream = stream.get_ref().0;
235 stream.set_nodelay(val)?;
236 }
237 #[cfg(unix)]
238 Endpoint::Socket(_) => (),
239 }
240 Ok(())
241 }
242}
243
244impl From<TcpStream> for Endpoint {
245 fn from(stream: TcpStream) -> Self {
246 Endpoint::Plain(Some(stream))
247 }
248}
249
250#[cfg(unix)]
251impl From<Socket> for Endpoint {
252 fn from(socket: Socket) -> Self {
253 Endpoint::Socket(socket)
254 }
255}
256
257#[cfg(feature = "native-tls-tls")]
258impl From<tokio_native_tls::TlsStream<TcpStream>> for Endpoint {
259 fn from(stream: tokio_native_tls::TlsStream<TcpStream>) -> Self {
260 Endpoint::Secure(stream)
261 }
262}
263
264impl AsyncRead for Endpoint {
269 fn poll_read(
270 self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 buf: &mut ReadBuf<'_>,
273 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
274 let mut this = self.project();
275 with_interrupted!(match this {
276 EndpointProj::Plain(ref mut stream) => {
277 Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf)
278 }
279 #[cfg(feature = "native-tls-tls")]
280 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
281 #[cfg(feature = "rustls-tls")]
282 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
283 #[cfg(feature = "wasmedge-tls")]
284 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
285 #[cfg(unix)]
286 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
287 })
288 }
289}
290
291impl AsyncWrite for Endpoint {
292 fn poll_write(
293 self: Pin<&mut Self>,
294 cx: &mut Context,
295 buf: &[u8],
296 ) -> Poll<std::result::Result<usize, tokio::io::Error>> {
297 let mut this = self.project();
298 with_interrupted!(match this {
299 EndpointProj::Plain(ref mut stream) => {
300 Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf)
301 }
302 #[cfg(feature = "native-tls-tls")]
303 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
304 #[cfg(feature = "rustls-tls")]
305 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
306 #[cfg(feature = "wasmedge-tls")]
307 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
308 #[cfg(unix)]
309 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
310 })
311 }
312
313 fn poll_flush(
314 self: Pin<&mut Self>,
315 cx: &mut Context,
316 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
317 let mut this = self.project();
318 with_interrupted!(match this {
319 EndpointProj::Plain(ref mut stream) => {
320 Pin::new(stream.as_mut().unwrap()).poll_flush(cx)
321 }
322 #[cfg(feature = "native-tls-tls")]
323 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
324 #[cfg(feature = "rustls-tls")]
325 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
326 #[cfg(feature = "wasmedge-tls")]
327 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
328 #[cfg(unix)]
329 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
330 })
331 }
332
333 fn poll_shutdown(
334 self: Pin<&mut Self>,
335 cx: &mut Context,
336 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
337 let mut this = self.project();
338 with_interrupted!(match this {
339 EndpointProj::Plain(ref mut stream) => {
340 Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx)
341 }
342 #[cfg(feature = "native-tls-tls")]
343 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
344 #[cfg(feature = "rustls-tls")]
345 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
346 #[cfg(feature = "wasmedge-tls")]
347 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
348 #[cfg(unix)]
349 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
350 })
351 }
352}
353
354pub struct Stream {
356 closed: bool,
357 pub(crate) codec: Option<Box<Framed<Endpoint, PacketCodec>>>,
358}
359
360impl fmt::Debug for Stream {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(
363 f,
364 "Stream (endpoint={:?})",
365 self.codec.as_ref().unwrap().get_ref()
366 )
367 }
368}
369
370impl Stream {
371 #[cfg(unix)]
372 fn new<T: Into<Endpoint>>(endpoint: T) -> Self {
373 let endpoint = endpoint.into();
374
375 Self {
376 closed: false,
377 codec: Box::new(Framed::new(endpoint, PacketCodec::default())).into(),
378 }
379 }
380
381 pub(crate) async fn connect_tcp(
382 addr: &HostPortOrUrl,
383 _keepalive: Option<Duration>,
384 ) -> io::Result<Stream> {
385 let tcp_stream = match addr {
386 HostPortOrUrl::HostPort(host, port) => {
387 TcpStream::connect((host.as_str(), *port)).await?
388 }
389 HostPortOrUrl::Url(url) => {
390 #[cfg(not(target_os = "wasi"))]
391 {
392 let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
393 TcpStream::connect(addrs).await?
394 }
395
396 #[cfg(target_os = "wasi")]
397 {
398 let addrs = (
399 url.host_str().expect("Unable to get host"),
400 url.port_or_known_default().expect("No port found in url"),
401 );
402 TcpStream::connect(addrs).await?
403 }
404 }
405 };
406 #[cfg(not(target_os = "wasi"))]
407 if let Some(duration) = keepalive {
408 #[cfg(unix)]
409 let socket = {
410 use std::os::unix::prelude::*;
411 let fd = tcp_stream.as_raw_fd();
412 unsafe { Socket2Socket::from_raw_fd(fd) }
413 };
414 #[cfg(windows)]
415 let socket = {
416 use std::os::windows::prelude::*;
417 let sock = tcp_stream.as_raw_socket();
418 unsafe { Socket2Socket::from_raw_socket(sock) }
419 };
420 socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?;
421 std::mem::forget(socket);
422 }
423
424 Ok(Stream {
425 closed: false,
426 codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(),
427 })
428 }
429
430 #[cfg(unix)]
431 pub(crate) async fn connect_socket<P: AsRef<Path>>(path: P) -> io::Result<Stream> {
432 Ok(Stream::new(Socket::new(path).await?))
433 }
434
435 pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
436 self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
437 }
438 pub(crate) async fn make_secure(
439 &mut self,
440 domain: String,
441 ssl_opts: crate::SslOpts,
442 ) -> crate::error::Result<()> {
443 use tokio_util::codec::FramedParts;
444
445 let codec = self.codec.take().unwrap();
446 let FramedParts { mut io, codec, .. } = codec.into_parts();
447 io.make_secure(domain, ssl_opts).await?;
448 let codec = Framed::new(io, codec);
449 self.codec = Some(Box::new(codec));
450 Ok(())
451 }
452
453 #[cfg(any(
454 feature = "native-tls-tls",
455 feature = "rustls-tls",
456 feature = "wasmedge-tls"
457 ))]
458 pub(crate) fn is_secure(&self) -> bool {
459 self.codec.as_ref().unwrap().get_ref().is_secure()
460 }
461
462 #[cfg(unix)]
463 pub(crate) fn is_socket(&self) -> bool {
464 self.codec.as_ref().unwrap().get_ref().is_socket()
465 }
466
467 pub(crate) fn reset_seq_id(&mut self) {
468 if let Some(codec) = self.codec.as_mut() {
469 codec.codec_mut().reset_seq_id();
470 }
471 }
472
473 pub(crate) fn sync_seq_id(&mut self) {
474 if let Some(codec) = self.codec.as_mut() {
475 codec.codec_mut().sync_seq_id();
476 }
477 }
478
479 pub(crate) fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) {
480 if let Some(codec) = self.codec.as_mut() {
481 codec.codec_mut().max_allowed_packet = max_allowed_packet;
482 }
483 }
484
485 pub(crate) fn compress(&mut self, level: crate::Compression) {
486 if let Some(codec) = self.codec.as_mut() {
487 codec.codec_mut().compress(level);
488 }
489 }
490
491 pub(crate) async fn check(&mut self) -> std::result::Result<(), IoError> {
493 if let Some(codec) = self.codec.as_mut() {
494 codec.get_mut().check().await?;
495 }
496 Ok(())
497 }
498
499 pub(crate) async fn close(mut self) -> std::result::Result<(), IoError> {
500 self.closed = true;
501 if let Some(mut codec) = self.codec {
502 use futures_sink::Sink;
503 futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) {
504 Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => {
505 Poll::Ready(Ok(()))
506 }
507 x => x,
508 })
509 .await?;
510 }
511 Ok(())
512 }
513}
514
515impl stream::Stream for Stream {
516 type Item = std::result::Result<PooledBuf, IoError>;
517
518 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
519 if !self.closed {
520 let item = ready!(Pin::new(self.codec.as_mut().unwrap()).poll_next(cx)).transpose()?;
521 Poll::Ready(Ok(item).transpose())
522 } else {
523 Poll::Ready(None)
524 }
525 }
526}
527
528#[cfg(test)]
529mod test {
530 #[cfg(unix)] #[tokio::test]
532 async fn should_connect_with_keepalive() {
533 use crate::{test_misc::get_opts, Conn};
534
535 let opts = get_opts()
536 .tcp_keepalive(Some(42_000_u32))
537 .prefer_socket(false);
538 let mut conn: Conn = Conn::new(opts).await.unwrap();
539 let stream = conn.stream_mut().unwrap();
540 let endpoint = stream.codec.as_mut().unwrap().get_ref();
541 let stream = match endpoint {
542 super::Endpoint::Plain(Some(stream)) => stream,
543 #[cfg(feature = "rustls-tls")]
544 super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
545 #[cfg(feature = "wasmedge-tls")]
546 super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
547 #[cfg(feature = "native-tls")]
548 super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
549 _ => unreachable!(),
550 };
551 let sock = unsafe {
552 use std::os::unix::prelude::*;
553 let raw = stream.as_raw_fd();
554 socket2::Socket::from_raw_fd(raw)
555 };
556
557 assert_eq!(
558 sock.keepalive_time().unwrap(),
559 std::time::Duration::from_millis(42_000),
560 );
561
562 std::mem::forget(sock);
563
564 conn.disconnect().await.unwrap();
565 }
566}