Skip to main content

fraiseql_wire/auth/
scram.rs

1//! SCRAM-SHA-256 authentication implementation
2//!
3//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
4//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15/// SCRAM authentication error types
16#[derive(Debug, Clone)]
17pub enum ScramError {
18    /// Invalid proof from server
19    InvalidServerProof(String),
20    /// Invalid server message format
21    InvalidServerMessage(String),
22    /// UTF-8 encoding/decoding error
23    Utf8Error(String),
24    /// Base64 decoding error
25    Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35        }
36    }
37}
38
39impl std::error::Error for ScramError {}
40
41/// Internal state needed for SCRAM authentication
42#[derive(Clone, Debug)]
43pub struct ScramState {
44    /// Combined authentication message (for verification)
45    auth_message: Vec<u8>,
46    /// Server key (for verification calculation)
47    server_key: Vec<u8>,
48}
49
50/// SCRAM-SHA-256 client implementation
51pub struct ScramClient {
52    username: String,
53    password: String,
54    nonce: String,
55}
56
57impl ScramClient {
58    /// Create a new SCRAM client
59    pub fn new(username: String, password: String) -> Self {
60        // Generate random client nonce (24 bytes, base64 encoded = 32 chars)
61        let mut rng = rand::thread_rng();
62        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63        let nonce = BASE64.encode(&nonce_bytes);
64
65        Self {
66            username,
67            password,
68            nonce,
69        }
70    }
71
72    /// Generate client first message (no proof)
73    pub fn client_first(&self) -> String {
74        // RFC 5802 format: gs2-header client-first-message-bare
75        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
76        // client-first-message-bare = "n=<username>,r=<nonce>"
77        // Note: "n=" is the username, "a=" would be authorization identity (not supported by PostgreSQL)
78        format!("n,,n={},r={}", self.username, self.nonce)
79    }
80
81    /// Process server first message and generate client final message
82    ///
83    /// Returns (client_final_message, internal_state)
84    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
85        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
86        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
87
88        // Verify server nonce starts with our client nonce
89        if !server_nonce.starts_with(&self.nonce) {
90            return Err(ScramError::InvalidServerMessage(
91                "server nonce doesn't contain client nonce".to_string(),
92            ));
93        }
94
95        // Decode salt and iterations
96        let salt_bytes = BASE64
97            .decode(&salt)
98            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
99        let iterations = iterations
100            .parse::<u32>()
101            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
102
103        // Build channel binding (no channel binding for SCRAM-SHA-256)
104        let channel_binding = BASE64.encode(b"n,,");
105
106        // Build client final without proof
107        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
108
109        // Build auth message for signature calculation
110        // client-first-message-bare is "n=<username>,r=<nonce>" (without gs2-header)
111        let client_first_bare = format!("n={},r={}", self.username, self.nonce);
112        let auth_message = format!(
113            "{},{},{}",
114            client_first_bare, server_first, client_final_without_proof
115        );
116
117        // Calculate proof
118        let proof = calculate_client_proof(
119            &self.password,
120            &salt_bytes,
121            iterations,
122            auth_message.as_bytes(),
123        )?;
124
125        // Calculate server signature for later verification
126        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
127
128        // Build client final message
129        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
130
131        let state = ScramState {
132            auth_message: auth_message.into_bytes(),
133            server_key,
134        };
135
136        Ok((client_final, state))
137    }
138
139    /// Verify server final message and confirm authentication
140    pub fn verify_server_final(
141        &self,
142        server_final: &str,
143        state: &ScramState,
144    ) -> Result<(), ScramError> {
145        // Parse server final: v=<server_signature>
146        let server_sig_encoded = server_final
147            .strip_prefix("v=")
148            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
149
150        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
151            ScramError::Base64Error("invalid server signature encoding".to_string())
152        })?;
153
154        // Calculate expected server signature
155        let expected_signature = calculate_server_signature(&state.server_key, &state.auth_message);
156
157        // Constant-time comparison
158        if constant_time_compare(&server_signature, &expected_signature) {
159            Ok(())
160        } else {
161            Err(ScramError::InvalidServerProof(
162                "server signature verification failed".to_string(),
163            ))
164        }
165    }
166}
167
168/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
169fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
170    let mut nonce = String::new();
171    let mut salt = String::new();
172    let mut iterations = String::new();
173
174    for part in msg.split(',') {
175        if let Some(value) = part.strip_prefix("r=") {
176            nonce = value.to_string();
177        } else if let Some(value) = part.strip_prefix("s=") {
178            salt = value.to_string();
179        } else if let Some(value) = part.strip_prefix("i=") {
180            iterations = value.to_string();
181        }
182    }
183
184    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
185        return Err(ScramError::InvalidServerMessage(
186            "missing required fields in server first message".to_string(),
187        ));
188    }
189
190    Ok((nonce, salt, iterations))
191}
192
193/// Calculate SCRAM client proof
194fn calculate_client_proof(
195    password: &str,
196    salt: &[u8],
197    iterations: u32,
198    auth_message: &[u8],
199) -> Result<Vec<u8>, ScramError> {
200    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
201    let password_bytes = password.as_bytes();
202    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
203    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
204
205    // ClientKey := HMAC(SaltedPassword, "Client Key")
206    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
207        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
208    client_key_hmac.update(b"Client Key");
209    let client_key = client_key_hmac.finalize().into_bytes();
210
211    // StoredKey := SHA256(ClientKey)
212    let stored_key = Sha256::digest(client_key.to_vec().as_slice());
213
214    // ClientSignature := HMAC(StoredKey, AuthMessage)
215    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
216        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
217    client_sig_hmac.update(auth_message);
218    let client_signature = client_sig_hmac.finalize().into_bytes();
219
220    // ClientProof := ClientKey XOR ClientSignature
221    let mut proof = client_key.to_vec();
222    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
223        *proof_byte ^= sig_byte;
224    }
225
226    Ok(proof.to_vec())
227}
228
229/// Calculate server key for server signature verification
230fn calculate_server_key(
231    password: &str,
232    salt: &[u8],
233    iterations: u32,
234) -> Result<Vec<u8>, ScramError> {
235    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
236    let password_bytes = password.as_bytes();
237    let mut salted_password = vec![0u8; 32];
238    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
239
240    // ServerKey := HMAC(SaltedPassword, "Server Key")
241    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
242        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
243    server_key_hmac.update(b"Server Key");
244
245    Ok(server_key_hmac.finalize().into_bytes().to_vec())
246}
247
248/// Calculate server signature for verification
249fn calculate_server_signature(server_key: &[u8], auth_message: &[u8]) -> Vec<u8> {
250    let mut hmac = HmacSha256::new_from_slice(server_key).expect("HMAC key should be valid");
251    hmac.update(auth_message);
252    hmac.finalize().into_bytes().to_vec()
253}
254
255/// Constant-time comparison to prevent timing attacks
256fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
257    if a.len() != b.len() {
258        return false;
259    }
260    let mut result = 0u8;
261    for (x, y) in a.iter().zip(b.iter()) {
262        result |= x ^ y;
263    }
264    result == 0
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_scram_client_creation() {
273        let client = ScramClient::new("user".to_string(), "password".to_string());
274        assert_eq!(client.username, "user");
275        assert_eq!(client.password, "password");
276        assert!(!client.nonce.is_empty());
277    }
278
279    #[test]
280    fn test_client_first_message_format() {
281        let client = ScramClient::new("alice".to_string(), "secret".to_string());
282        let first = client.client_first();
283
284        // RFC 5802 format: "n,,n=<username>,r=<nonce>"
285        assert!(first.starts_with("n,,n=alice,r="));
286        assert!(first.len() > 20);
287    }
288
289    #[test]
290    fn test_parse_server_first_valid() {
291        let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
292        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
293
294        assert_eq!(nonce, "client_nonce_server_nonce");
295        assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
296        assert_eq!(iterations, "4096");
297    }
298
299    #[test]
300    fn test_parse_server_first_invalid() {
301        let server_first = "r=nonce,s=salt"; // missing iterations
302        assert!(parse_server_first(server_first).is_err());
303    }
304
305    #[test]
306    fn test_constant_time_compare_equal() {
307        let a = b"test_value";
308        let b_arr = b"test_value";
309        assert!(constant_time_compare(a, b_arr));
310    }
311
312    #[test]
313    fn test_constant_time_compare_different() {
314        let a = b"test_value";
315        let b_arr = b"test_wrong";
316        assert!(!constant_time_compare(a, b_arr));
317    }
318
319    #[test]
320    fn test_constant_time_compare_different_length() {
321        let a = b"test";
322        let b_arr = b"test_longer";
323        assert!(!constant_time_compare(a, b_arr));
324    }
325
326    #[test]
327    fn test_scram_client_final_flow() {
328        let mut client = ScramClient::new("user".to_string(), "password".to_string());
329        let _client_first = client.client_first();
330
331        // Simulate server response
332        let server_nonce = format!("{}server_nonce_part", client.nonce);
333        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
334
335        // Should succeed with valid format
336        let result = client.client_final(&server_first);
337        assert!(result.is_ok());
338
339        let (client_final, state) = result.unwrap();
340        assert!(client_final.starts_with("c="));
341        assert!(!state.auth_message.is_empty());
342    }
343}