drasi_source_postgres/
scram.rs1use anyhow::{anyhow, Context, Result};
16use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
17use rand::Rng;
18use sha2::{Digest, Sha256};
19use std::collections::HashMap;
20
21pub struct ScramClient {
22 username: String,
23 password: String,
24 client_nonce: String,
25 server_nonce: Option<String>,
26 salt: Option<Vec<u8>>,
27 iterations: Option<u32>,
28 auth_message: Option<String>,
29}
30
31impl ScramClient {
32 pub fn new(username: &str, password: &str) -> Self {
33 let client_nonce = generate_nonce();
34 Self {
35 username: username.to_string(),
36 password: password.to_string(),
37 client_nonce,
38 server_nonce: None,
39 salt: None,
40 iterations: None,
41 auth_message: None,
42 }
43 }
44
45 pub fn client_first_message(&self) -> String {
46 let gs2_header = "n,,"; let client_first_bare = format!("n={},r={}", saslprep(&self.username), self.client_nonce);
49 format!("{gs2_header}{client_first_bare}")
50 }
51
52 pub fn process_server_first_message(&mut self, message: &str) -> Result<()> {
53 let params = parse_scram_message(message)?;
55
56 let server_nonce = params
58 .get("r")
59 .ok_or_else(|| anyhow!("Missing nonce in server response"))?;
60 if !server_nonce.starts_with(&self.client_nonce) {
61 return Err(anyhow!("Server nonce doesn't include client nonce"));
62 }
63 self.server_nonce = Some(server_nonce.clone());
64
65 let salt_b64 = params
67 .get("s")
68 .ok_or_else(|| anyhow!("Missing salt in server response"))?;
69 self.salt = Some(BASE64.decode(salt_b64)?);
70
71 let iterations_str = params
73 .get("i")
74 .ok_or_else(|| anyhow!("Missing iteration count in server response"))?;
75 self.iterations = Some(iterations_str.parse()?);
76
77 Ok(())
78 }
79
80 pub fn client_final_message(&mut self) -> Result<String> {
81 let server_nonce = self
82 .server_nonce
83 .as_ref()
84 .ok_or_else(|| anyhow!("Server nonce not set"))?;
85 let salt = self.salt.as_ref().ok_or_else(|| anyhow!("Salt not set"))?;
86 let iterations = self
87 .iterations
88 .ok_or_else(|| anyhow!("Iterations not set"))?;
89
90 let channel_binding = "c=biws"; let client_final_without_proof = format!("{channel_binding},r={server_nonce}");
93
94 let client_first_bare = format!("n={},r={}", saslprep(&self.username), self.client_nonce);
96 let server_first = format!(
97 "r={},s={},i={}",
98 server_nonce,
99 BASE64.encode(salt),
100 iterations
101 );
102 let auth_message =
103 format!("{client_first_bare},{server_first},{client_final_without_proof}");
104 self.auth_message = Some(auth_message.clone());
105
106 let salted_password = pbkdf2_sha256(self.password.as_bytes(), salt, iterations)
108 .context("Failed to derive salted password with PBKDF2")?;
109 let client_key = hmac_sha256(&salted_password, b"Client Key")
110 .context("Failed to calculate client key")?;
111 let stored_key = sha256(&client_key);
112 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes())
113 .context("Failed to calculate client signature")?;
114 let client_proof = xor_bytes(&client_key, &client_signature);
115
116 Ok(format!(
118 "{},p={}",
119 client_final_without_proof,
120 BASE64.encode(client_proof)
121 ))
122 }
123
124 pub fn verify_server_final(&self, message: &str) -> Result<()> {
125 let params = parse_scram_message(message)?;
126
127 if let Some(error) = params.get("e") {
129 return Err(anyhow!("Server error: {error}"));
130 }
131
132 if let Some(server_sig_b64) = params.get("v") {
134 let auth_message = self
135 .auth_message
136 .as_ref()
137 .ok_or_else(|| anyhow!("Auth message not set"))?;
138 let salt = self.salt.as_ref().ok_or_else(|| anyhow!("Salt not set"))?;
139 let iterations = self
140 .iterations
141 .ok_or_else(|| anyhow!("Iterations not set"))?;
142
143 let salted_password = pbkdf2_sha256(self.password.as_bytes(), salt, iterations)
144 .context("Failed to derive salted password for verification")?;
145 let server_key = hmac_sha256(&salted_password, b"Server Key")
146 .context("Failed to calculate server key")?;
147 let expected_sig = hmac_sha256(&server_key, auth_message.as_bytes())
148 .context("Failed to calculate expected server signature")?;
149
150 let server_sig = BASE64.decode(server_sig_b64)?;
151 if server_sig != expected_sig {
152 return Err(anyhow!("Server signature verification failed"));
153 }
154 } else {
155 return Err(anyhow!("Missing server signature"));
156 }
157
158 Ok(())
159 }
160}
161
162fn generate_nonce() -> String {
163 let mut rng = rand::thread_rng();
164 let bytes: Vec<u8> = (0..18).map(|_| rng.gen()).collect();
165 BASE64.encode(bytes)
166}
167
168fn saslprep(s: &str) -> String {
169 s.replace('=', "=3D").replace(',', "=2C")
171}
172
173fn parse_scram_message(message: &str) -> Result<HashMap<String, String>> {
174 let mut params = HashMap::new();
175 for part in message.split(',') {
176 if let Some(eq_pos) = part.find('=') {
177 let key = &part[..eq_pos];
178 let value = &part[eq_pos + 1..];
179 params.insert(key.to_string(), value.to_string());
180 }
181 }
182 Ok(params)
183}
184
185fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Result<Vec<u8>> {
186 let mut result = vec![0u8; 32];
187 pbkdf2::pbkdf2::<hmac::Hmac<sha2::Sha256>>(password, salt, iterations, &mut result)
188 .map_err(|e| anyhow::anyhow!("PBKDF2 failed: {e:?}"))?;
189 Ok(result)
190}
191
192fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
193 use hmac::{Hmac, Mac};
194 type HmacSha256 = Hmac<Sha256>;
195
196 let mut mac = HmacSha256::new_from_slice(key)
197 .map_err(|e| anyhow::anyhow!("Invalid HMAC key length: {e}"))?;
198 mac.update(data);
199 Ok(mac.finalize().into_bytes().to_vec())
200}
201
202fn sha256(data: &[u8]) -> Vec<u8> {
203 let mut hasher = Sha256::new();
204 hasher.update(data);
205 hasher.finalize().to_vec()
206}
207
208fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
209 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
210}