datafusion_table_providers/sql/db_connection_pool/
mysqlpool.rs1use std::{collections::HashMap, path::PathBuf, sync::Arc};
2
3use async_trait::async_trait;
4use mysql_async::{
5 prelude::{Queryable, ToValue},
6 DriverError, Metrics, Opts, Params, PoolConstraints, PoolOpts, Row, SslOpts,
7 DEFAULT_POOL_CONSTRAINTS,
8};
9use secrecy::{ExposeSecret, SecretBox, SecretString};
10use snafu::{ResultExt, Snafu};
11
12use crate::{
13 sql::db_connection_pool::{
14 dbconnection::{mysqlconn::MySQLConnection, AsyncDbConnection, DbConnection},
15 JoinPushDown,
16 },
17 util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect},
18};
19
20use super::DbConnectionPool;
21
22pub type Result<T, E = Error> = std::result::Result<T, E>;
23
24#[derive(Debug, Snafu)]
25pub enum Error {
26 #[snafu(display("MySQL connection failed.\n{source}\nFor further information, refer to the MySQL manual: https://dev.mysql.com/doc/mysql-errors/9.1/en/error-reference-introduction.html"))]
27 MySQLConnectionError { source: mysql_async::Error },
28
29 #[snafu(display(
30 "Invalid MySQL connection string.\n{source}\nEnsure the MySQL connection string is valid"
31 ))]
32 InvalidConnectionString { source: mysql_async::UrlError },
33
34 #[snafu(display("Invalid value for parameter {parameter_name}\nEnsure the value is valid for parameter {parameter_name}"))]
35 InvalidParameterError { parameter_name: String },
36
37 #[snafu(display("Invalid root cert path: {path}\nEnsure the root cert path is valid"))]
38 InvalidRootCertPathError { path: String },
39
40 #[snafu(display("Cannot connect to MySQL on {host}:{port}. Ensure the host and port are correct and reachable."))]
41 InvalidHostOrPortError {
42 source: crate::util::ns_lookup::Error,
43 host: String,
44 port: u16,
45 },
46
47 #[snafu(display("Authentication failed. Verify username and password."))]
48 InvalidUsernameOrPassword,
49
50 #[snafu(display("{message}\nEnsure the given MySQL database exists"))]
51 UnknownMySQLDatabase { message: String },
52}
53
54#[derive(Debug, Clone)]
55pub struct MySQLConnectionPool {
56 pool: Arc<mysql_async::Pool>,
57 join_push_down: JoinPushDown,
58}
59
60const SETUP_QUERIES: [&str; 4] = [
61 "SET time_zone = '+00:00'",
62 "SET character_set_results = 'utf8mb4'",
63 "SET character_set_client = 'utf8mb4'",
64 "SET character_set_connection = 'utf8mb4'",
65];
66
67impl MySQLConnectionPool {
68 #[allow(clippy::unused_async)]
88 pub async fn new(params: HashMap<String, SecretString>) -> Result<Self> {
89 let params = util::remove_prefix_from_hashmap_keys(params, "mysql_");
91
92 let mut connection_string = mysql_async::OptsBuilder::default();
93 let mut ssl_mode = "required";
94 let mut ssl_rootcert_path: Option<PathBuf> = None;
95
96 if let Some(mysql_connection_string) = params
97 .get("connection_string")
98 .map(SecretBox::expose_secret)
99 {
100 connection_string = mysql_async::OptsBuilder::from_opts(
101 mysql_async::Opts::from_url(mysql_connection_string)
102 .context(InvalidConnectionStringSnafu)?,
103 );
104 } else {
105 if let Some(mysql_host) = params.get("host").map(SecretBox::expose_secret) {
106 connection_string = connection_string.ip_or_hostname(mysql_host);
107 }
108 if let Some(mysql_user) = params.get("user").map(SecretBox::expose_secret) {
109 connection_string = connection_string.user(Some(mysql_user));
110 }
111 if let Some(mysql_db) = params.get("db").map(SecretBox::expose_secret) {
112 connection_string = connection_string.db_name(Some(mysql_db));
113 }
114 if let Some(mysql_pass) = params.get("pass").map(SecretBox::expose_secret) {
115 connection_string = connection_string.pass(Some(mysql_pass));
116 }
117 if let Some(mysql_tcp_port) = params.get("tcp_port").map(SecretBox::expose_secret) {
118 connection_string =
119 connection_string.tcp_port(mysql_tcp_port.parse::<u16>().unwrap_or(3306));
120 }
121
122 let pool_min = params
123 .get("pool_min")
124 .map(SecretBox::expose_secret)
125 .unwrap_or_default()
126 .parse::<usize>()
127 .unwrap_or(DEFAULT_POOL_CONSTRAINTS.min());
128
129 let pool_max = params
130 .get("pool_max")
131 .map(SecretBox::expose_secret)
132 .unwrap_or_default()
133 .parse::<usize>()
134 .unwrap_or(DEFAULT_POOL_CONSTRAINTS.max());
135
136 connection_string = connection_string
137 .pool_opts(PoolOpts::default().with_constraints(
138 PoolConstraints::new(pool_min, pool_max).unwrap_or_default(),
139 ));
140 }
141
142 if let Some(mysql_sslmode) = params.get("sslmode").map(SecretBox::expose_secret) {
143 match mysql_sslmode.to_lowercase().as_str() {
144 "disabled" | "required" | "preferred" => {
145 ssl_mode = mysql_sslmode;
146 }
147 _ => {
148 InvalidParameterSnafu {
149 parameter_name: "sslmode".to_string(),
150 }
151 .fail()?;
152 }
153 }
154 }
155 if let Some(mysql_sslrootcert) = params.get("sslrootcert").map(SecretBox::expose_secret) {
156 if !std::path::Path::new(mysql_sslrootcert).exists() {
157 InvalidRootCertPathSnafu {
158 path: mysql_sslrootcert,
159 }
160 .fail()?;
161 }
162
163 ssl_rootcert_path = Some(PathBuf::from(mysql_sslrootcert));
164 }
165
166 let ssl_opts = get_ssl_opts(ssl_mode, ssl_rootcert_path);
167
168 connection_string = connection_string.ssl_opts(ssl_opts);
169
170 connection_string = connection_string.setup(SETUP_QUERIES.to_vec());
171
172 let opts = mysql_async::Opts::from(connection_string);
173
174 verify_mysql_opts(&opts).await?;
175
176 let join_push_down = get_join_context(&opts);
177
178 let pool = mysql_async::Pool::new(opts);
179
180 let mut conn = pool.get_conn().await.map_err(|err| match err {
182 mysql_async::Error::Driver(DriverError::UnknownAuthPlugin { .. }) => {
185 Error::InvalidUsernameOrPassword
186 }
187 mysql_async::Error::Server(server_error) => {
188 match server_error.code {
189 1049 => Error::UnknownMySQLDatabase {
191 message: server_error.message,
192 },
193 1045 => Error::InvalidUsernameOrPassword,
195 _ => Error::MySQLConnectionError {
196 source: mysql_async::Error::Server(server_error),
197 },
198 }
199 }
200 _ => Error::MySQLConnectionError { source: err },
201 })?;
202
203 let _rows: Vec<Row> = conn
204 .exec("SELECT 1", Params::Empty)
205 .await
206 .context(MySQLConnectionSnafu)?;
207
208 Ok(Self {
209 pool: Arc::new(pool),
210 join_push_down,
211 })
212 }
213
214 pub async fn connect_direct(&self) -> super::Result<MySQLConnection> {
220 let pool = Arc::clone(&self.pool);
221 let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?;
222
223 Ok(MySQLConnection::new(conn))
224 }
225
226 pub fn metrics(&self) -> Arc<Metrics> {
227 self.pool.metrics()
228 }
229}
230
231async fn verify_mysql_opts(opts: &Opts) -> Result<()> {
232 let host = opts.ip_or_hostname();
234 let port = opts.tcp_port();
235
236 verify_ns_lookup_and_tcp_connect(host, port)
237 .await
238 .context(InvalidHostOrPortSnafu { host, port })?;
239
240 Ok(())
241}
242
243fn get_join_context(opts: &mysql_async::Opts) -> JoinPushDown {
244 let mut join_context = format!("host={},port={}", opts.ip_or_hostname(), opts.tcp_port());
245 if let Some(db_name) = opts.db_name() {
246 join_context.push_str(&format!(",db={db_name}"));
247 }
248 if let Some(user) = opts.user() {
249 join_context.push_str(&format!(",user={user}"));
250 }
251
252 JoinPushDown::AllowedFor(join_context)
253}
254
255fn get_ssl_opts(ssl_mode: &str, rootcert_path: Option<PathBuf>) -> Option<SslOpts> {
256 if ssl_mode == "disabled" {
257 return None;
258 }
259
260 let mut opts = SslOpts::default();
261
262 if let Some(rootcert_path) = rootcert_path {
263 let path = rootcert_path;
264 opts = opts.with_root_certs(vec![path.into()]);
265 }
266
267 if ssl_mode == "preferred" {
270 opts = opts
271 .with_danger_accept_invalid_certs(true)
272 .with_danger_skip_domain_validation(true);
273 }
274
275 Some(opts)
276}
277
278#[async_trait]
279impl DbConnectionPool<mysql_async::Conn, &'static (dyn ToValue + Sync)> for MySQLConnectionPool {
280 async fn connect(
281 &self,
282 ) -> super::Result<Box<dyn DbConnection<mysql_async::Conn, &'static (dyn ToValue + Sync)>>>
283 {
284 let pool = Arc::clone(&self.pool);
285 let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?;
286
287 Ok(Box::new(MySQLConnection::new(conn)))
288 }
289
290 fn join_push_down(&self) -> JoinPushDown {
291 self.join_push_down.clone()
292 }
293}