1use std::borrow::Cow;
2
3use base64;
4use rand::distributions::{Distribution, Uniform};
5use rand::{rngs::OsRng, Rng};
6use ring::digest::SHA256_OUTPUT_LEN;
7use ring::hmac;
8
9use error::{Error, Field, Kind};
10use utils::find_proofs;
11use NONCE_LENGTH;
12
13pub struct ScramServer<P: AuthenticationProvider> {
16 provider: P,
18}
19
20pub struct PasswordInfo {
23 hashed_password: Vec<u8>,
24 salt: Vec<u8>,
25 iterations: u16,
26}
27
28#[derive(Clone, Copy, PartialEq, Debug)]
30pub enum AuthenticationStatus {
31 Authenticated,
33 NotAuthenticated,
35 NotAuthorized,
38}
39
40impl PasswordInfo {
41 pub fn new(hashed_password: Vec<u8>, iterations: u16, salt: Vec<u8>) -> Self {
44 PasswordInfo {
45 hashed_password,
46 iterations,
47 salt,
48 }
49 }
50}
51
52pub trait AuthenticationProvider {
59 fn get_password_for(&self, username: &str) -> Option<PasswordInfo>;
61
62 fn authorize(&self, authcid: &str, authzid: &str) -> bool {
66 authcid == authzid
67 }
68}
69
70fn parse_client_first(data: &str) -> Result<(&str, Option<&str>, &str), Error> {
73 let mut parts = data.split(',');
74
75 if let Some(part) = parts.next() {
77 if let Some(cb) = part.chars().next() {
78 if cb == 'p' {
79 return Err(Error::UnsupportedExtension);
80 }
81 if cb != 'n' && cb != 'y' || part.len() > 1 {
82 return Err(Error::Protocol(Kind::InvalidField(Field::ChannelBinding)));
83 }
84 } else {
85 return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
86 }
87 } else {
88 return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
89 }
90
91 let authzid = if let Some(part) = parts.next() {
93 if part.is_empty() {
94 None
95 } else if part.len() < 2 || &part.as_bytes()[..2] != b"a=" {
96 return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
97 } else {
98 Some(&part[2..])
99 }
100 } else {
101 return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
102 };
103
104 let authcid = parse_part!(parts, Authcid, b"n=");
106
107 let nonce = match parts.next() {
109 Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
110 _ => {
111 return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
112 }
113 };
114 Ok((authcid, authzid, nonce))
115}
116
117fn parse_client_final(data: &str) -> Result<(&str, &str, &str), Error> {
119 let mut parts = data.split(',');
121 let gs2header = parse_part!(parts, GS2Header, b"c=");
122 let nonce = parse_part!(parts, Nonce, b"r=");
123 let proof = parse_part!(parts, Proof, b"p=");
124 Ok((gs2header, nonce, proof))
125}
126
127impl<P: AuthenticationProvider> ScramServer<P> {
128 pub fn new(provider: P) -> Self {
130 ScramServer { provider }
131 }
132
133 pub fn handle_client_first<'a>(
137 &'a self,
138 client_first: &'a str,
139 ) -> Result<ServerFirst<'a, P>, Error> {
140 let (authcid, authzid, client_nonce) = parse_client_first(client_first)?;
141 let password_info = self
142 .provider
143 .get_password_for(authcid)
144 .ok_or_else(|| Error::InvalidUser(authcid.to_string()))?;
145 Ok(ServerFirst {
146 client_nonce,
147 authcid,
148 authzid,
149 provider: &self.provider,
150 password_info,
151 })
152 }
153}
154
155pub struct ServerFirst<'a, P: 'a + AuthenticationProvider> {
158 client_nonce: &'a str,
159 authcid: &'a str,
160 authzid: Option<&'a str>,
161 provider: &'a P,
162 password_info: PasswordInfo,
163}
164
165impl<'a, P: AuthenticationProvider> ServerFirst<'a, P> {
166 pub fn server_first(self) -> (ClientFinal<'a, P>, String) {
172 self.server_first_with_rng(&mut OsRng)
173 }
174
175 pub fn server_first_with_rng<R: Rng>(self, rng: &mut R) -> (ClientFinal<'a, P>, String) {
180 let mut nonce = String::with_capacity(self.client_nonce.len() + NONCE_LENGTH);
181 nonce.push_str(self.client_nonce);
182 nonce.extend(
183 Uniform::from(33..125)
184 .sample_iter(rng)
185 .map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
186 .take(NONCE_LENGTH),
187 );
188
189 let gs2header: Cow<'static, str> = match self.authzid {
190 Some(authzid) => format!("n,a={},", authzid).into(),
191 None => "n,,".into(),
192 };
193 let client_first_bare: Cow<'static, str> =
194 format!("n={},r={}", self.authcid, self.client_nonce).into();
195 let server_first: Cow<'static, str> = format!(
196 "r={},s={},i={}",
197 nonce,
198 base64::encode(self.password_info.salt.as_slice()),
199 self.password_info.iterations
200 )
201 .into();
202 (
203 ClientFinal {
204 hashed_password: self.password_info.hashed_password,
205 nonce,
206 gs2header,
207 client_first_bare,
208 server_first: server_first.clone(),
209 authcid: self.authcid,
210 authzid: self.authzid,
211 provider: self.provider,
212 },
213 server_first.into_owned(),
214 )
215 }
216}
217
218pub struct ClientFinal<'a, P: 'a + AuthenticationProvider> {
221 hashed_password: Vec<u8>,
222 nonce: String,
223 gs2header: Cow<'static, str>,
224 client_first_bare: Cow<'static, str>,
225 server_first: Cow<'static, str>,
226 authcid: &'a str,
227 authzid: Option<&'a str>,
228 provider: &'a P,
229}
230
231impl<'a, P: AuthenticationProvider> ClientFinal<'a, P> {
232 pub fn handle_client_final(self, client_final: &str) -> Result<ServerFinal, Error> {
239 let (gs2header_enc, nonce, proof) = parse_client_final(client_final)?;
240 if !self.verify_header(gs2header_enc) {
241 return Err(Error::Protocol(Kind::InvalidField(Field::GS2Header)));
242 }
243 if !self.verify_nonce(nonce) {
244 return Err(Error::Protocol(Kind::InvalidField(Field::Nonce)));
245 }
246 if let Some(signature) = self.verify_proof(proof)? {
247 if let Some(authzid) = self.authzid {
248 if self.provider.authorize(self.authcid, authzid) {
249 Ok(ServerFinal {
250 status: AuthenticationStatus::Authenticated,
251 signature,
252 })
253 } else {
254 Ok(ServerFinal {
255 status: AuthenticationStatus::NotAuthorized,
256 signature: format!(
257 "e=User '{}' not authorized to act as '{}'",
258 self.authcid, authzid
259 ),
260 })
261 }
262 } else {
263 Ok(ServerFinal {
264 status: AuthenticationStatus::Authenticated,
265 signature,
266 })
267 }
268 } else {
269 Ok(ServerFinal {
270 status: AuthenticationStatus::NotAuthenticated,
271 signature: "e=Invalid Password".to_string(),
272 })
273 }
274 }
275
276 fn verify_header(&self, gs2header: &str) -> bool {
278 let server_gs2header = base64::encode(self.gs2header.as_bytes());
279 server_gs2header == gs2header
280 }
281
282 fn verify_nonce(&self, nonce: &str) -> bool {
284 nonce == self.nonce
285 }
286
287 fn verify_proof(&self, proof: &str) -> Result<Option<String>, Error> {
289 let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Tag) = find_proofs(
290 &self.gs2header,
291 &self.client_first_bare,
292 &self.server_first,
293 self.hashed_password.as_slice(),
294 &self.nonce,
295 );
296 let proof = if let Ok(proof) = base64::decode(proof.as_bytes()) {
297 proof
298 } else {
299 return Err(Error::Protocol(Kind::InvalidField(Field::Proof)));
300 };
301 if proof != client_proof {
302 return Ok(None);
303 }
304
305 let server_signature_string = format!("v={}", base64::encode(server_signature.as_ref()));
306 Ok(Some(server_signature_string))
307 }
308}
309
310pub struct ServerFinal {
313 status: AuthenticationStatus,
314 signature: String,
315}
316
317impl ServerFinal {
318 pub fn server_final(self) -> (AuthenticationStatus, String) {
321 (self.status, self.signature)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::super::{Error, Field, Kind};
328 use super::{parse_client_final, parse_client_first};
329
330 #[test]
331 fn test_parse_client_first_success() {
332 let (authcid, authzid, nonce) = parse_client_first("n,,n=user,r=abcdefghijk").unwrap();
333 assert_eq!(authcid, "user");
334 assert!(authzid.is_none());
335 assert_eq!(nonce, "abcdefghijk");
336
337 let (authcid, authzid, nonce) =
338 parse_client_first("y,a=other user,n=user,r=abcdef=hijk").unwrap();
339 assert_eq!(authcid, "user");
340 assert_eq!(authzid, Some("other user"));
341 assert_eq!(nonce, "abcdef=hijk");
342
343 let (authcid, authzid, nonce) = parse_client_first("n,,n=,r=").unwrap();
344 assert_eq!(authcid, "");
345 assert!(authzid.is_none());
346 assert_eq!(nonce, "");
347 }
348
349 #[test]
350 fn test_parse_client_first_missing_fields() {
351 assert_eq!(
352 parse_client_first("n,,n=user").unwrap_err(),
353 Error::Protocol(Kind::ExpectedField(Field::Nonce))
354 );
355 assert_eq!(
356 parse_client_first("n,,r=user").unwrap_err(),
357 Error::Protocol(Kind::ExpectedField(Field::Authcid))
358 );
359 assert_eq!(
360 parse_client_first("n,n=user,r=abc").unwrap_err(),
361 Error::Protocol(Kind::ExpectedField(Field::Authzid))
362 );
363 assert_eq!(
364 parse_client_first(",,n=user,r=abc").unwrap_err(),
365 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
366 );
367 assert_eq!(
368 parse_client_first("").unwrap_err(),
369 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
370 );
371 assert_eq!(
372 parse_client_first(",,,").unwrap_err(),
373 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
374 );
375 }
376 #[test]
377 fn test_parse_client_first_invalid_data() {
378 assert_eq!(
379 parse_client_first("a,,n=user,r=abc").unwrap_err(),
380 Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
381 );
382 assert_eq!(
383 parse_client_first("p,,n=user,r=abc").unwrap_err(),
384 Error::UnsupportedExtension
385 );
386 assert_eq!(
387 parse_client_first("nn,,n=user,r=abc").unwrap_err(),
388 Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
389 );
390 assert_eq!(
391 parse_client_first("n,,n,r=abc").unwrap_err(),
392 Error::Protocol(Kind::ExpectedField(Field::Authcid))
393 );
394 }
395
396 #[test]
397 fn test_parse_client_final_success() {
398 let (gs2head, nonce, proof) = parse_client_final("c=abc,r=abcefg,p=783232").unwrap();
399 assert_eq!(gs2head, "abc");
400 assert_eq!(nonce, "abcefg");
401 assert_eq!(proof, "783232");
402
403 let (gs2head, nonce, proof) = parse_client_final("c=,r=,p=").unwrap();
404 assert_eq!(gs2head, "");
405 assert_eq!(nonce, "");
406 assert_eq!(proof, "");
407 }
408
409 #[test]
410 fn test_parse_client_final_missing_fields() {
411 assert_eq!(
412 parse_client_final("c=whatever,r=something").unwrap_err(),
413 Error::Protocol(Kind::ExpectedField(Field::Proof))
414 );
415 assert_eq!(
416 parse_client_final("c=whatever,p=words").unwrap_err(),
417 Error::Protocol(Kind::ExpectedField(Field::Nonce))
418 );
419 assert_eq!(
420 parse_client_final("c=whatever").unwrap_err(),
421 Error::Protocol(Kind::ExpectedField(Field::Nonce))
422 );
423 assert_eq!(
424 parse_client_final("c=").unwrap_err(),
425 Error::Protocol(Kind::ExpectedField(Field::Nonce))
426 );
427 assert_eq!(
428 parse_client_final("").unwrap_err(),
429 Error::Protocol(Kind::ExpectedField(Field::GS2Header))
430 );
431 assert_eq!(
432 parse_client_final("r=anonce").unwrap_err(),
433 Error::Protocol(Kind::ExpectedField(Field::GS2Header))
434 );
435 }
436}