fraiseql-wire 2.2.1

Streaming JSON query engine for Postgres 17
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
//! SCRAM-SHA-256 authentication implementation
//!
//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).

use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2;
use rand::{rngs::OsRng, Rng};
use sha2::{Digest, Sha256};
use std::fmt;

type HmacSha256 = Hmac<Sha256>;

/// Maximum PBKDF2 iteration count accepted from the server (DoS protection).
///
/// A malicious server can supply a very large `i=` value in its SCRAM first message,
/// causing the client to spend seconds (or minutes) in PBKDF2 before the connection
/// is rejected. Capping at 1,000,000 prevents this denial-of-service vector while
/// remaining orders of magnitude above typical PostgreSQL defaults (4096–600,000).
const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;

/// SCRAM authentication error types
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ScramError {
    /// Invalid proof from server
    InvalidServerProof(String),
    /// Invalid server message format
    InvalidServerMessage(String),
    /// UTF-8 encoding/decoding error
    Utf8Error(String),
    /// Base64 decoding error
    Base64Error(String),
}

impl fmt::Display for ScramError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
        }
    }
}

impl std::error::Error for ScramError {}

/// Internal state needed for SCRAM authentication
#[derive(Clone, Debug)]
pub struct ScramState {
    /// Combined authentication message (for verification)
    auth_message: Vec<u8>,
    /// Server key (for verification calculation)
    server_key: Vec<u8>,
}

/// SCRAM-SHA-256 client implementation
pub struct ScramClient {
    username: String,
    password: String,
    nonce: String,
}

impl ScramClient {
    /// Create a new SCRAM client
    pub fn new(username: String, password: String) -> Self {
        // SECURITY: OsRng guarantees OS-level entropy for SCRAM nonces.
        let mut rng = OsRng;
        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
        let nonce = BASE64.encode(&nonce_bytes);

        Self {
            username,
            password,
            nonce,
        }
    }

    /// Generate client first message (no proof)
    pub fn client_first(&self) -> String {
        // RFC 5802 format: gs2-header client-first-message-bare
        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
        // client-first-message-bare = "n=<username>,r=<nonce>"
        // RFC 5802 §5.1: username must have ',' escaped as '=2C' and '=' escaped as '=3D'.
        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
        format!("n,,n={},r={}", escaped_username, self.nonce)
    }

    /// Process server first message and generate client final message
    ///
    /// Returns (`client_final_message`, `internal_state`)
    ///
    /// # Errors
    ///
    /// Returns [`ScramError::InvalidServerMessage`] if the server message cannot be parsed,
    /// the server nonce does not start with the client nonce, or the iteration count is
    /// invalid or exceeds `MAX_SCRAM_ITERATIONS`. Returns [`ScramError::Base64Error`] if
    /// the salt is not valid base64.
    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;

        // Verify server nonce starts with our client nonce
        if !server_nonce.starts_with(&self.nonce) {
            return Err(ScramError::InvalidServerMessage(
                "server nonce doesn't contain client nonce".to_string(),
            ));
        }

        // Decode salt and iterations
        let salt_bytes = BASE64
            .decode(&salt)
            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
        let iterations = iterations
            .parse::<u32>()
            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;

        // SECURITY: Guard against server-supplied iteration counts large enough to
        // cause a denial-of-service via excessive PBKDF2 CPU time.
        if iterations > MAX_SCRAM_ITERATIONS {
            return Err(ScramError::InvalidServerMessage(format!(
                "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
            )));
        }

        // Build channel binding (no channel binding for SCRAM-SHA-256)
        let channel_binding = BASE64.encode(b"n,,");

