Skip to main content

sqlmodel_mysql/
auth.rs

1//! MySQL authentication implementations.
2//!
3//! This module implements the MySQL authentication plugins:
4//! - `mysql_native_password`: SHA1-based (legacy, MySQL < 8.0 default)
5//! - `caching_sha2_password`: SHA256-based (MySQL 8.0+ default)
6//!
7//! # mysql_native_password
8//!
9//! Password scramble algorithm:
10//! ```text
11//! SHA1(password) XOR SHA1(seed + SHA1(SHA1(password)))
12//! ```
13//!
14//! # caching_sha2_password
15//!
16//! Fast auth (if cached on server):
17//! ```text
18//! XOR(SHA256(password), SHA256(SHA256(SHA256(password)) + seed))
19//! ```
20//!
21//! Full auth requires TLS or RSA public key encryption.
22
23use sha1::Sha1;
24use sha2::{Digest, Sha256};
25
26use rand::rngs::OsRng;
27
28use rsa::RsaPublicKey;
29use rsa::pkcs1::DecodeRsaPublicKey;
30use rsa::pkcs8::DecodePublicKey;
31
32/// Well-known authentication plugin names.
33pub mod plugins {
34    /// SHA1-based authentication (legacy default)
35    pub const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
36    /// SHA256-based authentication (MySQL 8.0+ default)
37    pub const CACHING_SHA2_PASSWORD: &str = "caching_sha2_password";
38    /// RSA-based SHA256 authentication
39    pub const SHA256_PASSWORD: &str = "sha256_password";
40    /// MySQL clear password (for debugging/testing only)
41    pub const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
42}
43
44/// Response codes for caching_sha2_password protocol.
45pub mod caching_sha2 {
46    /// Request for public key (client should send 0x02)
47    pub const REQUEST_PUBLIC_KEY: u8 = 0x02;
48    /// Fast auth success
49    pub const FAST_AUTH_SUCCESS: u8 = 0x03;
50    /// Full auth needed (switch to secure channel or RSA)
51    pub const PERFORM_FULL_AUTH: u8 = 0x04;
52}
53
54/// Compute mysql_native_password authentication response.
55///
56/// Algorithm: `SHA1(password) XOR SHA1(seed + SHA1(SHA1(password)))`
57///
58/// # Arguments
59/// * `password` - The user's password (UTF-8)
60/// * `auth_data` - The 20-byte scramble from the server
61///
62/// # Returns
63/// The 20-byte authentication response, or empty vec if password is empty.
64pub fn mysql_native_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
65    if password.is_empty() {
66        return vec![];
67    }
68
69    // Ensure we only use first 20 bytes of auth_data
70    let seed = if auth_data.len() > 20 {
71        &auth_data[..20]
72    } else {
73        auth_data
74    };
75
76    // Stage 1: SHA1(password)
77    let mut hasher = Sha1::new();
78    hasher.update(password.as_bytes());
79    let stage1: [u8; 20] = hasher.finalize().into();
80
81    // Stage 2: SHA1(SHA1(password))
82    let mut hasher = Sha1::new();
83    hasher.update(stage1);
84    let stage2: [u8; 20] = hasher.finalize().into();
85
86    // Stage 3: SHA1(seed + stage2)
87    let mut hasher = Sha1::new();
88    hasher.update(seed);
89    hasher.update(stage2);
90    let stage3: [u8; 20] = hasher.finalize().into();
91
92    // Final: stage1 XOR stage3
93    stage1
94        .iter()
95        .zip(stage3.iter())
96        .map(|(a, b)| a ^ b)
97        .collect()
98}
99
100/// Compute caching_sha2_password fast authentication response.
101///
102/// Algorithm: `XOR(SHA256(password), SHA256(SHA256(SHA256(password)) + seed))`
103///
104/// # Arguments
105/// * `password` - The user's password (UTF-8)
106/// * `auth_data` - The scramble from the server (typically 20 bytes + NUL)
107///
108/// # Returns
109/// The 32-byte authentication response, or empty vec if password is empty.
110pub fn caching_sha2_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
111    if password.is_empty() {
112        return vec![];
113    }
114
115    // Remove trailing NUL if present (MySQL sends 20-byte scramble + NUL = 21 bytes)
116    // Only strip if length is 21 and ends with NUL, to avoid modifying valid 20-byte seeds
117    let seed = if auth_data.len() == 21 && auth_data.last() == Some(&0) {
118        &auth_data[..20]
119    } else {
120        auth_data
121    };
122
123    // SHA256(password)
124    let mut hasher = Sha256::new();
125    hasher.update(password.as_bytes());
126    let password_hash: [u8; 32] = hasher.finalize().into();
127
128    // SHA256(SHA256(password))
129    let mut hasher = Sha256::new();
130    hasher.update(password_hash);
131    let password_hash_hash: [u8; 32] = hasher.finalize().into();
132
133    // SHA256(SHA256(SHA256(password)) + seed)
134    let mut hasher = Sha256::new();
135    hasher.update(password_hash_hash);
136    hasher.update(seed);
137    let scramble: [u8; 32] = hasher.finalize().into();
138
139    // XOR(SHA256(password), scramble)
140    password_hash
141        .iter()
142        .zip(scramble.iter())
143        .map(|(a, b)| a ^ b)
144        .collect()
145}
146
147/// Generate a random nonce for client-side use.
148///
149/// Uses `OsRng` for cryptographically secure random generation.
150pub fn generate_nonce(length: usize) -> Vec<u8> {
151    use rand::RngCore;
152    use rand::rngs::OsRng;
153    let mut bytes = vec![0u8; length];
154    OsRng.fill_bytes(&mut bytes);
155    bytes
156}
157
158/// Scramble password for sha256_password plugin using RSA encryption.
159///
160/// This is used when full authentication is required for caching_sha2_password
161/// or sha256_password plugins without TLS.
162///
163/// # Arguments
164/// * `password` - The user's password
165/// * `seed` - The authentication seed from server
166/// * `public_key` - RSA public key from server (PEM format)
167///
168/// # Returns
169/// The encrypted password, or error if encryption fails.
170///
171/// This is used for full authentication for `caching_sha2_password`/`sha256_password`
172/// when the connection is not secured by TLS.
173pub fn sha256_password_rsa(
174    password: &str,
175    seed: &[u8],
176    public_key_pem: &[u8],
177    use_oaep: bool,
178) -> Result<Vec<u8>, String> {
179    // MySQL expects: RSA_encrypt(password_with_nul XOR seed_rotation)
180    let mut pw = password.as_bytes().to_vec();
181    pw.push(0); // NUL terminator
182
183    if seed.is_empty() {
184        return Err("Seed is empty".to_string());
185    }
186
187    for (i, b) in pw.iter_mut().enumerate() {
188        *b ^= seed[i % seed.len()];
189    }
190
191    // Server usually returns a PEM public key for sha256_password/caching_sha2_password.
192    let pem = std::str::from_utf8(public_key_pem)
193        .map_err(|e| format!("Public key is not valid UTF-8 PEM: {e}"))?;
194
195    // Try both common encodings.
196    let pub_key = RsaPublicKey::from_public_key_pem(pem)
197        .or_else(|_| RsaPublicKey::from_pkcs1_pem(pem))
198        .map_err(|e| format!("Failed to parse RSA public key PEM: {e}"))?;
199
200    let encrypted = if use_oaep {
201        // MySQL 8.0.5+ uses OAEP padding for caching_sha2_password.
202        let padding = rsa::Oaep::new::<Sha1>();
203        pub_key
204            .encrypt(&mut OsRng, padding, &pw)
205            .map_err(|e| format!("RSA OAEP encryption failed: {e}"))?
206    } else {
207        let padding = rsa::Pkcs1v15Encrypt;
208        pub_key
209            .encrypt(&mut OsRng, padding, &pw)
210            .map_err(|e| format!("RSA PKCS1v1.5 encryption failed: {e}"))?
211    };
212
213    Ok(encrypted)
214}
215
216/// XOR password with seed for cleartext transmission over TLS.
217///
218/// When the connection is secured with TLS, some auth methods allow sending
219/// the password XOR'd with the seed (or even cleartext).
220pub fn xor_password_with_seed(password: &str, seed: &[u8]) -> Vec<u8> {
221    let password_bytes = password.as_bytes();
222    let mut result = Vec::with_capacity(password_bytes.len() + 1);
223
224    for (i, &byte) in password_bytes.iter().enumerate() {
225        let seed_byte = seed.get(i % seed.len()).copied().unwrap_or(0);
226        result.push(byte ^ seed_byte);
227    }
228
229    // NUL terminator
230    result.push(0);
231
232    result
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_mysql_native_password_empty() {
241        let result = mysql_native_password("", &[0; 20]);
242        assert!(result.is_empty());
243    }
244
245    #[test]
246    fn test_mysql_native_password() {
247        // Known test vector from MySQL protocol documentation
248        // Seed: 20 bytes of zeros
249        // Password: "secret"
250        let seed = [0u8; 20];
251        let result = mysql_native_password("secret", &seed);
252
253        // Should produce 20 bytes
254        assert_eq!(result.len(), 20);
255
256        // The result should be deterministic
257        let result2 = mysql_native_password("secret", &seed);
258        assert_eq!(result, result2);
259    }
260
261    #[test]
262    fn test_mysql_native_password_real_seed() {
263        // Test with a realistic scramble
264        let seed = [
265            0x3d, 0x4c, 0x5e, 0x2f, 0x1a, 0x0b, 0x7c, 0x8d, 0x9e, 0xaf, 0x10, 0x21, 0x32, 0x43,
266            0x54, 0x65, 0x76, 0x87, 0x98, 0xa9,
267        ];
268
269        let result = mysql_native_password("mypassword", &seed);
270        assert_eq!(result.len(), 20);
271
272        // Different password should give different result
273        let result2 = mysql_native_password("otherpassword", &seed);
274        assert_ne!(result, result2);
275    }
276
277    #[test]
278    fn test_caching_sha2_password_empty() {
279        let result = caching_sha2_password("", &[0; 20]);
280        assert!(result.is_empty());
281    }
282
283    #[test]
284    fn test_caching_sha2_password() {
285        let seed = [0u8; 20];
286        let result = caching_sha2_password("secret", &seed);
287
288        // Should produce 32 bytes (SHA-256 output)
289        assert_eq!(result.len(), 32);
290
291        // Should be deterministic
292        let result2 = caching_sha2_password("secret", &seed);
293        assert_eq!(result, result2);
294    }
295
296    #[test]
297    fn test_caching_sha2_password_with_nul() {
298        // MySQL often sends seed with trailing NUL
299        let mut seed = vec![0u8; 20];
300        seed.push(0); // Trailing NUL
301
302        let result = caching_sha2_password("secret", &seed);
303        assert_eq!(result.len(), 32);
304
305        // Should be same as without NUL
306        let result2 = caching_sha2_password("secret", &seed[..20]);
307        assert_eq!(result, result2);
308    }
309
310    #[test]
311    fn test_generate_nonce() {
312        let nonce1 = generate_nonce(20);
313        let nonce2 = generate_nonce(20);
314
315        assert_eq!(nonce1.len(), 20);
316        assert_eq!(nonce2.len(), 20);
317
318        // Should be different (extremely high probability)
319        assert_ne!(nonce1, nonce2);
320    }
321
322    #[test]
323    fn test_xor_password_with_seed() {
324        let password = "test";
325        let seed = [1, 2, 3, 4, 5, 6, 7, 8];
326
327        let result = xor_password_with_seed(password, &seed);
328
329        // Should be password length + 1 (NUL terminator)
330        assert_eq!(result.len(), 5);
331
332        // Last byte should be NUL
333        assert_eq!(result[4], 0);
334
335        // XOR is reversible
336        let recovered: Vec<u8> = result[..4]
337            .iter()
338            .enumerate()
339            .map(|(i, &b)| b ^ seed[i % seed.len()])
340            .collect();
341        assert_eq!(recovered, password.as_bytes());
342    }
343
344    #[test]
345    fn test_plugin_names() {
346        assert_eq!(plugins::MYSQL_NATIVE_PASSWORD, "mysql_native_password");
347        assert_eq!(plugins::CACHING_SHA2_PASSWORD, "caching_sha2_password");
348        assert_eq!(plugins::SHA256_PASSWORD, "sha256_password");
349    }
350}