datafusion_table_providers/sql/db_connection_pool/
mysqlpool.rs

1use std::{collections::HashMap, path::PathBuf, sync::Arc};
2
3use async_trait::async_trait;
4use mysql_async::{
5    prelude::{Queryable, ToValue},
6    DriverError, Metrics, Opts, Params, PoolConstraints, PoolOpts, Row, SslOpts,
7    DEFAULT_POOL_CONSTRAINTS,
8};
9use secrecy::{ExposeSecret, SecretBox, SecretString};
10use snafu::{ResultExt, Snafu};
11
12use crate::{
13    sql::db_connection_pool::{
14        dbconnection::{mysqlconn::MySQLConnection, AsyncDbConnection, DbConnection},
15        JoinPushDown,
16    },
17    util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect},
18};
19
20use super::DbConnectionPool;
21
22pub type Result<T, E = Error> = std::result::Result<T, E>;
23
24#[derive(Debug, Snafu)]
25pub enum Error {
26    #[snafu(display("MySQL connection failed.\n{source}\nFor further information, refer to the MySQL manual: https://dev.mysql.com/doc/mysql-errors/9.1/en/error-reference-introduction.html"))]
27    MySQLConnectionError { source: mysql_async::Error },
28
29    #[snafu(display(
30        "Invalid MySQL connection string.\n{source}\nEnsure the MySQL connection string is valid"
31    ))]
32    InvalidConnectionString { source: mysql_async::UrlError },
33
34    #[snafu(display("Invalid value for parameter {parameter_name}\nEnsure the value is valid for parameter {parameter_name}"))]
35    InvalidParameterError { parameter_name: String },
36
37    #[snafu(display("Invalid root cert path: {path}\nEnsure the root cert path is valid"))]
38    InvalidRootCertPathError { path: String },
39
40    #[snafu(display("Cannot connect to MySQL on {host}:{port}. Ensure the host and port are correct and reachable."))]
41    InvalidHostOrPortError {
42        source: crate::util::ns_lookup::Error,
43        host: String,
44        port: u16,
45    },
46
47    #[snafu(display("Authentication failed. Verify username and password."))]
48    InvalidUsernameOrPassword,
49
50    #[snafu(display("{message}\nEnsure the given MySQL database exists"))]
51    UnknownMySQLDatabase { message: String },
52}
53
54#[derive(Debug, Clone)]
55pub struct MySQLConnectionPool {
56    pool: Arc<mysql_async::Pool>,
57    join_push_down: JoinPushDown,
58}
59
60const SETUP_QUERIES: [&str; 4] = [
61    "SET time_zone = '+00:00'",
62    "SET character_set_results = 'utf8mb4'",
63    "SET character_set_client = 'utf8mb4'",
64    "SET character_set_connection = 'utf8mb4'",
65];
66
67impl MySQLConnectionPool {
68    /// Creates a new instance of `MySQLConnectionPool`.
69    ///
70    /// # Arguments
71    ///
72    /// * `params` - A map of parameters to create the connection pool.
73    ///   * `connection_string` - The connection string to use to connect to the MySQL database, or can be specified with the below individual parameters.
74    ///   * `host` - The host of the MySQL database.
75    ///   * `user` - The user to use when connecting to the MySQL database.
76    ///   * `db` - The database to connect to.
77    ///   * `pass` - The password to use when connecting to the MySQL database.
78    ///   * `tcp_port` - The TCP port to use when connecting to the MySQL database.
79    ///   * `sslmode` - The SSL mode to use when connecting to the MySQL database. Can be "disabled", "required", or "preferred".
80    ///   * `sslrootcert` - The path to the root certificate to use when connecting to the MySQL database.
81    ///   * `pool_min` - The minimum number of connections to keep open in the pool, lazily created when requested.
82    ///   * `pool_max` - The maximum number of connections to allow in the pool.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if there is a problem creating the connection pool.
87    #[allow(clippy::unused_async)]
88    pub async fn new(params: HashMap<String, SecretString>) -> Result<Self> {
89        // Remove the "mysql_" prefix from the keys to keep backward compatibility
90        let params = util::remove_prefix_from_hashmap_keys(params, "mysql_");
91
92        let mut connection_string = mysql_async::OptsBuilder::default();
93        let mut ssl_mode = "required";
94        let mut ssl_rootcert_path: Option<PathBuf> = None;
95
96        if let Some(mysql_connection_string) = params
97            .get("connection_string")
98            .map(SecretBox::expose_secret)
99        {
100            connection_string = mysql_async::OptsBuilder::from_opts(
101                mysql_async::Opts::from_url(mysql_connection_string)
102                    .context(InvalidConnectionStringSnafu)?,
103            );
104        } else {
105            if let Some(mysql_host) = params.get("host").map(SecretBox::expose_secret) {
106                connection_string = connection_string.ip_or_hostname(mysql_host);
107            }
108            if let Some(mysql_user) = params.get("user").map(SecretBox::expose_secret) {
109                connection_string = connection_string.user(Some(mysql_user));
110            }
111            if let Some(mysql_db) = params.get("db").map(SecretBox::expose_secret) {
112                connection_string = connection_string.db_name(Some(mysql_db));
113            }
114            if let Some(mysql_pass) = params.get("pass").map(SecretBox::expose_secret) {
115                connection_string = connection_string.pass(Some(mysql_pass));
116            }
117            if let Some(mysql_tcp_port) = params.get("tcp_port").map(SecretBox::expose_secret) {
118                connection_string =
119                    connection_string.tcp_port(mysql_tcp_port.parse::<u16>().unwrap_or(3306));
120            }
121
122            let pool_min = params
123                .get("pool_min")
124                .map(SecretBox::expose_secret)
125                .unwrap_or_default()
126                .parse::<usize>()
127                .unwrap_or(DEFAULT_POOL_CONSTRAINTS.min());
128
129            let pool_max = params
130                .get("pool_max")
131                .map(SecretBox::expose_secret)
132                .unwrap_or_default()
133                .parse::<usize>()
134                .unwrap_or(DEFAULT_POOL_CONSTRAINTS.max());
135
136            connection_string = connection_string
137                .pool_opts(PoolOpts::default().with_constraints(
138                    PoolConstraints::new(pool_min, pool_max).unwrap_or_default(),
139                ));
140        }
141
142        if let Some(mysql_sslmode) = params.get("sslmode").map(SecretBox::expose_secret) {
143            match mysql_sslmode.to_lowercase().as_str() {
144                "disabled" | "required" | "preferred" => {
145                    ssl_mode = mysql_sslmode;
146                }
147                _ => {
148                    InvalidParameterSnafu {
149                        parameter_name: "sslmode".to_string(),
150                    }
151                    .fail()?;
152                }
153            }
154        }
155        if let Some(mysql_sslrootcert) = params.get("sslrootcert").map(SecretBox::expose_secret) {
156            if !std::path::Path::new(mysql_sslrootcert).exists() {
157                InvalidRootCertPathSnafu {
158                    path: mysql_sslrootcert,
159                }
160                .fail()?;
161            }
162
163            ssl_rootcert_path = Some(PathBuf::from(mysql_sslrootcert));
164        }
165
166        let ssl_opts = get_ssl_opts(ssl_mode, ssl_rootcert_path);
167
168        connection_string = connection_string.ssl_opts(ssl_opts);
169
170        connection_string = connection_string.setup(SETUP_QUERIES.to_vec());
171
172        let opts = mysql_async::Opts::from(connection_string);
173
174        verify_mysql_opts(&opts).await?;
175
176        let join_push_down = get_join_context(&opts);
177
178        let pool = mysql_async::Pool::new(opts);
179
180        // Test the connection
181        let mut conn = pool.get_conn().await.map_err(|err| match err {
182            // In case of an incorrect user name, the error `Unknown authentication plugin 'sha256_password'` is returned.
183            // We override it with a more user-friendly error message.
184            mysql_async::Error::Driver(DriverError::UnknownAuthPlugin { .. }) => {
185                Error::InvalidUsernameOrPassword
186            }
187            mysql_async::Error::Server(server_error) => {
188                match server_error.code {
189                    // Code 1049: Server error: `ERROR 42000 (1049): Unknown database <database>
190                    1049 => Error::UnknownMySQLDatabase {
191                        message: server_error.message,
192                    },
193                    // Code 1045: Server error: ERROR 1045 (28000): Access denied for user <user> (using password: YES / NO)
194                    1045 => Error::InvalidUsernameOrPassword,
195                    _ => Error::MySQLConnectionError {
196                        source: mysql_async::Error::Server(server_error),
197                    },
198                }
199            }
200            _ => Error::MySQLConnectionError { source: err },
201        })?;
202
203        let _rows: Vec<Row> = conn
204            .exec("SELECT 1", Params::Empty)
205            .await
206            .context(MySQLConnectionSnafu)?;
207
208        Ok(Self {
209            pool: Arc::new(pool),
210            join_push_down,
211        })
212    }
213
214    /// Returns a direct connection to the underlying database.
215    ///
216    /// # Errors
217    ///
218    /// Returns an error if there is a problem creating the connection pool.
219    pub async fn connect_direct(&self) -> super::Result<MySQLConnection> {
220        let pool = Arc::clone(&self.pool);
221        let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?;
222
223        Ok(MySQLConnection::new(conn))
224    }
225
226    pub fn metrics(&self) -> Arc<Metrics> {
227        self.pool.metrics()
228    }
229}
230
231async fn verify_mysql_opts(opts: &Opts) -> Result<()> {
232    // Verify the host and port are correct
233    let host = opts.ip_or_hostname();
234    let port = opts.tcp_port();
235
236    verify_ns_lookup_and_tcp_connect(host, port)
237        .await
238        .context(InvalidHostOrPortSnafu { host, port })?;
239
240    Ok(())
241}
242
243fn get_join_context(opts: &mysql_async::Opts) -> JoinPushDown {
244    let mut join_context = format!("host={},port={}", opts.ip_or_hostname(), opts.tcp_port());
245    if let Some(db_name) = opts.db_name() {
246        join_context.push_str(&format!(",db={db_name}"));
247    }
248    if let Some(user) = opts.user() {
249        join_context.push_str(&format!(",user={user}"));
250    }
251
252    JoinPushDown::AllowedFor(join_context)
253}
254
255fn get_ssl_opts(ssl_mode: &str, rootcert_path: Option<PathBuf>) -> Option<SslOpts> {
256    if ssl_mode == "disabled" {
257        return None;
258    }
259
260    let mut opts = SslOpts::default();
261
262    if let Some(rootcert_path) = rootcert_path {
263        let path = rootcert_path;
264        opts = opts.with_root_certs(vec![path.into()]);
265    }
266
267    // If ssl_mode is "preferred", we will accept invalid certs and skip domain validation
268    // mysql_async does not have a "ssl_mode" https://github.com/blackbeam/mysql_async/issues/225#issuecomment-1409922237
269    if ssl_mode == "preferred" {
270        opts = opts
271            .with_danger_accept_invalid_certs(true)
272            .with_danger_skip_domain_validation(true);
273    }
274
275    Some(opts)
276}
277
278#[async_trait]
279impl DbConnectionPool<mysql_async::Conn, &'static (dyn ToValue + Sync)> for MySQLConnectionPool {
280    async fn connect(
281        &self,
282    ) -> super::Result<Box<dyn DbConnection<mysql_async::Conn, &'static (dyn ToValue + Sync)>>>
283    {
284        let pool = Arc::clone(&self.pool);
285        let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?;
286
287        Ok(Box::new(MySQLConnection::new(conn)))
288    }
289
290    fn join_push_down(&self) -> JoinPushDown {
291        self.join_push_down.clone()
292    }
293}