1use std::collections::HashMap;
4
5use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlPoolOptions, MySqlSslMode};
6
7use crate::config::{Config, IntoMysqlName, MysqlConfig};
8
9pub fn ssl_mode_from_str(s: &str) -> MySqlSslMode {
11 match s.trim().to_ascii_lowercase().as_str() {
12 "disabled" | "disable" | "off" => MySqlSslMode::Disabled,
13 "required" | "require" => MySqlSslMode::Required,
14 "verify-ca" | "verify_ca" => MySqlSslMode::VerifyCa,
15 "verify-identity" | "verify_identity" => MySqlSslMode::VerifyIdentity,
16 _ => MySqlSslMode::Preferred,
17 }
18}
19
20pub fn connect_options(cfg: &MysqlConfig) -> MySqlConnectOptions {
22 let mut opts = MySqlConnectOptions::new()
23 .host(&cfg.host)
24 .port(cfg.port)
25 .username(&cfg.user)
26 .password(&cfg.password)
27 .ssl_mode(ssl_mode_from_str(&cfg.ssl_mode));
28
29 if cfg.disable_sql_mode {
30 opts = opts.no_engine_substitution(false).pipes_as_concat(false);
31 }
32
33 if !cfg.database.is_empty() {
34 opts = opts.database(&cfg.database);
35 }
36 opts
37}
38
39pub async fn connect(cfg: &MysqlConfig) -> anyhow::Result<MySqlPool> {
41 let pool = MySqlPoolOptions::new()
42 .max_connections(cfg.max_connections)
43 .connect_with(connect_options(cfg))
44 .await?;
45 Ok(pool)
46}
47
48pub struct MysqlPools {
50 pools: HashMap<String, MySqlPool>,
51}
52
53impl MysqlPools {
54 pub async fn from_config(cfg: &Config) -> anyhow::Result<Self> {
56 let mut pools = HashMap::new();
57 for (name, mc) in &cfg.mysql {
58 pools.insert(name.clone(), connect(mc).await?);
59 }
60 Ok(Self { pools })
61 }
62
63 pub fn get(&self, name: impl IntoMysqlName) -> Option<&MySqlPool> {
65 self.pools.get(&name.into_name())
66 }
67
68 pub fn require(&self, name: impl IntoMysqlName) -> anyhow::Result<&MySqlPool> {
70 let name = name.into_name();
71 self.pools
72 .get(&name)
73 .ok_or_else(|| anyhow::anyhow!("未找到名为 `{}` 的 MySQL 连接", name))
74 }
75
76 pub fn default(&self) -> anyhow::Result<&MySqlPool> {
78 self.require("default")
79 }
80}