        // Build client final without proof
        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);

        // Build auth message for signature calculation.
        // client-first-message-bare is "n=<escaped_username>,r=<nonce>" (without gs2-header).
        // SECURITY: Must use the RFC 5802 §5.1-escaped username (same as client_first()),
        // not the raw username — otherwise an attacker who controls ',' or '=' in a username
        // can inject arbitrary SCRAM attributes and break authentication.
        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
        let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
        let auth_message = format!(
            "{},{},{}",
            client_first_bare, server_first, client_final_without_proof
        );

        // Calculate proof
        let proof = calculate_client_proof(
            &self.password,
            &salt_bytes,
            iterations,
            auth_message.as_bytes(),
        )?;

        // Calculate server signature for later verification
        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;

        // Build client final message
        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));

        let state = ScramState {
            auth_message: auth_message.into_bytes(),
            server_key,
        };

        Ok((client_final, state))
    }

    /// Verify server final message and confirm authentication
    ///
    /// # Errors
    ///
    /// Returns `ScramError::InvalidServerMessage` if the server final message is malformed.
    /// Returns `ScramError::Base64Error` if the server signature is not valid base64.
    /// Returns `ScramError::AuthenticationFailed` if the server signature does not match.
    pub fn verify_server_final(
        &self,
        server_final: &str,
        state: &ScramState,
    ) -> Result<(), ScramError> {
        // Parse server final: v=<server_signature>
        let server_sig_encoded = server_final
            .strip_prefix("v=")
            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;

        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
            ScramError::Base64Error("invalid server signature encoding".to_string())
        })?;

        // Calculate expected server signature
        let expected_signature =
            calculate_server_signature(&state.server_key, &state.auth_message)?;

        // Constant-time comparison
        if constant_time_compare(&server_signature, &expected_signature) {
            Ok(())
        } else {
            Err(ScramError::InvalidServerProof(
                "server signature verification failed".to_string(),
            ))
        }
    }
}

/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
    let mut nonce = String::new();
    let mut salt = String::new();
    let mut iterations = String::new();

    for part in msg.split(',') {
        if let Some(value) = part.strip_prefix("r=") {
            nonce = value.to_string();
        } else if let Some(value) = part.strip_prefix("s=") {
            salt = value.to_string();
        } else if let Some(value) = part.strip_prefix("i=") {
            iterations = value.to_string();
        }
    }

    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
        return Err(ScramError::InvalidServerMessage(
            "missing required fields in server first message".to_string(),
        ));
    }

    Ok((nonce, salt, iterations))
}

/// Calculate SCRAM client proof
fn calculate_client_proof(
    password: &str,
    salt: &[u8],
    iterations: u32,
    auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
    let password_bytes = password.as_bytes();
    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);

    // ClientKey := HMAC(SaltedPassword, "Client Key")
    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    client_key_hmac.update(b"Client Key");
    let client_key = client_key_hmac.finalize().into_bytes();

    // StoredKey := SHA256(ClientKey)
    let stored_key = Sha256::digest(client_key.to_vec().as_slice());

    // ClientSignature := HMAC(StoredKey, AuthMessage)
    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    client_sig_hmac.update(auth_message);
    let client_signature = client_sig_hmac.finalize().into_bytes();

    // ClientProof := ClientKey XOR ClientSignature
    let mut proof = client_key.to_vec();
    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
        *proof_byte ^= sig_byte;
    }

    Ok(proof.clone())
}

/// Calculate server key for server signature verification
fn calculate_server_key(
    password: &str,
    salt: &[u8],
    iterations: u32,
) -> Result<Vec<u8>, ScramError> {
    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
    let password_bytes = password.as_bytes();
    let mut salted_password = vec![0u8; 32];
    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);

    // ServerKey := HMAC(SaltedPassword, "Server Key")
    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    server_key_hmac.update(b"Server Key");

    Ok(server_key_hmac.finalize().into_bytes().to_vec())
}

/// Calculate server signature for verification
fn calculate_server_signature(
    server_key: &[u8],
    auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
    let mut hmac = HmacSha256::new_from_slice(server_key)
        .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
    hmac.update(auth_message);
    Ok(hmac.finalize().into_bytes().to_vec())
}

