1use std::collections::HashMap;
2
3use serde::Deserialize;
4
5use super::Validate;
6
7#[derive(Debug, Clone, Deserialize)]
13pub struct MysqlConfig {
14 pub host: String,
15 #[serde(default = "default_mysql_port")]
16 pub port: u16,
17 #[serde(alias = "username")]
18 pub user: String,
19 pub password: String,
20 #[serde(default)]
21 pub database: String,
22 #[serde(default = "default_max_connections")]
23 pub max_connections: u32,
24 #[serde(default = "default_ssl_mode")]
25 pub ssl_mode: String,
26}
27
28impl Default for MysqlConfig {
29 fn default() -> Self {
30 Self {
31 host: String::new(),
32 port: default_mysql_port(),
33 user: String::new(),
34 password: String::new(),
35 database: String::new(),
36 max_connections: default_max_connections(),
37 ssl_mode: default_ssl_mode(),
38 }
39 }
40}
41
42fn default_mysql_port() -> u16 {
43 3306
44}
45fn default_max_connections() -> u32 {
46 5
47}
48fn default_ssl_mode() -> String {
49 "preferred".to_string()
50}
51
52impl Validate for MysqlConfig {
53 fn validate(&self) -> anyhow::Result<()> {
54 if self.host.is_empty() {
55 anyhow::bail!("MySQL host 不能为空");
56 }
57 if self.database.is_empty() {
58 anyhow::bail!("MySQL database 不能为空");
59 }
60 if self.user.is_empty() {
61 anyhow::bail!("MySQL user 不能为空");
62 }
63 if self.port == 0 {
64 anyhow::bail!("MySQL port 不能为 0");
65 }
66 if self.max_connections == 0 {
67 anyhow::bail!("MySQL max_connections 不能为 0");
68 }
69 let valid_modes = [
70 "disabled",
71 "disable",
72 "off",
73 "preferred",
74 "required",
75 "require",
76 "verify-ca",
77 "verify_ca",
78 "verify-identity",
79 "verify_identity",
80 ];
81 if !valid_modes.contains(&self.ssl_mode.as_str()) {
82 anyhow::bail!(
83 "MySQL ssl_mode 无效: `{}`,可选: disabled, preferred, required, verify-ca, verify-identity",
84 self.ssl_mode
85 );
86 }
87 Ok(())
88 }
89}
90
91pub struct MysqlConfigBuilder(pub(crate) MysqlConfig);
97
98impl MysqlConfigBuilder {
99 pub fn host(mut self, v: impl Into<String>) -> Self {
100 self.0.host = v.into();
101 self
102 }
103 pub fn port(mut self, v: u16) -> Self {
104 self.0.port = v;
105 self
106 }
107 pub fn user(mut self, v: impl Into<String>) -> Self {
108 self.0.user = v.into();
109 self
110 }
111 pub fn password(mut self, v: impl Into<String>) -> Self {
112 self.0.password = v.into();
113 self
114 }
115 pub fn database(mut self, v: impl Into<String>) -> Self {
116 self.0.database = v.into();
117 self
118 }
119 pub fn max_connections(mut self, v: u32) -> Self {
120 self.0.max_connections = v;
121 self
122 }
123 pub fn ssl_mode(mut self, v: impl Into<String>) -> Self {
124 self.0.ssl_mode = v.into();
125 self
126 }
127}
128
129pub(crate) fn collect_env_mysql(
134 prefix: &str,
135 existing: &HashMap<String, MysqlConfig>,
136) -> anyhow::Result<HashMap<String, MysqlConfig>> {
137 let mut result = HashMap::new();
138 let pfx_upper = prefix.to_uppercase();
139
140 for (key, val) in std::env::vars() {
141 let upper = key.to_uppercase();
142 let rest = match upper.strip_prefix(&format!("{pfx_upper}_MYSQL_")) {
144 Some(r) => r,
145 None => continue,
146 };
147 let (name, field) = match rest.rsplit_once('_') {
148 Some((n, f)) => (n.to_lowercase(), f),
149 None => continue,
150 };
151
152 let entry = result
153 .entry(name.clone())
154 .or_insert_with(|| existing.get(&name).cloned().unwrap_or_default());
155
156 match field {
157 "HOST" => entry.host = val,
158 "PORT" => {
159 entry.port = val
160 .parse()
161 .map_err(|e| anyhow::anyhow!("PORT 解析失败: {}", e))?
162 }
163 "USER" => entry.user = val,
164 "PASSWORD" => entry.password = val,
165 "DATABASE" => entry.database = val,
166 "MAX_CONNECTIONS" => {
167 entry.max_connections = val
168 .parse()
169 .map_err(|e| anyhow::anyhow!("MAX_CONNECTIONS 解析失败: {}", e))?
170 }
171 "SSL_MODE" => entry.ssl_mode = val,
172 _ => {}
173 }
174 }
175 Ok(result)
176}
177
178#[cfg(test)]
183mod tests {
184 use crate::ConfigBuilder;
185
186 #[test]
187 fn validation_rejects_empty_host() {
188 let result = ConfigBuilder::new()
189 .with_mysql("default", |m| {
190 m.host("").user("u").password("p").database("db")
191 })
192 .build();
193 assert!(result.is_err());
194 assert!(result.unwrap_err().to_string().contains("host 不能为空"));
195 }
196}