Skip to main content

drasi_source_postgres/
scram.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, Context, Result};
16use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
17use rand::Rng;
18use sha2::{Digest, Sha256};
19use std::collections::HashMap;
20
21pub struct ScramClient {
22    username: String,
23    password: String,
24    client_nonce: String,
25    server_nonce: Option<String>,
26    salt: Option<Vec<u8>>,
27    iterations: Option<u32>,
28    auth_message: Option<String>,
29}
30
31impl ScramClient {
32    pub fn new(username: &str, password: &str) -> Self {
33        let client_nonce = generate_nonce();
34        Self {
35            username: username.to_string(),
36            password: password.to_string(),
37            client_nonce,
38            server_nonce: None,
39            salt: None,
40            iterations: None,
41            auth_message: None,
42        }
43    }
44
45    pub fn client_first_message(&self) -> String {
46        // SCRAM client-first-message format: n,,n=<username>,r=<nonce>
47        let gs2_header = "n,,"; // No channel binding
48        let client_first_bare = format!("n={},r={}", saslprep(&self.username), self.client_nonce);
49        format!("{gs2_header}{client_first_bare}")
50    }
51
52    pub fn process_server_first_message(&mut self, message: &str) -> Result<()> {
53        // Parse server-first-message: r=<nonce>,s=<salt>,i=<iteration-count>
54        let params = parse_scram_message(message)?;
55
56        // Verify server nonce starts with client nonce
57        let server_nonce = params
58            .get("r")
59            .ok_or_else(|| anyhow!("Missing nonce in server response"))?;
60        if !server_nonce.starts_with(&self.client_nonce) {
61            return Err(anyhow!("Server nonce doesn't include client nonce"));
62        }
63        self.server_nonce = Some(server_nonce.clone());
64
65        // Parse salt
66        let salt_b64 = params
67            .get("s")
68            .ok_or_else(|| anyhow!("Missing salt in server response"))?;
69        self.salt = Some(BASE64.decode(salt_b64)?);
70
71        // Parse iteration count
72        let iterations_str = params
73            .get("i")
74            .ok_or_else(|| anyhow!("Missing iteration count in server response"))?;
75        self.iterations = Some(iterations_str.parse()?);
76
77        Ok(())
78    }
79
80    pub fn client_final_message(&mut self) -> Result<String> {
81        let server_nonce = self
82            .server_nonce
83            .as_ref()
84            .ok_or_else(|| anyhow!("Server nonce not set"))?;
85        let salt = self.salt.as_ref().ok_or_else(|| anyhow!("Salt not set"))?;
86        let iterations = self
87            .iterations
88            .ok_or_else(|| anyhow!("Iterations not set"))?;
89
90        // Build client-final-message-without-proof
91        let channel_binding = "c=biws"; // base64("n,,")
92        let client_final_without_proof = format!("{channel_binding},r={server_nonce}");
93
94        // Build auth message
95        let client_first_bare = format!("n={},r={}", saslprep(&self.username), self.client_nonce);
96        let server_first = format!(
97            "r={},s={},i={}",
98            server_nonce,
99            BASE64.encode(salt),
100            iterations
101        );
102        let auth_message =
103            format!("{client_first_bare},{server_first},{client_final_without_proof}");
104        self.auth_message = Some(auth_message.clone());
105
106        // Calculate proof
107        let salted_password = pbkdf2_sha256(self.password.as_bytes(), salt, iterations)
108            .context("Failed to derive salted password with PBKDF2")?;
109        let client_key = hmac_sha256(&salted_password, b"Client Key")
110            .context("Failed to calculate client key")?;
111        let stored_key = sha256(&client_key);
112        let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes())
113            .context("Failed to calculate client signature")?;
114        let client_proof = xor_bytes(&client_key, &client_signature);
115
116        // Build final message
117        Ok(format!(
118            "{},p={}",
119            client_final_without_proof,
120            BASE64.encode(client_proof)
121        ))
122    }
123
124    pub fn verify_server_final(&self, message: &str) -> Result<()> {
125        let params = parse_scram_message(message)?;
126
127        // Check for error
128        if let Some(error) = params.get("e") {
129            return Err(anyhow!("Server error: {error}"));
130        }
131
132        // Verify server signature
133        if let Some(server_sig_b64) = params.get("v") {
134            let auth_message = self
135                .auth_message
136                .as_ref()
137                .ok_or_else(|| anyhow!("Auth message not set"))?;
138            let salt = self.salt.as_ref().ok_or_else(|| anyhow!("Salt not set"))?;
139            let iterations = self
140                .iterations
141                .ok_or_else(|| anyhow!("Iterations not set"))?;
142
143            let salted_password = pbkdf2_sha256(self.password.as_bytes(), salt, iterations)
144                .context("Failed to derive salted password for verification")?;
145            let server_key = hmac_sha256(&salted_password, b"Server Key")
146                .context("Failed to calculate server key")?;
147            let expected_sig = hmac_sha256(&server_key, auth_message.as_bytes())
148                .context("Failed to calculate expected server signature")?;
149
150            let server_sig = BASE64.decode(server_sig_b64)?;
151            if server_sig != expected_sig {
152                return Err(anyhow!("Server signature verification failed"));
153            }
154        } else {
155            return Err(anyhow!("Missing server signature"));
156        }
157
158        Ok(())
159    }
160}
161
162fn generate_nonce() -> String {
163    let mut rng = rand::thread_rng();
164    let bytes: Vec<u8> = (0..18).map(|_| rng.gen()).collect();
165    BASE64.encode(bytes)
166}
167
168fn saslprep(s: &str) -> String {
169    // Simplified saslprep - just escape special characters
170    s.replace('=', "=3D").replace(',', "=2C")
171}
172
173fn parse_scram_message(message: &str) -> Result<HashMap<String, String>> {
174    let mut params = HashMap::new();
175    for part in message.split(',') {
176        if let Some(eq_pos) = part.find('=') {
177            let key = &part[..eq_pos];
178            let value = &part[eq_pos + 1..];
179            params.insert(key.to_string(), value.to_string());
180        }
181    }
182    Ok(params)
183}
184
185fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Result<Vec<u8>> {
186    let mut result = vec![0u8; 32];
187    pbkdf2::pbkdf2::<hmac::Hmac<sha2::Sha256>>(password, salt, iterations, &mut result)
188        .map_err(|e| anyhow::anyhow!("PBKDF2 failed: {e:?}"))?;
189    Ok(result)
190}
191
192fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
193    use hmac::{Hmac, Mac};
194    type HmacSha256 = Hmac<Sha256>;
195
196    let mut mac = HmacSha256::new_from_slice(key)
197        .map_err(|e| anyhow::anyhow!("Invalid HMAC key length: {e}"))?;
198    mac.update(data);
199    Ok(mac.finalize().into_bytes().to_vec())
200}
201
202fn sha256(data: &[u8]) -> Vec<u8> {
203    let mut hasher = Sha256::new();
204    hasher.update(data);
205    hasher.finalize().to_vec()
206}
207
208fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
209    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
210}