1use std::io::IoSlice;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use pin_project_lite::pin_project;
21use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
22use tokio_rustls::server::TlsStream;
23use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
24
25use crate::commands::ClientHandshake;
26use crate::myc::constants::CapabilityFlags;
27use crate::packet_reader::PacketReader;
28use crate::packet_writer::PacketWriter;
29use crate::{AsyncMysqlIntermediary, AsyncMysqlShim, IntermediaryOptions};
30
31pub async fn plain_run_with_options<B, R, W>(
32 shim: B,
33 writer: W,
34 opts: IntermediaryOptions,
35 init_params: (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
36) -> Result<(), B::Error>
37where
38 B: AsyncMysqlShim<W> + Send + Sync,
39 R: AsyncRead + Send + Unpin,
40 W: AsyncWrite + Send + Unpin,
41{
42 let (handshake, seq, client_capabilities, reader) = init_params;
43 let reader = PacketReader::new(reader);
44 let writer = PacketWriter::new(writer);
45
46 let process_use_statement_on_query = opts.process_use_statement_on_query;
47 let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
48 let mut mi = AsyncMysqlIntermediary {
49 client_capabilities,
50 process_use_statement_on_query,
51 reject_connection_on_dbname_absence,
52 shim,
53 reader,
54 writer,
55 };
56 mi.init_after_ssl(handshake, seq).await?;
57 mi.run().await
58}
59
60pub async fn secure_run_with_options<B, R, W>(
61 shim: B,
62 writer: W,
63 opts: IntermediaryOptions,
64 tls_config: Arc<ServerConfig>,
65 init_params: (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
66) -> Result<(), B::Error>
67where
68 B: AsyncMysqlShim<WriteHalf<TlsStream<Duplex<PacketReader<R>, W>>>> + Send + Sync,
69 R: AsyncRead + Send + Unpin,
70 W: AsyncWrite + Send + Unpin,
71{
72 let (handshake, seq, client_capabilities, reader) = init_params;
73 let (reader, writer) = switch_to_tls(tls_config, reader, writer).await?;
74 let reader = PacketReader::new(reader);
75 let writer = PacketWriter::new(writer);
76
77 let process_use_statement_on_query = opts.process_use_statement_on_query;
78 let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
79 let mut mi = AsyncMysqlIntermediary {
80 client_capabilities,
81 process_use_statement_on_query,
82 reject_connection_on_dbname_absence,
83 shim,
84 reader,
85 writer,
86 };
87 mi.init_after_ssl(handshake, seq).await?;
88 mi.run().await
89}
90
91pub async fn switch_to_tls<R: AsyncRead + Send + Unpin, W: AsyncWrite + Send + Unpin>(
92 config: Arc<ServerConfig>,
93 reader: R,
94 writer: W,
95) -> std::io::Result<(
96 ReadHalf<TlsStream<Duplex<R, W>>>,
97 WriteHalf<TlsStream<Duplex<R, W>>>,
98)> {
99 let stream = Duplex::new(reader, writer);
100 let acceptor = TlsAcceptor::from(config);
101 let stream = acceptor.accept(stream).await?;
102 let (r, w) = tokio::io::split(stream);
103 Ok((r, w))
104}
105
106pin_project! {
107 #[derive(Clone, Debug)]
108 pub struct Duplex<R, W> {
109 #[pin]
110 reader: R,
111 #[pin]
112 writer: W,
113 }
114}
115
116impl<R, W> Duplex<R, W> {
117 pub fn new(reader: R, writer: W) -> Self {
118 Self { reader, writer }
119 }
120}
121
122impl<R: AsyncRead, W> AsyncRead for Duplex<R, W> {
123 fn poll_read(
124 self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 buf: &mut ReadBuf<'_>,
127 ) -> Poll<io::Result<()>> {
128 AsyncRead::poll_read(self.project().reader, cx, buf)
129 }
130}
131
132impl<R, W: AsyncWrite> AsyncWrite for Duplex<R, W> {
133 fn poll_write(
134 self: Pin<&mut Self>,
135 cx: &mut Context<'_>,
136 buf: &[u8],
137 ) -> Poll<io::Result<usize>> {
138 AsyncWrite::poll_write(self.project().writer, cx, buf)
139 }
140
141 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142 AsyncWrite::poll_flush(self.project().writer, cx)
143 }
144
145 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
146 AsyncWrite::poll_shutdown(self.project().writer, cx)
147 }
148
149 fn poll_write_vectored(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 bufs: &[IoSlice<'_>],
153 ) -> Poll<Result<usize, io::Error>> {
154 AsyncWrite::poll_write_vectored(self.project().writer, cx, bufs)
155 }
156
157 fn is_write_vectored(&self) -> bool {
158 AsyncWrite::is_write_vectored(&self.writer)
159 }
160}