use ring::digest;
use std::mem;
use msgs::codec::Codec;
use msgs::message::{Message, MessagePayload};
pub struct HandshakeHash {
ctx: Option<digest::Context>,
client_auth_enabled: bool,
buffer: Vec<u8>
}
impl HandshakeHash {
pub fn new() -> HandshakeHash {
HandshakeHash {
ctx: None,
client_auth_enabled: false,
buffer: Vec::new()
}
}
pub fn set_client_auth_enabled(&mut self) {
debug_assert!(self.ctx.is_none()); self.client_auth_enabled = true;
}
pub fn abandon_client_auth(&mut self) {
self.client_auth_enabled = false;
self.buffer.drain(..);
}
pub fn start_hash(&mut self, alg: &'static digest::Algorithm) {
debug_assert!(self.ctx.is_none());
let mut ctx = digest::Context::new(alg);
ctx.update(&self.buffer);
self.ctx = Some(ctx);
if !self.client_auth_enabled {
self.buffer.drain(..);
}
}
pub fn add_message(&mut self, m: &Message) -> &mut HandshakeHash {
match m.payload {
MessagePayload::Handshake(ref hs) => {
let buf = hs.get_encoding();
self.update_raw(&buf);
},
_ => unreachable!()
};
self
}
fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
if self.ctx.is_some() {
self.ctx.as_mut().unwrap().update(buf);
}
if self.ctx.is_none() || self.client_auth_enabled {
self.buffer.extend_from_slice(buf);
}
self
}
pub fn get_current_hash(&self) -> Vec<u8> {
let h = self.ctx.as_ref().unwrap().clone().finish();
let mut ret = Vec::new();
ret.extend_from_slice(h.as_ref());
ret
}
pub fn take_handshake_buf(&mut self) -> Vec<u8> {
debug_assert!(self.client_auth_enabled);
mem::replace(&mut self.buffer, Vec::new())
}
}
#[cfg(test)]
mod test {
use super::HandshakeHash;
use ring::digest;
#[test]
fn hashes_correctly() {
let mut hh = HandshakeHash::new();
hh.update_raw(b"hello");
assert_eq!(hh.buffer.len(), 5);
hh.start_hash(&digest::SHA256);
assert_eq!(hh.buffer.len(), 0);
hh.update_raw(b"world");
let h = hh.get_current_hash();
assert_eq!(h[0], 0x93);
assert_eq!(h[1], 0x6a);
assert_eq!(h[2], 0x18);
assert_eq!(h[3], 0x5c);
}
#[test]
fn buffers_correctly() {
let mut hh = HandshakeHash::new();
hh.set_client_auth_enabled();
hh.update_raw(b"hello");
assert_eq!(hh.buffer.len(), 5);
hh.start_hash(&digest::SHA256);
assert_eq!(hh.buffer.len(), 5);
hh.update_raw(b"world");
assert_eq!(hh.buffer.len(), 10);
let h = hh.get_current_hash();
assert_eq!(h[0], 0x93);
assert_eq!(h[1], 0x6a);
assert_eq!(h[2], 0x18);
assert_eq!(h[3], 0x5c);
let buf = hh.take_handshake_buf();
assert_eq!(b"helloworld".to_vec(), buf);
}
#[test]
fn abandon() {
let mut hh = HandshakeHash::new();
hh.set_client_auth_enabled();
hh.update_raw(b"hello");
assert_eq!(hh.buffer.len(), 5);
hh.start_hash(&digest::SHA256);
assert_eq!(hh.buffer.len(), 5);
hh.abandon_client_auth();
assert_eq!(hh.buffer.len(), 0);
hh.update_raw(b"world");
assert_eq!(hh.buffer.len(), 0);
let h = hh.get_current_hash();
assert_eq!(h[0], 0x93);
assert_eq!(h[1], 0x6a);
assert_eq!(h[2], 0x18);
assert_eq!(h[3], 0x5c);
}
}