Skip to main content

fraiseql_wire/auth/
scram.rs

1//! SCRAM-SHA-256 authentication implementation
2//!
3//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
4//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15/// SCRAM authentication error types
16#[derive(Debug, Clone)]
17pub enum ScramError {
18    /// Invalid proof from server
19    InvalidServerProof(String),
20    /// Invalid server message format
21    InvalidServerMessage(String),
22    /// UTF-8 encoding/decoding error
23    Utf8Error(String),
24    /// Base64 decoding error
25    Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35        }
36    }
37}
38
39impl std::error::Error for ScramError {}
40
41/// Internal state needed for SCRAM authentication
42#[derive(Clone, Debug)]
43pub struct ScramState {
44    /// Combined authentication message (for verification)
45    auth_message: Vec<u8>,
46    /// Server key (for verification calculation)
47    server_key: Vec<u8>,
48}
49
50/// SCRAM-SHA-256 client implementation
51pub struct ScramClient {
52    username: String,
53    password: String,
54    nonce: String,
55}
56
57impl ScramClient {
58    /// Create a new SCRAM client
59    pub fn new(username: String, password: String) -> Self {
60        // Generate random client nonce (24 bytes, base64 encoded = 32 chars)
61        let mut rng = rand::thread_rng();
62        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63        let nonce = BASE64.encode(&nonce_bytes);
64
65        Self {
66            username,
67            password,
68            nonce,
69        }
70    }
71
72    /// Generate client first message (no proof)
73    pub fn client_first(&self) -> String {
74        // Format: n,a=<username>,r=<nonce>
75        // RFC 5802: Channels binding doesn't apply, so use "n" (no channel binding)
76        format!("n,a={},r={}", self.username, self.nonce)
77    }
78
79    /// Process server first message and generate client final message
80    ///
81    /// Returns (client_final_message, internal_state)
82    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
83        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
84        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
85
86        // Verify server nonce starts with our client nonce
87        if !server_nonce.starts_with(&self.nonce) {
88            return Err(ScramError::InvalidServerMessage(
89                "server nonce doesn't contain client nonce".to_string(),
90            ));
91        }
92
93        // Decode salt and iterations
94        let salt_bytes = BASE64
95            .decode(&salt)
96            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
97        let iterations = iterations
98            .parse::<u32>()
99            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
100
101        // Build channel binding (no channel binding for SCRAM-SHA-256)
102        let channel_binding = BASE64.encode(b"n,,");
103
104        // Build client final without proof
105        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
106
107        // Build auth message for signature calculation
108        let client_first_bare = format!("a={},r={}", self.username, self.nonce);
109        let auth_message = format!(
110            "{},{},{}",
111            client_first_bare, server_first, client_final_without_proof
112        );
113
114        // Calculate proof
115        let proof = calculate_client_proof(
116            &self.password,
117            &salt_bytes,
118            iterations,
119            auth_message.as_bytes(),
120        )?;
121
122        // Calculate server signature for later verification
123        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
124
125        // Build client final message
126        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
127
128        let state = ScramState {
129            auth_message: auth_message.into_bytes(),
130            server_key,
131        };
132
133        Ok((client_final, state))
134    }
135
136    /// Verify server final message and confirm authentication
137    pub fn verify_server_final(
138        &self,
139        server_final: &str,
140        state: &ScramState,
141    ) -> Result<(), ScramError> {
142        // Parse server final: v=<server_signature>
143        let server_sig_encoded = server_final
144            .strip_prefix("v=")
145            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
146
147        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
148            ScramError::Base64Error("invalid server signature encoding".to_string())
149        })?;
150
151        // Calculate expected server signature
152        let expected_signature = calculate_server_signature(&state.server_key, &state.auth_message);
153
154        // Constant-time comparison
155        if constant_time_compare(&server_signature, &expected_signature) {
156            Ok(())
157        } else {
158            Err(ScramError::InvalidServerProof(
159                "server signature verification failed".to_string(),
160            ))
161        }
162    }
163}
164
165/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
166fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
167    let mut nonce = String::new();
168    let mut salt = String::new();
169    let mut iterations = String::new();
170
171    for part in msg.split(',') {
172        if let Some(value) = part.strip_prefix("r=") {
173            nonce = value.to_string();
174        } else if let Some(value) = part.strip_prefix("s=") {
175            salt = value.to_string();
176        } else if let Some(value) = part.strip_prefix("i=") {
177            iterations = value.to_string();
178        }
179    }
180
181    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
182        return Err(ScramError::InvalidServerMessage(
183            "missing required fields in server first message".to_string(),
184        ));
185    }
186
187    Ok((nonce, salt, iterations))
188}
189
190/// Calculate SCRAM client proof
191fn calculate_client_proof(
192    password: &str,
193    salt: &[u8],
194    iterations: u32,
195    auth_message: &[u8],
196) -> Result<Vec<u8>, ScramError> {
197    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
198    let password_bytes = password.as_bytes();
199    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
200    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
201
202    // ClientKey := HMAC(SaltedPassword, "Client Key")
203    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
204        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
205    client_key_hmac.update(b"Client Key");
206    let client_key = client_key_hmac.finalize().into_bytes();
207
208    // StoredKey := SHA256(ClientKey)
209    let stored_key = Sha256::digest(client_key.to_vec().as_slice());
210
211    // ClientSignature := HMAC(StoredKey, AuthMessage)
212    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
213        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
214    client_sig_hmac.update(auth_message);
215    let client_signature = client_sig_hmac.finalize().into_bytes();
216
217    // ClientProof := ClientKey XOR ClientSignature
218    let mut proof = client_key.to_vec();
219    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
220        *proof_byte ^= sig_byte;
221    }
222
223    Ok(proof.to_vec())
224}
225
226/// Calculate server key for server signature verification
227fn calculate_server_key(
228    password: &str,
229    salt: &[u8],
230    iterations: u32,
231) -> Result<Vec<u8>, ScramError> {
232    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
233    let password_bytes = password.as_bytes();
234    let mut salted_password = vec![0u8; 32];
235    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
236
237    // ServerKey := HMAC(SaltedPassword, "Server Key")
238    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
239        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
240    server_key_hmac.update(b"Server Key");
241
242    Ok(server_key_hmac.finalize().into_bytes().to_vec())
243}
244
245/// Calculate server signature for verification
246fn calculate_server_signature(server_key: &[u8], auth_message: &[u8]) -> Vec<u8> {
247    let mut hmac = HmacSha256::new_from_slice(server_key).expect("HMAC key should be valid");
248    hmac.update(auth_message);
249    hmac.finalize().into_bytes().to_vec()
250}
251
252/// Constant-time comparison to prevent timing attacks
253fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
254    if a.len() != b.len() {
255        return false;
256    }
257    let mut result = 0u8;
258    for (x, y) in a.iter().zip(b.iter()) {
259        result |= x ^ y;
260    }
261    result == 0
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_scram_client_creation() {
270        let client = ScramClient::new("user".to_string(), "password".to_string());
271        assert_eq!(client.username, "user");
272        assert_eq!(client.password, "password");
273        assert!(!client.nonce.is_empty());
274    }
275
276    #[test]
277    fn test_client_first_message_format() {
278        let client = ScramClient::new("alice".to_string(), "secret".to_string());
279        let first = client.client_first();
280
281        assert!(first.starts_with("n,a=alice,r="));
282        assert!(first.len() > 20);
283    }
284
285    #[test]
286    fn test_parse_server_first_valid() {
287        let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
288        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
289
290        assert_eq!(nonce, "client_nonce_server_nonce");
291        assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
292        assert_eq!(iterations, "4096");
293    }
294
295    #[test]
296    fn test_parse_server_first_invalid() {
297        let server_first = "r=nonce,s=salt"; // missing iterations
298        assert!(parse_server_first(server_first).is_err());
299    }
300
301    #[test]
302    fn test_constant_time_compare_equal() {
303        let a = b"test_value";
304        let b_arr = b"test_value";
305        assert!(constant_time_compare(a, b_arr));
306    }
307
308    #[test]
309    fn test_constant_time_compare_different() {
310        let a = b"test_value";
311        let b_arr = b"test_wrong";
312        assert!(!constant_time_compare(a, b_arr));
313    }
314
315    #[test]
316    fn test_constant_time_compare_different_length() {
317        let a = b"test";
318        let b_arr = b"test_longer";
319        assert!(!constant_time_compare(a, b_arr));
320    }
321
322    #[test]
323    fn test_scram_client_final_flow() {
324        let mut client = ScramClient::new("user".to_string(), "password".to_string());
325        let _client_first = client.client_first();
326
327        // Simulate server response
328        let server_nonce = format!("{}server_nonce_part", client.nonce);
329        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
330
331        // Should succeed with valid format
332        let result = client.client_final(&server_first);
333        assert!(result.is_ok());
334
335        let (client_final, state) = result.unwrap();
336        assert!(client_final.starts_with("c="));
337        assert!(!state.auth_message.is_empty());
338    }
339}