Skip to main content

sqlmodel_postgres/auth/
scram.rs

1//! SCRAM-SHA-256 Authentication implementation.
2
3use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
4use hmac::{Hmac, Mac};
5use rand::{Rng, distributions::Alphanumeric, rngs::OsRng};
6use sha2::{Digest, Sha256};
7use sqlmodel_core::Error;
8use sqlmodel_core::error::{ConnectionError, ConnectionErrorKind, ProtocolError};
9use subtle::ConstantTimeEq;
10
11type HmacSha256 = Hmac<Sha256>;
12
13pub struct ScramClient {
14    username: String,
15    password: String,
16    client_nonce: String,
17
18    // State from server
19    server_nonce: Option<String>,
20    salt: Option<Vec<u8>>,
21    iterations: Option<u32>,
22
23    // Derived keys
24    salted_password: Option<[u8; 32]>,
25    auth_message: Option<String>,
26}
27
28impl ScramClient {
29    pub fn new(username: &str, password: &str) -> Self {
30        // Use OsRng for cryptographically secure nonce generation.
31        // 32 characters of alphanumeric provides ~190 bits of entropy.
32        let client_nonce: String = OsRng
33            .sample_iter(&Alphanumeric)
34            .take(32)
35            .map(char::from)
36            .collect();
37
38        Self {
39            username: username.to_string(),
40            password: password.to_string(),
41            client_nonce,
42            server_nonce: None,
43            salt: None,
44            iterations: None,
45            salted_password: None,
46            auth_message: None,
47        }
48    }
49
50    /// Generate client-first message
51    pub fn client_first(&self) -> Vec<u8> {
52        // gs2-header: "n,," (no channel binding, no authzid)
53        // client-first-message-bare: "n=<user>,r=<nonce>"
54        // Note: SCRAM requires strict handling of "," in usernames but Postgres usually forbids it or requires escaping.
55        // For now we assume standard username.
56        format!("n,,n={},r={}", self.username, self.client_nonce).into_bytes()
57    }
58
59    /// Process server-first message and generate client-final
60    #[allow(clippy::result_large_err)]
61    pub fn process_server_first(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
62        let msg = std::str::from_utf8(data)
63            .map_err(|e| protocol_error(format!("Invalid UTF-8 in SASL continue: {}", e)))?;
64
65        // Parse server-first: r=<nonce>,s=<salt>,i=<iterations>
66        let mut combined_nonce = None;
67        let mut salt = None;
68        let mut iterations = None;
69
70        for part in msg.split(',') {
71            if let Some(value) = part.strip_prefix("r=") {
72                combined_nonce = Some(value.to_string());
73            } else if let Some(value) = part.strip_prefix("s=") {
74                salt = Some(
75                    BASE64
76                        .decode(value)
77                        .map_err(|e| protocol_error(format!("Invalid base64 salt: {}", e)))?,
78                );
79            } else if let Some(value) = part.strip_prefix("i=") {
80                iterations = Some(
81                    value
82                        .parse()
83                        .map_err(|e| protocol_error(format!("Invalid iterations: {}", e)))?,
84                );
85            }
86        }
87
88        let combined_nonce = combined_nonce.ok_or_else(|| protocol_error("Missing nonce"))?;
89        let salt = salt.ok_or_else(|| protocol_error("Missing salt"))?;
90        let iterations = iterations.ok_or_else(|| protocol_error("Missing iterations"))?;
91
92        // Verify nonce starts with our client nonce
93        if !combined_nonce.starts_with(&self.client_nonce) {
94            return Err(protocol_error("Invalid server nonce"));
95        }
96
97        // Derive salted password using PBKDF2
98        let mut salted_password = [0u8; 32];
99        pbkdf2::pbkdf2::<HmacSha256>(
100            self.password.as_bytes(),
101            &salt,
102            iterations,
103            &mut salted_password,
104        )
105        .map_err(|e| protocol_error(format!("PBKDF2 failed: {}", e)))?;
106
107        // Build auth message
108        let client_first_bare = format!("n={},r={}", self.username, self.client_nonce);
109        let client_final_without_proof = format!("c=biws,r={}", combined_nonce); // biws = base64("n,,")
110        let auth_message = format!(
111            "{},{},{}",
112            client_first_bare, msg, client_final_without_proof
113        );
114
115        // Calculate client proof
116        let client_key = hmac_sha256(&salted_password, b"Client Key")?;
117        let stored_key = sha256(&client_key);
118        let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes())?;
119
120        let client_proof: Vec<u8> = client_key
121            .iter()
122            .zip(client_signature.iter())
123            .map(|(a, b)| a ^ b)
124            .collect();
125
126        // Store for verification
127        self.server_nonce = Some(combined_nonce.clone());
128        self.salted_password = Some(salted_password);
129        self.auth_message = Some(auth_message);
130        self.salt = Some(salt);
131        self.iterations = Some(iterations);
132
133        // Build client-final message
134        let client_final = format!(
135            "c=biws,r={},p={}",
136            combined_nonce,
137            BASE64.encode(&client_proof)
138        );
139
140        Ok(client_final.into_bytes())
141    }
142
143    /// Verify server-final message
144    #[allow(clippy::result_large_err)]
145    pub fn verify_server_final(&self, data: &[u8]) -> Result<(), Error> {
146        let msg = std::str::from_utf8(data)
147            .map_err(|e| protocol_error(format!("Invalid UTF-8 in SASL final: {}", e)))?;
148
149        let server_signature_b64 = msg
150            .strip_prefix("v=")
151            .ok_or_else(|| protocol_error("Invalid server-final format"))?;
152
153        let server_signature = BASE64
154            .decode(server_signature_b64)
155            .map_err(|e| protocol_error(format!("Invalid base64 server signature: {}", e)))?;
156
157        // Calculate expected server signature
158        let salted_password = self
159            .salted_password
160            .as_ref()
161            .ok_or_else(|| protocol_error("Missing salted password state"))?;
162        let auth_message = self
163            .auth_message
164            .as_ref()
165            .ok_or_else(|| protocol_error("Missing auth message state"))?;
166
167        let server_key = hmac_sha256(salted_password, b"Server Key")?;
168        let expected_signature = hmac_sha256(&server_key, auth_message.as_bytes())?;
169
170        // Use constant-time comparison to prevent timing attacks.
171        // An attacker observing response times could otherwise recover
172        // the expected signature byte-by-byte.
173        if server_signature.ct_eq(&expected_signature).into() {
174            Ok(())
175        } else {
176            Err(auth_error("Server signature mismatch"))
177        }
178    }
179}
180
181// Helpers
182
183fn protocol_error(msg: impl Into<String>) -> Error {
184    Error::Protocol(ProtocolError {
185        message: msg.into(),
186        raw_data: None,
187        source: None,
188    })
189}
190
191fn auth_error(msg: impl Into<String>) -> Error {
192    Error::Connection(ConnectionError {
193        kind: ConnectionErrorKind::Authentication,
194        message: msg.into(),
195        source: None,
196    })
197}
198
199#[allow(clippy::result_large_err)]
200fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<[u8; 32], Error> {
201    let mut mac = HmacSha256::new_from_slice(key)
202        .map_err(|e| protocol_error(format!("HMAC init failed: {}", e)))?;
203    mac.update(data);
204    let result = mac.finalize();
205    let bytes = result.into_bytes();
206    Ok(bytes.into())
207}
208
209fn sha256(data: &[u8]) -> [u8; 32] {
210    let mut hasher = Sha256::new();
211    hasher.update(data);
212    hasher.finalize().into()
213}