Skip to main content

cc_core/config/
mysql.rs

1use std::collections::HashMap;
2
3use serde::Deserialize;
4
5use super::Validate;
6
7// ──────────────────────────────────────────────
8// MySQL 配置
9// ──────────────────────────────────────────────
10
11/// 单个 MySQL 连接的配置。
12#[derive(Debug, Clone, Deserialize)]
13pub struct MysqlConfig {
14    pub host: String,
15    #[serde(default = "default_mysql_port")]
16    pub port: u16,
17    #[serde(alias = "username")]
18    pub user: String,
19    pub password: String,
20    #[serde(default)]
21    pub database: String,
22    #[serde(default = "default_max_connections")]
23    pub max_connections: u32,
24    #[serde(default = "default_ssl_mode")]
25    pub ssl_mode: String,
26    #[serde(default)]
27    pub disable_sql_mode: bool,
28}
29
30impl Default for MysqlConfig {
31    fn default() -> Self {
32        Self {
33            host: String::new(),
34            port: default_mysql_port(),
35            user: String::new(),
36            password: String::new(),
37            database: String::new(),
38            max_connections: default_max_connections(),
39            ssl_mode: default_ssl_mode(),
40            disable_sql_mode: false,
41        }
42    }
43}
44
45fn default_mysql_port() -> u16 {
46    3306
47}
48fn default_max_connections() -> u32 {
49    5
50}
51fn default_ssl_mode() -> String {
52    "preferred".to_string()
53}
54
55impl Validate for MysqlConfig {
56    fn validate(&self) -> anyhow::Result<()> {
57        if self.host.is_empty() {
58            anyhow::bail!("MySQL host 不能为空");
59        }
60        if self.database.is_empty() {
61            anyhow::bail!("MySQL database 不能为空");
62        }
63        if self.user.is_empty() {
64            anyhow::bail!("MySQL user 不能为空");
65        }
66        if self.port == 0 {
67            anyhow::bail!("MySQL port 不能为 0");
68        }
69        if self.max_connections == 0 {
70            anyhow::bail!("MySQL max_connections 不能为 0");
71        }
72        let valid_modes = [
73            "disabled",
74            "disable",
75            "off",
76            "preferred",
77            "required",
78            "require",
79            "verify-ca",
80            "verify_ca",
81            "verify-identity",
82            "verify_identity",
83        ];
84        if !valid_modes.contains(&self.ssl_mode.as_str()) {
85            anyhow::bail!(
86                "MySQL ssl_mode 无效: `{}`,可选: disabled, preferred, required, verify-ca, verify-identity",
87                self.ssl_mode
88            );
89        }
90        Ok(())
91    }
92}
93
94// ──────────────────────────────────────────────
95// MySQL 子构建器
96// ──────────────────────────────────────────────
97
98/// MySQL 单连接构建器,提供链式 API。
99pub struct MysqlConfigBuilder(pub(crate) MysqlConfig);
100
101impl MysqlConfigBuilder {
102    pub fn host(mut self, v: impl Into<String>) -> Self {
103        self.0.host = v.into();
104        self
105    }
106    pub fn port(mut self, v: u16) -> Self {
107        self.0.port = v;
108        self
109    }
110    pub fn user(mut self, v: impl Into<String>) -> Self {
111        self.0.user = v.into();
112        self
113    }
114    pub fn password(mut self, v: impl Into<String>) -> Self {
115        self.0.password = v.into();
116        self
117    }
118    pub fn database(mut self, v: impl Into<String>) -> Self {
119        self.0.database = v.into();
120        self
121    }
122    pub fn max_connections(mut self, v: u32) -> Self {
123        self.0.max_connections = v;
124        self
125    }
126    pub fn ssl_mode(mut self, v: impl Into<String>) -> Self {
127        self.0.ssl_mode = v.into();
128        self
129    }
130    pub fn disable_sql_mode(mut self, v: bool) -> Self {
131        self.0.disable_sql_mode = v;
132        self
133    }
134}
135
136// ──────────────────────────────────────────────
137// 环境变量解析
138// ──────────────────────────────────────────────
139
140pub(crate) fn collect_env_mysql(
141    prefix: &str,
142    existing: &HashMap<String, MysqlConfig>,
143) -> anyhow::Result<HashMap<String, MysqlConfig>> {
144    let mut result = HashMap::new();
145    let pfx_upper = prefix.to_uppercase();
146
147    for (key, val) in std::env::vars() {
148        let upper = key.to_uppercase();
149        // 匹配 <PREFIX>_MYSQL_<NAME>_<FIELD>
150        let rest = match upper.strip_prefix(&format!("{pfx_upper}_MYSQL_")) {
151            Some(r) => r,
152            None => continue,
153        };
154        let (name, field) = match rest.rsplit_once('_') {
155            Some((n, f)) => (n.to_lowercase(), f),
156            None => continue,
157        };
158
159        let entry = result
160            .entry(name.clone())
161            .or_insert_with(|| existing.get(&name).cloned().unwrap_or_default());
162
163        match field {
164            "HOST" => entry.host = val,
165            "PORT" => {
166                entry.port = val
167                    .parse()
168                    .map_err(|e| anyhow::anyhow!("PORT 解析失败: {}", e))?
169            }
170            "USER" => entry.user = val,
171            "PASSWORD" => entry.password = val,
172            "DATABASE" => entry.database = val,
173            "MAX_CONNECTIONS" => {
174                entry.max_connections = val
175                    .parse()
176                    .map_err(|e| anyhow::anyhow!("MAX_CONNECTIONS 解析失败: {}", e))?
177            }
178            "SSL_MODE" => entry.ssl_mode = val,
179            "DISABLE_SQL_MODE" => {
180                entry.disable_sql_mode = matches!(val.as_str(), "1" | "true" | "TRUE")
181            }
182            _ => {}
183        }
184    }
185    Ok(result)
186}
187
188// ──────────────────────────────────────────────
189// Tests
190// ──────────────────────────────────────────────
191
192#[cfg(test)]
193mod tests {
194    use crate::ConfigBuilder;
195
196    #[test]
197    fn validation_rejects_empty_host() {
198        let result = ConfigBuilder::new()
199            .with_mysql("default", |m| {
200                m.host("").user("u").password("p").database("db")
201            })
202            .build();
203        assert!(result.is_err());
204        assert!(result.unwrap_err().to_string().contains("host 不能为空"));
205    }
206}