use ring;
use std::io::{Read, Write};
use msgs::message::{Message, MessagePayload};
use msgs::deframer::MessageDeframer;
use msgs::fragmenter::{MessageFragmenter, MAX_FRAGMENT_LEN};
use msgs::hsjoiner::HandshakeJoiner;
use msgs::base::Payload;
use msgs::codec::Codec;
use msgs::enums::{ContentType, ProtocolVersion, AlertDescription, AlertLevel};
use error::TLSError;
use suites::SupportedCipherSuite;
use cipher::MessageCipher;
use vecbuf::ChunkVecBuffer;
use prf;
use rand;
use std::io;
use std::collections::VecDeque;
pub trait Session : Read + Write {
fn read_tls(&mut self, rd: &mut Read) -> Result<usize, io::Error>;
fn write_tls(&mut self, wr: &mut Write) -> Result<usize, io::Error>;
fn process_new_packets(&mut self) -> Result<(), TLSError>;
fn wants_read(&self) -> bool;
fn wants_write(&self) -> bool;
fn is_handshaking(&self) -> bool;
fn send_close_notify(&mut self);
fn get_peer_certificates(&self) -> Option<Vec<Vec<u8>>>;
fn get_alpn_protocol(&self) -> Option<String>;
}
#[derive(Clone, Debug)]
pub struct SessionRandoms {
pub we_are_client: bool,
pub client: [u8; 32],
pub server: [u8; 32]
}
impl SessionRandoms {
pub fn for_server() -> SessionRandoms {
let mut ret = SessionRandoms {
we_are_client: false,
client: [0u8; 32],
server: [0u8; 32]
};
rand::fill_random(&mut ret.server);
ret
}
pub fn for_client() -> SessionRandoms {
let mut ret = SessionRandoms {
we_are_client: true,
client: [0u8; 32],
server: [0u8; 32]
};
rand::fill_random(&mut ret.client);
ret
}
}
fn join_randoms(first: &[u8], second: &[u8]) -> [u8; 64] {
let mut randoms = [0u8; 64];
randoms.as_mut().write(first).unwrap();
randoms[32..].as_mut().write(second).unwrap();
randoms
}
pub struct SessionSecrets {
pub randoms: SessionRandoms,
hash: &'static ring::digest::Algorithm,
master_secret: [u8; 48]
}
impl SessionSecrets {
pub fn new(randoms: &SessionRandoms,
hashalg: &'static ring::digest::Algorithm,
pms: &[u8]) -> SessionSecrets {
let mut ret = SessionSecrets {
randoms: randoms.clone(),
hash: hashalg,
master_secret: [0u8; 48]
};
let randoms = join_randoms(&ret.randoms.client, &ret.randoms.server);
prf::prf(&mut ret.master_secret,
ret.hash,
pms,
b"master secret",
&randoms);
ret
}
pub fn new_resume(randoms: &SessionRandoms,
hashalg: &'static ring::digest::Algorithm,
master_secret: &[u8]) -> SessionSecrets {
let mut ret = SessionSecrets {
randoms: randoms.clone(),
hash: hashalg,
master_secret: [0u8; 48]
};
ret.master_secret.as_mut().write(master_secret).unwrap();
ret
}
pub fn make_key_block(&self, len: usize) -> Vec<u8> {
let mut out = Vec::new();
out.resize(len, 0u8);
let randoms = join_randoms(&self.randoms.server, &self.randoms.client);
prf::prf(&mut out,
self.hash,
&self.master_secret,
b"key expansion",
&randoms);
out
}
pub fn get_master_secret(&self) -> Vec<u8> {
let mut ret = Vec::new();
ret.extend_from_slice(&self.master_secret);
ret
}
pub fn make_verify_data(&self, handshake_hash: &Vec<u8>, label: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
out.resize(12, 0u8);
prf::prf(&mut out,
self.hash,
&self.master_secret,
label,
&handshake_hash);
out
}
pub fn client_verify_data(&self, handshake_hash: &Vec<u8>) -> Vec<u8> {
self.make_verify_data(handshake_hash, b"client finished")
}
pub fn server_verify_data(&self, handshake_hash: &Vec<u8>) -> Vec<u8> {
self.make_verify_data(handshake_hash, b"server finished")
}
}
static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
pub struct SessionCommon {
message_cipher: Box<MessageCipher + Send + Sync>,
write_seq: u64,
read_seq: u64,
peer_eof: bool,
pub peer_encrypting: bool,
pub we_encrypting: bool,
pub traffic: bool,
pub message_deframer: MessageDeframer,
pub handshake_joiner: HandshakeJoiner,
pub message_fragmenter: MessageFragmenter,
received_plaintext: ChunkVecBuffer,
sendable_plaintext: ChunkVecBuffer,
pub sendable_tls: ChunkVecBuffer
}
impl SessionCommon {
pub fn new(mtu: Option<usize>) -> SessionCommon {
SessionCommon {
message_cipher: MessageCipher::invalid(),
write_seq: 0,
read_seq: 0,
peer_eof: false,
peer_encrypting: false,
we_encrypting: false,
traffic: false,
message_deframer: MessageDeframer::new(),
handshake_joiner: HandshakeJoiner::new(),
message_fragmenter: MessageFragmenter::new(mtu.unwrap_or(MAX_FRAGMENT_LEN)),
received_plaintext: ChunkVecBuffer::new(),
sendable_plaintext: ChunkVecBuffer::new(),
sendable_tls: ChunkVecBuffer::new(),
}
}
pub fn has_readable_plaintext(&self) -> bool {
!self.received_plaintext.is_empty()
}
pub fn encrypt_outgoing(&mut self, plain: Message) -> Message {
let seq = self.write_seq;
self.write_seq += 1;
self.message_cipher.encrypt(plain, seq).unwrap()
}
pub fn decrypt_incoming(&mut self, plain: Message) -> Result<Message, TLSError> {
if self.read_seq == SEQ_SOFT_LIMIT {
self.send_close_notify();
}
let seq = self.read_seq;
self.read_seq += 1;
self.message_cipher.decrypt(plain, seq)
}
pub fn process_alert(&mut self, msg: Message) -> Result<(), TLSError> {
if let MessagePayload::Alert(ref alert) = msg.payload {
if alert.description == AlertDescription::CloseNotify {
self.peer_eof = true;
return Ok(())
}
if alert.level == AlertLevel::Warning {
warn!("TLS alert warning received: {:#?}", msg);
return Ok(())
}
error!("TLS alert received: {:#?}", msg);
Err(TLSError::AlertReceived(alert.description))
} else {
Err(TLSError::CorruptMessagePayload(ContentType::Alert))
}
}
pub fn send_msg_encrypt(&mut self, m: Message) {
let mut plain_messages = VecDeque::new();
self.message_fragmenter.fragment(m, &mut plain_messages);
for m in plain_messages {
if self.write_seq == SEQ_SOFT_LIMIT {
self.send_close_notify();
}
if self.write_seq >= SEQ_HARD_LIMIT {
return;
}
let em = self.encrypt_outgoing(m);
self.queue_tls_message(em);
}
}
pub fn connection_at_eof(&self) -> bool {
self.peer_eof && !self.message_deframer.has_pending()
}
pub fn read_tls(&mut self, rd: &mut io::Read) -> io::Result<usize> {
self.message_deframer.read(rd)
}
pub fn write_tls(&mut self, wr: &mut io::Write) -> io::Result<usize> {
self.sendable_tls.write_to(wr)
}
pub fn send_plain(&mut self, data: Vec<u8>) {
if !self.traffic {
self.sendable_plaintext.append(data);
return;
}
debug_assert!(self.we_encrypting);
if data.len() == 0 {
return;
}
let m = Message {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(data)
};
self.send_msg_encrypt(m);
}
pub fn start_traffic(&mut self) {
self.traffic = true;
self.flush_plaintext();
}
pub fn flush_plaintext(&mut self) {
if !self.traffic {
return;
}
while !self.sendable_plaintext.is_empty() {
let buf = self.sendable_plaintext.take_one();
self.send_plain(buf);
}
}
fn queue_tls_message(&mut self, m: Message) {
self.sendable_tls.append(m.get_encoding());
}
pub fn send_msg(&mut self, m: Message, must_encrypt: bool) {
if !must_encrypt {
let mut to_send = VecDeque::new();
self.message_fragmenter.fragment(m, &mut to_send);
for mm in to_send {
self.queue_tls_message(mm);
}
} else {
self.send_msg_encrypt(m);
}
}
pub fn take_received_plaintext(&mut self, bytes: Payload) {
self.received_plaintext.append(bytes.0);
}
pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = try!(self.received_plaintext.read(buf));
if len == 0 && self.connection_at_eof() && self.received_plaintext.is_empty() {
return Err(io::Error::new(io::ErrorKind::ConnectionAborted, "CloseNotify alert received"));
}
Ok(len)
}
pub fn start_encryption(&mut self, suite: &'static SupportedCipherSuite, secrets: &SessionSecrets) {
self.message_cipher = MessageCipher::new(suite, secrets);
}
pub fn peer_now_encrypting(&mut self) {
self.peer_encrypting = true;
}
pub fn we_now_encrypting(&mut self) {
self.we_encrypting = true;
}
pub fn send_warning_alert(&mut self, desc: AlertDescription) {
warn!("Sending warning alert {:?}", desc);
let m = Message::build_alert(AlertLevel::Warning, desc);
let enc = self.we_encrypting;
self.send_msg(m, enc);
}
pub fn send_fatal_alert(&mut self, desc: AlertDescription) {
warn!("Sending fatal alert {:?}", desc);
let m = Message::build_alert(AlertLevel::Fatal, desc);
let enc = self.we_encrypting;
self.send_msg(m, enc);
}
pub fn send_close_notify(&mut self) {
self.send_warning_alert(AlertDescription::CloseNotify)
}
}