sqlmodel_postgres/auth/
scram.rs1use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
4use hmac::{Hmac, Mac};
5use rand::{Rng, distributions::Alphanumeric, rngs::OsRng};
6use sha2::{Digest, Sha256};
7use sqlmodel_core::Error;
8use sqlmodel_core::error::{ConnectionError, ConnectionErrorKind, ProtocolError};
9use subtle::ConstantTimeEq;
10
11type HmacSha256 = Hmac<Sha256>;
12
13pub struct ScramClient {
14 username: String,
15 password: String,
16 client_nonce: String,
17
18 server_nonce: Option<String>,
20 salt: Option<Vec<u8>>,
21 iterations: Option<u32>,
22
23 salted_password: Option<[u8; 32]>,
25 auth_message: Option<String>,
26}
27
28impl ScramClient {
29 pub fn new(username: &str, password: &str) -> Self {
30 let client_nonce: String = OsRng
33 .sample_iter(&Alphanumeric)
34 .take(32)
35 .map(char::from)
36 .collect();
37
38 Self {
39 username: username.to_string(),
40 password: password.to_string(),
41 client_nonce,
42 server_nonce: None,
43 salt: None,
44 iterations: None,
45 salted_password: None,
46 auth_message: None,
47 }
48 }
49
50 pub fn client_first(&self) -> Vec<u8> {
52 format!("n,,n={},r={}", self.username, self.client_nonce).into_bytes()
57 }
58
59 #[allow(clippy::result_large_err)]
61 pub fn process_server_first(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
62 let msg = std::str::from_utf8(data)
63 .map_err(|e| protocol_error(format!("Invalid UTF-8 in SASL continue: {}", e)))?;
64
65 let mut combined_nonce = None;
67 let mut salt = None;
68 let mut iterations = None;
69
70 for part in msg.split(',') {
71 if let Some(value) = part.strip_prefix("r=") {
72 combined_nonce = Some(value.to_string());
73 } else if let Some(value) = part.strip_prefix("s=") {
74 salt = Some(
75 BASE64
76 .decode(value)
77 .map_err(|e| protocol_error(format!("Invalid base64 salt: {}", e)))?,
78 );
79 } else if let Some(value) = part.strip_prefix("i=") {
80 iterations = Some(
81 value
82 .parse()
83 .map_err(|e| protocol_error(format!("Invalid iterations: {}", e)))?,
84 );
85 }
86 }
87
88 let combined_nonce = combined_nonce.ok_or_else(|| protocol_error("Missing nonce"))?;
89 let salt = salt.ok_or_else(|| protocol_error("Missing salt"))?;
90 let iterations = iterations.ok_or_else(|| protocol_error("Missing iterations"))?;
91
92 if !combined_nonce.starts_with(&self.client_nonce) {
94 return Err(protocol_error("Invalid server nonce"));
95 }
96
97 let mut salted_password = [0u8; 32];
99 pbkdf2::pbkdf2::<HmacSha256>(
100 self.password.as_bytes(),
101 &salt,
102 iterations,
103 &mut salted_password,
104 )
105 .map_err(|e| protocol_error(format!("PBKDF2 failed: {}", e)))?;
106
107 let client_first_bare = format!("n={},r={}", self.username, self.client_nonce);
109 let client_final_without_proof = format!("c=biws,r={}", combined_nonce); let auth_message = format!(
111 "{},{},{}",
112 client_first_bare, msg, client_final_without_proof
113 );
114
115 let client_key = hmac_sha256(&salted_password, b"Client Key")?;
117 let stored_key = sha256(&client_key);
118 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes())?;
119
120 let client_proof: Vec<u8> = client_key
121 .iter()
122 .zip(client_signature.iter())
123 .map(|(a, b)| a ^ b)
124 .collect();
125
126 self.server_nonce = Some(combined_nonce.clone());
128 self.salted_password = Some(salted_password);
129 self.auth_message = Some(auth_message);
130 self.salt = Some(salt);
131 self.iterations = Some(iterations);
132
133 let client_final = format!(
135 "c=biws,r={},p={}",
136 combined_nonce,
137 BASE64.encode(&client_proof)
138 );
139
140 Ok(client_final.into_bytes())
141 }
142
143 #[allow(clippy::result_large_err)]
145 pub fn verify_server_final(&self, data: &[u8]) -> Result<(), Error> {
146 let msg = std::str::from_utf8(data)
147 .map_err(|e| protocol_error(format!("Invalid UTF-8 in SASL final: {}", e)))?;
148
149 let server_signature_b64 = msg
150 .strip_prefix("v=")
151 .ok_or_else(|| protocol_error("Invalid server-final format"))?;
152
153 let server_signature = BASE64
154 .decode(server_signature_b64)
155 .map_err(|e| protocol_error(format!("Invalid base64 server signature: {}", e)))?;
156
157 let salted_password = self
159 .salted_password
160 .as_ref()
161 .ok_or_else(|| protocol_error("Missing salted password state"))?;
162 let auth_message = self
163 .auth_message
164 .as_ref()
165 .ok_or_else(|| protocol_error("Missing auth message state"))?;
166
167 let server_key = hmac_sha256(salted_password, b"Server Key")?;
168 let expected_signature = hmac_sha256(&server_key, auth_message.as_bytes())?;
169
170 if server_signature.ct_eq(&expected_signature).into() {
174 Ok(())
175 } else {
176 Err(auth_error("Server signature mismatch"))
177 }
178 }
179}
180
181fn protocol_error(msg: impl Into<String>) -> Error {
184 Error::Protocol(ProtocolError {
185 message: msg.into(),
186 raw_data: None,
187 source: None,
188 })
189}
190
191fn auth_error(msg: impl Into<String>) -> Error {
192 Error::Connection(ConnectionError {
193 kind: ConnectionErrorKind::Authentication,
194 message: msg.into(),
195 source: None,
196 })
197}
198
199#[allow(clippy::result_large_err)]
200fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<[u8; 32], Error> {
201 let mut mac = HmacSha256::new_from_slice(key)
202 .map_err(|e| protocol_error(format!("HMAC init failed: {}", e)))?;
203 mac.update(data);
204 let result = mac.finalize();
205 let bytes = result.into_bytes();
206 Ok(bytes.into())
207}
208
209fn sha256(data: &[u8]) -> [u8; 32] {
210 let mut hasher = Sha256::new();
211 hasher.update(data);
212 hasher.finalize().into()
213}