Skip to main content

rswxpay/
config.rs

1use zeroize::Zeroize;
2
3use crate::error::WxPayError;
4
5const DEFAULT_BASE_URL: &str = "https://api.mch.weixin.qq.com";
6
7pub struct ClientConfig {
8    pub(crate) mch_id: String,
9    pub(crate) serial_no: String,
10    pub(crate) api_v3_key: String,
11    pub(crate) private_key_pem: String,
12    pub(crate) http_client: Option<reqwest::Client>,
13    pub(crate) base_url: String,
14}
15
16impl ClientConfig {
17    pub fn builder() -> ClientConfigBuilder {
18        ClientConfigBuilder {
19            mch_id: None,
20            serial_no: None,
21            api_v3_key: None,
22            private_key_pem: None,
23            http_client: None,
24            base_url: None,
25        }
26    }
27
28    /// Returns the merchant ID.
29    pub fn mch_id(&self) -> &str {
30        &self.mch_id
31    }
32
33    /// Returns the certificate serial number.
34    pub fn serial_no(&self) -> &str {
35        &self.serial_no
36    }
37
38    /// Returns the base URL.
39    pub fn base_url(&self) -> &str {
40        &self.base_url
41    }
42}
43
44impl Drop for ClientConfig {
45    fn drop(&mut self) {
46        self.api_v3_key.zeroize();
47        self.private_key_pem.zeroize();
48    }
49}
50
51pub struct ClientConfigBuilder {
52    mch_id: Option<String>,
53    serial_no: Option<String>,
54    api_v3_key: Option<String>,
55    private_key_pem: Option<String>,
56    http_client: Option<reqwest::Client>,
57    base_url: Option<String>,
58}
59
60impl ClientConfigBuilder {
61    pub fn mch_id(mut self, mch_id: impl Into<String>) -> Self {
62        self.mch_id = Some(mch_id.into());
63        self
64    }
65
66    pub fn serial_no(mut self, serial_no: impl Into<String>) -> Self {
67        self.serial_no = Some(serial_no.into());
68        self
69    }
70
71    pub fn api_v3_key(mut self, api_v3_key: impl Into<String>) -> Self {
72        self.api_v3_key = Some(api_v3_key.into());
73        self
74    }
75
76    pub fn private_key_pem(mut self, private_key_pem: impl Into<String>) -> Self {
77        self.private_key_pem = Some(private_key_pem.into());
78        self
79    }
80
81    pub fn http_client(mut self, client: reqwest::Client) -> Self {
82        self.http_client = Some(client);
83        self
84    }
85
86    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
87        self.base_url = Some(base_url.into());
88        self
89    }
90
91    pub fn build(self) -> Result<ClientConfig, WxPayError> {
92        let mch_id = self
93            .mch_id
94            .ok_or_else(|| WxPayError::Config("mch_id is required".into()))?;
95        let serial_no = self
96            .serial_no
97            .ok_or_else(|| WxPayError::Config("serial_no is required".into()))?;
98        let api_v3_key = self
99            .api_v3_key
100            .ok_or_else(|| WxPayError::Config("api_v3_key is required".into()))?;
101        let private_key_pem = self
102            .private_key_pem
103            .ok_or_else(|| WxPayError::Config("private_key_pem is required".into()))?;
104
105        if !api_v3_key.is_ascii() {
106            return Err(WxPayError::Config(
107                "api_v3_key must contain only ASCII characters".into(),
108            ));
109        }
110        if api_v3_key.len() != 32 {
111            return Err(WxPayError::Config(format!(
112                "api_v3_key must be 32 bytes, got {}",
113                api_v3_key.len()
114            )));
115        }
116
117        Ok(ClientConfig {
118            mch_id,
119            serial_no,
120            api_v3_key,
121            private_key_pem,
122            http_client: self.http_client,
123            base_url: self
124                .base_url
125                .unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    /// Generate a valid RSA PKCS#1 PEM private key for testing.
135    fn test_private_key_pem() -> String {
136        use rsa::RsaPrivateKey;
137        use rsa::pkcs1::EncodeRsaPrivateKey;
138
139        let mut rng = rand::thread_rng();
140        let key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
141        key.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
142            .unwrap()
143            .to_string()
144    }
145
146    /// A valid 32-byte API v3 key for testing.
147    fn test_api_v3_key() -> &'static str {
148        "01234567890123456789012345678901" // exactly 32 bytes
149    }
150
151    #[test]
152    fn test_builder_success() {
153        let pem = test_private_key_pem();
154        let config = ClientConfig::builder()
155            .mch_id("1900000001")
156            .serial_no("SERIAL123")
157            .api_v3_key(test_api_v3_key())
158            .private_key_pem(pem)
159            .build();
160
161        assert!(config.is_ok());
162    }
163
164    /// Extract error from a Result<ClientConfig, WxPayError>, panicking if Ok.
165    fn expect_err(result: Result<ClientConfig, WxPayError>) -> WxPayError {
166        match result {
167            Err(e) => e,
168            Ok(_) => panic!("expected Err, got Ok"),
169        }
170    }
171
172    #[test]
173    fn test_builder_missing_mch_id() {
174        let pem = test_private_key_pem();
175        let result = ClientConfig::builder()
176            .serial_no("SERIAL123")
177            .api_v3_key(test_api_v3_key())
178            .private_key_pem(pem)
179            .build();
180
181        let err = expect_err(result);
182        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("mch_id")));
183    }
184
185    #[test]
186    fn test_builder_missing_serial_no() {
187        let pem = test_private_key_pem();
188        let result = ClientConfig::builder()
189            .mch_id("1900000001")
190            .api_v3_key(test_api_v3_key())
191            .private_key_pem(pem)
192            .build();
193
194        let err = expect_err(result);
195        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("serial_no")));
196    }
197
198    #[test]
199    fn test_builder_missing_api_v3_key() {
200        let pem = test_private_key_pem();
201        let result = ClientConfig::builder()
202            .mch_id("1900000001")
203            .serial_no("SERIAL123")
204            .private_key_pem(pem)
205            .build();
206
207        let err = expect_err(result);
208        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("api_v3_key")));
209    }
210
211    #[test]
212    fn test_builder_missing_private_key_pem() {
213        let result = ClientConfig::builder()
214            .mch_id("1900000001")
215            .serial_no("SERIAL123")
216            .api_v3_key(test_api_v3_key())
217            .build();
218
219        let err = expect_err(result);
220        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("private_key_pem")));
221    }
222
223    #[test]
224    fn test_builder_invalid_api_v3_key_length() {
225        let pem = test_private_key_pem();
226        let result = ClientConfig::builder()
227            .mch_id("1900000001")
228            .serial_no("SERIAL123")
229            .api_v3_key("too_short")
230            .private_key_pem(pem)
231            .build();
232
233        let err = expect_err(result);
234        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("32 bytes")));
235    }
236
237    #[test]
238    fn test_builder_default_base_url() {
239        let pem = test_private_key_pem();
240        let config = ClientConfig::builder()
241            .mch_id("1900000001")
242            .serial_no("SERIAL123")
243            .api_v3_key(test_api_v3_key())
244            .private_key_pem(pem)
245            .build()
246            .unwrap();
247
248        assert_eq!(config.base_url(), "https://api.mch.weixin.qq.com");
249    }
250
251    #[test]
252    fn test_builder_custom_base_url() {
253        let pem = test_private_key_pem();
254        let custom_url = "https://custom.example.com";
255        let config = ClientConfig::builder()
256            .mch_id("1900000001")
257            .serial_no("SERIAL123")
258            .api_v3_key(test_api_v3_key())
259            .private_key_pem(pem)
260            .base_url(custom_url)
261            .build()
262            .unwrap();
263
264        assert_eq!(config.base_url(), custom_url);
265    }
266
267    #[test]
268    fn test_getters() {
269        let pem = test_private_key_pem();
270        let config = ClientConfig::builder()
271            .mch_id("1900000001")
272            .serial_no("SERIAL123")
273            .api_v3_key(test_api_v3_key())
274            .private_key_pem(pem)
275            .build()
276            .unwrap();
277
278        assert_eq!(config.mch_id(), "1900000001");
279        assert_eq!(config.serial_no(), "SERIAL123");
280        assert_eq!(config.base_url(), "https://api.mch.weixin.qq.com");
281    }
282
283    #[test]
284    fn test_builder_non_ascii_api_v3_key() {
285        let pem = test_private_key_pem();
286        let result = ClientConfig::builder()
287            .mch_id("1900000001")
288            .serial_no("SERIAL123")
289            .api_v3_key("非ASCII密钥")
290            .private_key_pem(pem)
291            .build();
292
293        let err = expect_err(result);
294        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("ASCII")));
295    }
296
297    #[test]
298    fn test_builder_non_ascii_api_v3_key_32_bytes() {
299        let pem = test_private_key_pem();
300        // 10 Chinese chars (30 UTF-8 bytes) + 2 ASCII = 32 bytes total,
301        // but is_ascii() returns false — must be rejected.
302        let result = ClientConfig::builder()
303            .mch_id("1900000001")
304            .serial_no("SERIAL123")
305            .api_v3_key("密钥密钥密钥密钥密钥ab")
306            .private_key_pem(pem)
307            .build();
308
309        let err = expect_err(result);
310        assert!(matches!(err, WxPayError::Config(msg) if msg.contains("ASCII")));
311    }
312
313    #[test]
314    fn test_zeroize_clears_sensitive_fields() {
315        use zeroize::Zeroize;
316
317        let pem = test_private_key_pem();
318        let mut config = ClientConfig::builder()
319            .mch_id("1900000001")
320            .serial_no("SERIAL123")
321            .api_v3_key(test_api_v3_key())
322            .private_key_pem(pem)
323            .build()
324            .unwrap();
325
326        assert!(!config.api_v3_key.is_empty());
327        assert!(!config.private_key_pem.is_empty());
328
329        // Simulate what Drop does
330        config.api_v3_key.zeroize();
331        config.private_key_pem.zeroize();
332
333        assert!(config.api_v3_key.is_empty());
334        assert!(config.private_key_pem.is_empty());
335    }
336}