1use base64::Engine;
6use base64::engine::general_purpose::STANDARD as B64;
7use hmac::{Hmac, KeyInit, Mac};
8use ring::rand::{SecureRandom, SystemRandom};
9use sha2::{Digest, Sha256, Sha512};
10use subtle::ConstantTimeEq;
11
12use super::{ScramCredential, scram_hash_len};
13use crate::{AuthError, AuthMethod, Principal, SaslMechanism};
14
15#[derive(Debug)]
16enum State {
17 AwaitingClientFirst,
18 AwaitingClientFinal {
19 client_first_bare: String,
20 server_first: String,
21 },
22 Finished,
23}
24
25#[derive(Debug)]
26pub struct ScramServerExchange {
27 username: String,
28 credential: ScramCredential,
29 state: State,
30 principal_override: Option<Principal>,
38}
39
40#[derive(Debug)]
41pub enum StepResult {
42 Continue(Vec<u8>),
43 Done(Principal, Vec<u8>),
44 Failed(AuthError),
45}
46
47impl ScramServerExchange {
48 #[must_use]
49 pub fn new(username: String, credential: ScramCredential) -> Self {
50 Self {
51 username,
52 credential,
53 state: State::AwaitingClientFirst,
54 principal_override: None,
55 }
56 }
57
58 #[must_use]
65 pub fn new_with_principal(
66 username: String,
67 credential: ScramCredential,
68 override_principal: Principal,
69 ) -> Self {
70 Self {
71 username,
72 credential,
73 state: State::AwaitingClientFirst,
74 principal_override: Some(override_principal),
75 }
76 }
77
78 pub fn step(&mut self, client_bytes: &[u8]) -> StepResult {
79 match std::mem::replace(&mut self.state, State::Finished) {
80 State::AwaitingClientFirst => self.step_first(client_bytes),
81 State::AwaitingClientFinal {
82 client_first_bare,
83 server_first,
84 } => self.step_final(client_bytes, &client_first_bare, &server_first),
85 State::Finished => StepResult::Failed(AuthError::MalformedMessage),
86 }
87 }
88
89 fn step_first(&mut self, client_bytes: &[u8]) -> StepResult {
90 let Ok(s) = std::str::from_utf8(client_bytes) else {
91 return StepResult::Failed(AuthError::MalformedMessage);
92 };
93 let Some(bare) = s.strip_prefix("n,,") else {
95 return StepResult::Failed(AuthError::MalformedMessage);
96 };
97 let mut user = None;
98 let mut nonce = None;
99 for attr in bare.split(',') {
100 if let Some(v) = attr.strip_prefix("n=") {
101 user = Some(v.to_string());
102 } else if let Some(v) = attr.strip_prefix("r=") {
103 nonce = Some(v.to_string());
104 }
105 }
106 let (Some(u), Some(c_nonce)) = (user, nonce) else {
107 return StepResult::Failed(AuthError::MalformedMessage);
108 };
109 if u != self.username {
110 return StepResult::Failed(AuthError::UnknownUser);
111 }
112 let mut server_nonce_bytes = [0u8; 18];
113 SystemRandom::new()
114 .fill(&mut server_nonce_bytes)
115 .expect("rng");
116 let server_nonce = B64.encode(server_nonce_bytes);
117 let combined_nonce = format!("{c_nonce}{server_nonce}");
118 let server_first = format!(
119 "r={},s={},i={}",
120 combined_nonce,
121 B64.encode(&self.credential.salt),
122 self.credential.iterations,
123 );
124 let response = server_first.clone().into_bytes();
125 self.state = State::AwaitingClientFinal {
126 client_first_bare: bare.to_string(),
127 server_first,
128 };
129 StepResult::Continue(response)
130 }
131
132 fn step_final(
133 &mut self,
134 client_bytes: &[u8],
135 client_first_bare: &str,
136 server_first: &str,
137 ) -> StepResult {
138 let Ok(s) = std::str::from_utf8(client_bytes) else {
139 return StepResult::Failed(AuthError::MalformedMessage);
140 };
141 let mut channel_binding = None;
142 let mut nonce = None;
143 let mut proof_b64 = None;
144 for attr in s.split(',') {
145 if let Some(v) = attr.strip_prefix("c=") {
146 channel_binding = Some(v);
147 } else if let Some(v) = attr.strip_prefix("r=") {
148 nonce = Some(v);
149 } else if let Some(v) = attr.strip_prefix("p=") {
150 proof_b64 = Some(v);
151 }
152 }
153 let (Some(cb), Some(nonce), Some(proof_b64)) = (channel_binding, nonce, proof_b64) else {
154 return StepResult::Failed(AuthError::MalformedMessage);
155 };
156
157 let expected_nonce = server_first
163 .strip_prefix("r=")
164 .and_then(|rest| rest.split(',').next())
165 .unwrap_or_default();
166 if nonce != expected_nonce {
167 return StepResult::Failed(AuthError::MalformedMessage);
168 }
169
170 if cb != B64.encode(b"n,,") {
173 return StepResult::Failed(AuthError::MalformedMessage);
174 }
175
176 let expected_proof_len = scram_hash_len(self.credential.mechanism);
177 let proof = match B64.decode(proof_b64) {
178 Ok(b) if b.len() == expected_proof_len => b,
179 _ => return StepResult::Failed(AuthError::MalformedMessage),
180 };
181
182 let Some(cf_no_proof_end) = s.rfind(",p=") else {
184 return StepResult::Failed(AuthError::MalformedMessage);
185 };
186 let client_final_no_proof = &s[..cf_no_proof_end];
187
188 let auth_message = format!("{client_first_bare},{server_first},{client_final_no_proof}");
189
190 let (computed_stored, server_signature) = match self.credential.mechanism {
191 SaslMechanism::ScramSha512 => verify_and_sign_sha512(
192 &self.credential.stored_key,
193 &self.credential.server_key,
194 &proof,
195 auth_message.as_bytes(),
196 ),
197 SaslMechanism::ScramSha256 => verify_and_sign_sha256(
198 &self.credential.stored_key,
199 &self.credential.server_key,
200 &proof,
201 auth_message.as_bytes(),
202 ),
203 SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
204 return StepResult::Failed(AuthError::MalformedMessage);
205 }
206 };
207
208 if computed_stored
209 .ct_eq(self.credential.stored_key.as_slice())
210 .unwrap_u8()
211 != 1
212 {
213 return StepResult::Failed(AuthError::BadProof);
214 }
215 let server_final = format!("v={}", B64.encode(&server_signature));
216 let principal = self
221 .principal_override
222 .clone()
223 .unwrap_or_else(|| Principal {
224 name: self.username.clone(),
225 auth_method: AuthMethod::from_sasl(self.credential.mechanism),
226 groups: vec![],
227 });
228 StepResult::Done(principal, server_final.into_bytes())
229 }
230}
231
232fn verify_and_sign_sha512(
233 stored_key: &[u8],
234 server_key: &[u8],
235 proof: &[u8],
236 auth_message: &[u8],
237) -> (Vec<u8>, Vec<u8>) {
238 let mut mac = <Hmac<Sha512>>::new_from_slice(stored_key).expect("hmac");
239 mac.update(auth_message);
240 let client_signature = mac.finalize().into_bytes();
241 let client_key: Vec<u8> = client_signature
242 .iter()
243 .zip(proof.iter())
244 .map(|(a, b)| a ^ b)
245 .collect();
246 let computed_stored = Sha512::digest(&client_key).to_vec();
247 let mut server_mac = <Hmac<Sha512>>::new_from_slice(server_key).expect("hmac");
248 server_mac.update(auth_message);
249 let server_signature = server_mac.finalize().into_bytes().to_vec();
250 (computed_stored, server_signature)
251}
252
253fn verify_and_sign_sha256(
254 stored_key: &[u8],
255 server_key: &[u8],
256 proof: &[u8],
257 auth_message: &[u8],
258) -> (Vec<u8>, Vec<u8>) {
259 let mut mac = <Hmac<Sha256>>::new_from_slice(stored_key).expect("hmac");
260 mac.update(auth_message);
261 let client_signature = mac.finalize().into_bytes();
262 let client_key: Vec<u8> = client_signature
263 .iter()
264 .zip(proof.iter())
265 .map(|(a, b)| a ^ b)
266 .collect();
267 let computed_stored = Sha256::digest(&client_key).to_vec();
268 let mut server_mac = <Hmac<Sha256>>::new_from_slice(server_key).expect("hmac");
269 server_mac.update(auth_message);
270 let server_signature = server_mac.finalize().into_bytes().to_vec();
271 (computed_stored, server_signature)
272}