opensrv_mysql/
tls.rs

1// Copyright 2021 Datafuse Labs.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io::IoSlice;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use pin_project_lite::pin_project;
21use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
22use tokio_rustls::server::TlsStream;
23use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
24
25use crate::commands::ClientHandshake;
26use crate::myc::constants::CapabilityFlags;
27use crate::packet_reader::PacketReader;
28use crate::packet_writer::PacketWriter;
29use crate::{AsyncMysqlIntermediary, AsyncMysqlShim, IntermediaryOptions};
30
31pub async fn plain_run_with_options<B, R, W>(
32    shim: B,
33    writer: W,
34    opts: IntermediaryOptions,
35    init_params: (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
36) -> Result<(), B::Error>
37where
38    B: AsyncMysqlShim<W> + Send + Sync,
39    R: AsyncRead + Send + Unpin,
40    W: AsyncWrite + Send + Unpin,
41{
42    let (handshake, seq, client_capabilities, reader) = init_params;
43    let reader = PacketReader::new(reader);
44    let writer = PacketWriter::new(writer);
45
46    let process_use_statement_on_query = opts.process_use_statement_on_query;
47    let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
48    let mut mi = AsyncMysqlIntermediary {
49        client_capabilities,
50        process_use_statement_on_query,
51        reject_connection_on_dbname_absence,
52        shim,
53        reader,
54        writer,
55    };
56    mi.init_after_ssl(handshake, seq).await?;
57    mi.run().await
58}
59
60pub async fn secure_run_with_options<B, R, W>(
61    shim: B,
62    writer: W,
63    opts: IntermediaryOptions,
64    tls_config: Arc<ServerConfig>,
65    init_params: (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
66) -> Result<(), B::Error>
67where
68    B: AsyncMysqlShim<WriteHalf<TlsStream<Duplex<PacketReader<R>, W>>>> + Send + Sync,
69    R: AsyncRead + Send + Unpin,
70    W: AsyncWrite + Send + Unpin,
71{
72    let (handshake, seq, client_capabilities, reader) = init_params;
73    let (reader, writer) = switch_to_tls(tls_config, reader, writer).await?;
74    let reader = PacketReader::new(reader);
75    let writer = PacketWriter::new(writer);
76
77    let process_use_statement_on_query = opts.process_use_statement_on_query;
78    let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
79    let mut mi = AsyncMysqlIntermediary {
80        client_capabilities,
81        process_use_statement_on_query,
82        reject_connection_on_dbname_absence,
83        shim,
84        reader,
85        writer,
86    };
87    mi.init_after_ssl(handshake, seq).await?;
88    mi.run().await
89}
90
91pub async fn switch_to_tls<R: AsyncRead + Send + Unpin, W: AsyncWrite + Send + Unpin>(
92    config: Arc<ServerConfig>,
93    reader: R,
94    writer: W,
95) -> std::io::Result<(
96    ReadHalf<TlsStream<Duplex<R, W>>>,
97    WriteHalf<TlsStream<Duplex<R, W>>>,
98)> {
99    let stream = Duplex::new(reader, writer);
100    let acceptor = TlsAcceptor::from(config);
101    let stream = acceptor.accept(stream).await?;
102    let (r, w) = tokio::io::split(stream);
103    Ok((r, w))
104}
105
106pin_project! {
107    #[derive(Clone, Debug)]
108    pub struct Duplex<R, W> {
109        #[pin]
110        reader: R,
111        #[pin]
112        writer: W,
113    }
114}
115
116impl<R, W> Duplex<R, W> {
117    pub fn new(reader: R, writer: W) -> Self {
118        Self { reader, writer }
119    }
120}
121
122impl<R: AsyncRead, W> AsyncRead for Duplex<R, W> {
123    fn poll_read(
124        self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        buf: &mut ReadBuf<'_>,
127    ) -> Poll<io::Result<()>> {
128        AsyncRead::poll_read(self.project().reader, cx, buf)
129    }
130}
131
132impl<R, W: AsyncWrite> AsyncWrite for Duplex<R, W> {
133    fn poll_write(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136        buf: &[u8],
137    ) -> Poll<io::Result<usize>> {
138        AsyncWrite::poll_write(self.project().writer, cx, buf)
139    }
140
141    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142        AsyncWrite::poll_flush(self.project().writer, cx)
143    }
144
145    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
146        AsyncWrite::poll_shutdown(self.project().writer, cx)
147    }
148
149    fn poll_write_vectored(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        bufs: &[IoSlice<'_>],
153    ) -> Poll<Result<usize, io::Error>> {
154        AsyncWrite::poll_write_vectored(self.project().writer, cx, bufs)
155    }
156
157    fn is_write_vectored(&self) -> bool {
158        AsyncWrite::is_write_vectored(&self.writer)
159    }
160}