/// Constant-time comparison to prevent timing attacks.
///
/// Uses the `subtle` crate for verified constant-time operations.
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
    use subtle::ConstantTimeEq;
    a.ct_eq(b).into()
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
    use super::*;

    #[test]
    fn test_scram_client_creation() {
        let client = ScramClient::new("user".to_string(), "password".to_string());
        assert_eq!(client.username, "user");
        assert_eq!(client.password, "password");
        assert!(!client.nonce.is_empty());
    }

    #[test]
    fn test_client_first_message_format() {
        let client = ScramClient::new("alice".to_string(), "secret".to_string());
        let first = client.client_first();

        // RFC 5802 format: "n,,n=<username>,r=<nonce>"
        assert!(first.starts_with("n,,n=alice,r="));
        assert!(first.len() > 20);
    }

    #[test]
    fn test_parse_server_first_valid() {
        let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();

        assert_eq!(nonce, "client_nonce_server_nonce");
        assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
        assert_eq!(iterations, "4096");
    }

    #[test]
    fn test_parse_server_first_invalid() {
        let server_first = "r=nonce,s=salt"; // missing iterations
        let result = parse_server_first(server_first);
        assert!(
            matches!(result, Err(ScramError::InvalidServerMessage(_))),
            "expected InvalidServerMessage error, got: {result:?}"
        );
    }

    #[test]
    fn test_constant_time_compare_equal() {
        let a = b"test_value";
        let b_arr = b"test_value";
        assert!(constant_time_compare(a, b_arr));
    }

    #[test]
    fn test_constant_time_compare_different() {
        let a = b"test_value";
        let b_arr = b"test_wrong";
        assert!(!constant_time_compare(a, b_arr));
    }

    #[test]
    fn test_constant_time_compare_different_length() {
        let a = b"test";
        let b_arr = b"test_longer";
        assert!(!constant_time_compare(a, b_arr));
    }

    #[test]
    fn test_scram_client_final_flow() {
        let mut client = ScramClient::new("user".to_string(), "password".to_string());
        let _client_first = client.client_first();

        // Simulate server response
        let server_nonce = format!("{}server_nonce_part", client.nonce);
        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));

        // Should succeed with valid format
        let result = client.client_final(&server_first);
        let (client_final, state) = result.unwrap_or_else(|e| {
            panic!("expected Ok for client_final with valid server message: {e}")
        });
        assert!(client_final.starts_with("c="));
        assert!(!state.auth_message.is_empty());
    }

    #[test]
    fn test_scram_iteration_count_too_high_is_rejected() {
        // H3: A server-supplied i= value above MAX_SCRAM_ITERATIONS must be rejected
        // to prevent PBKDF2-based denial-of-service.
        let mut client = ScramClient::new("user".to_string(), "password".to_string());
        let _client_first = client.client_first();

        let server_nonce = format!("{}server_nonce_part", client.nonce);
        let excessive_iterations = MAX_SCRAM_ITERATIONS + 1;
        let server_first = format!(
            "r={},s={},i={}",
            server_nonce,
            BASE64.encode(b"salty"),
            excessive_iterations
        );

        let result = client.client_final(&server_first);
        assert!(
            matches!(result, Err(ScramError::InvalidServerMessage(_))),
            "expected InvalidServerMessage for excessive iterations, got: {result:?}"
        );
    }

    #[test]
    fn test_scram_iteration_count_at_limit_is_accepted() {
        // Exactly MAX_SCRAM_ITERATIONS must be accepted.
        let mut client = ScramClient::new("user".to_string(), "password".to_string());
        let _client_first = client.client_first();

        let server_nonce = format!("{}server_nonce_part", client.nonce);
        let server_first = format!(
            "r={},s={},i={}",
            server_nonce,
            BASE64.encode(b"salty"),
            MAX_SCRAM_ITERATIONS
        );

        // Should not fail on the iteration count check (may fail for other reasons if any)
        let result = client.client_final(&server_first);
        // We only care that it didn't fail with an iteration-count error
        if let Err(ScramError::InvalidServerMessage(msg)) = &result {
            assert!(
                !msg.contains("iteration count"),
                "unexpected iteration-count rejection at limit: {msg}"
            );
        }
    }

    #[test]
    fn test_scram_username_escaping_in_auth_message() {
        // H4: The auth message must use the RFC 5802-escaped username, not the raw one.
        // A username containing ',' or '=' must be escaped in client_first_bare.
        let mut client = ScramClient::new("user,admin=evil".to_string(), "password".to_string());
        let client_first = client.client_first();
        // client_first should have escaped username
        assert!(
            client_first.contains("user=2Cadmin=3Devil"),
            "client_first should escape ',' and '=' in username, got: {client_first}"
        );

        // client_final should use the same escaped username in the auth message
        let server_nonce = format!("{}server_nonce_part", client.nonce);
        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));

        let result = client.client_final(&server_first);
        let (_client_final, state) =
            result.unwrap_or_else(|e| panic!("expected Ok for escaped-username client_final: {e}"));

        // The auth message must contain the escaped username, not the raw one
        let auth_message = String::from_utf8(state.auth_message).unwrap();
        assert!(
            auth_message.contains("user=2Cadmin=3Devil"),
            "auth_message should contain escaped username, got: {auth_message}"
        );
        assert!(
            !auth_message.contains("user,admin=evil"),
            "auth_message must NOT contain raw (unescaped) username, got: {auth_message}"
        );
    }
}