1use base64::Engine;
6use base64::engine::general_purpose::STANDARD as B64;
7use hmac::{Hmac, KeyInit, Mac};
8use ring::rand::{SecureRandom, SystemRandom};
9use sha2::{Digest, Sha256, Sha512};
10use subtle::ConstantTimeEq;
11
12use crate::{AuthError, SaslMechanism};
13
14#[derive(Debug)]
15enum State {
16 Initial,
17 AwaitingServerFirst {
18 client_first_bare: String,
19 client_nonce: String,
20 },
21 AwaitingServerFinal {
22 auth_message: String,
23 server_key: Vec<u8>,
24 },
25 Finished,
26}
27
28#[derive(Debug)]
29pub struct ScramClientExchange {
30 username: String,
31 password: Vec<u8>,
32 mechanism: SaslMechanism,
33 state: State,
34}
35
36impl ScramClientExchange {
37 #[must_use]
38 pub fn new(username: String, password: Vec<u8>, mechanism: SaslMechanism) -> Self {
39 assert!(
40 mechanism.is_scram(),
41 "ScramClientExchange::new called with non-SCRAM mechanism {mechanism:?}"
42 );
43 Self {
44 username,
45 password,
46 mechanism,
47 state: State::Initial,
48 }
49 }
50
51 pub fn client_first(&mut self) -> Result<Vec<u8>, AuthError> {
52 if !matches!(self.state, State::Initial) {
53 return Err(AuthError::MalformedMessage);
54 }
55 let mut nonce_bytes = [0u8; 18];
56 SystemRandom::new()
57 .fill(&mut nonce_bytes)
58 .map_err(|_| AuthError::MalformedMessage)?;
59 let client_nonce = B64.encode(nonce_bytes);
60 let bare = format!("n={},r={}", self.username, client_nonce);
61 let msg = format!("n,,{bare}");
62 self.state = State::AwaitingServerFirst {
63 client_first_bare: bare,
64 client_nonce,
65 };
66 Ok(msg.into_bytes())
67 }
68
69 pub fn step(&mut self, server_bytes: &[u8]) -> Result<Vec<u8>, AuthError> {
70 let State::AwaitingServerFirst {
71 client_first_bare,
72 client_nonce,
73 } = std::mem::replace(&mut self.state, State::Finished)
74 else {
75 return Err(AuthError::MalformedMessage);
76 };
77 let s = std::str::from_utf8(server_bytes).map_err(|_| AuthError::MalformedMessage)?;
78 let mut nonce = None;
79 let mut salt = None;
80 let mut iterations = None;
81 for attr in s.split(',') {
82 if let Some(v) = attr.strip_prefix("r=") {
83 nonce = Some(v.to_string());
84 } else if let Some(v) = attr.strip_prefix("s=") {
85 salt = Some(B64.decode(v).map_err(|_| AuthError::MalformedMessage)?);
86 } else if let Some(v) = attr.strip_prefix("i=") {
87 iterations = Some(v.parse::<u32>().map_err(|_| AuthError::MalformedMessage)?);
88 }
89 }
90 let (Some(combined_nonce), Some(salt), Some(iters)) = (nonce, salt, iterations) else {
91 return Err(AuthError::MalformedMessage);
92 };
93 if !combined_nonce.starts_with(&client_nonce) {
94 return Err(AuthError::BadProof);
95 }
96
97 let channel_binding = B64.encode(b"n,,");
98 let client_final_no_proof = format!("c={channel_binding},r={combined_nonce}");
99 let auth_message = format!("{client_first_bare},{s},{client_final_no_proof}");
100
101 let (proof, server_key) = match self.mechanism {
102 SaslMechanism::ScramSha512 => {
103 compute_proof_sha512(&self.password, &salt, iters, auth_message.as_bytes())?
104 }
105 SaslMechanism::ScramSha256 => {
106 compute_proof_sha256(&self.password, &salt, iters, auth_message.as_bytes())?
107 }
108 SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
109 return Err(AuthError::MalformedMessage);
110 }
111 };
112
113 let client_final = format!("{client_final_no_proof},p={}", B64.encode(&proof));
114 self.state = State::AwaitingServerFinal {
115 auth_message,
116 server_key,
117 };
118 Ok(client_final.into_bytes())
119 }
120
121 pub fn verify_server_final(&mut self, server_bytes: &[u8]) -> Result<(), AuthError> {
122 let State::AwaitingServerFinal {
123 auth_message,
124 server_key,
125 } = std::mem::replace(&mut self.state, State::Finished)
126 else {
127 return Err(AuthError::MalformedMessage);
128 };
129 let s = std::str::from_utf8(server_bytes).map_err(|_| AuthError::MalformedMessage)?;
130 let v_b64 = s.strip_prefix("v=").ok_or(AuthError::MalformedMessage)?;
131 let v = B64.decode(v_b64).map_err(|_| AuthError::MalformedMessage)?;
132 let expected: Vec<u8> = match self.mechanism {
133 SaslMechanism::ScramSha512 => {
134 let mut mac = <Hmac<Sha512>>::new_from_slice(&server_key)
135 .map_err(|_| AuthError::MalformedMessage)?;
136 mac.update(auth_message.as_bytes());
137 mac.finalize().into_bytes().to_vec()
138 }
139 SaslMechanism::ScramSha256 => {
140 let mut mac = <Hmac<Sha256>>::new_from_slice(&server_key)
141 .map_err(|_| AuthError::MalformedMessage)?;
142 mac.update(auth_message.as_bytes());
143 mac.finalize().into_bytes().to_vec()
144 }
145 SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
146 return Err(AuthError::MalformedMessage);
147 }
148 };
149 if expected.ct_eq(&v).unwrap_u8() != 1 {
150 return Err(AuthError::BadProof);
151 }
152 Ok(())
153 }
154}
155
156fn compute_proof_sha512(
157 password: &[u8],
158 salt: &[u8],
159 iters: u32,
160 auth_message: &[u8],
161) -> Result<(Vec<u8>, Vec<u8>), AuthError> {
162 let salted: [u8; 64] = pbkdf2::pbkdf2_hmac_array::<Sha512, 64>(password, salt, iters);
163 let mut client_key_mac =
164 <Hmac<Sha512>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
165 client_key_mac.update(b"Client Key");
166 let client_key = client_key_mac.finalize().into_bytes();
167 let stored_key = Sha512::digest(client_key);
168 let mut server_key_mac =
169 <Hmac<Sha512>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
170 server_key_mac.update(b"Server Key");
171 let server_key = server_key_mac.finalize().into_bytes().to_vec();
172
173 let mut client_sig_mac =
174 <Hmac<Sha512>>::new_from_slice(&stored_key).map_err(|_| AuthError::MalformedMessage)?;
175 client_sig_mac.update(auth_message);
176 let client_signature = client_sig_mac.finalize().into_bytes();
177 let proof: Vec<u8> = client_key
178 .iter()
179 .zip(client_signature.iter())
180 .map(|(a, b)| a ^ b)
181 .collect();
182 Ok((proof, server_key))
183}
184
185fn compute_proof_sha256(
186 password: &[u8],
187 salt: &[u8],
188 iters: u32,
189 auth_message: &[u8],
190) -> Result<(Vec<u8>, Vec<u8>), AuthError> {
191 let salted: [u8; 32] = pbkdf2::pbkdf2_hmac_array::<Sha256, 32>(password, salt, iters);
192 let mut client_key_mac =
193 <Hmac<Sha256>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
194 client_key_mac.update(b"Client Key");
195 let client_key = client_key_mac.finalize().into_bytes();
196 let stored_key = Sha256::digest(client_key);
197 let mut server_key_mac =
198 <Hmac<Sha256>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
199 server_key_mac.update(b"Server Key");
200 let server_key = server_key_mac.finalize().into_bytes().to_vec();
201
202 let mut client_sig_mac =
203 <Hmac<Sha256>>::new_from_slice(&stored_key).map_err(|_| AuthError::MalformedMessage)?;
204 client_sig_mac.update(auth_message);
205 let client_signature = client_sig_mac.finalize().into_bytes();
206 let proof: Vec<u8> = client_key
207 .iter()
208 .zip(client_signature.iter())
209 .map(|(a, b)| a ^ b)
210 .collect();
211 Ok((proof, server_key))
212}