use session::{Session, SessionSecrets, SessionCommon};
use suites::{SupportedCipherSuite, ALL_CIPHERSUITES, KeyExchange};
use msgs::enums::ContentType;
use msgs::enums::{AlertDescription, HandshakeType};
use msgs::handshake::{SessionID, CertificatePayload, ASN1Cert};
use msgs::handshake::{ServerNameRequest, SupportedSignatureAlgorithms};
use msgs::handshake::{EllipticCurveList, ECPointFormatList};
use msgs::message::Message;
use hash_hs;
use server_hs;
use error::TLSError;
use rand;
use sign;
use verify;
use std::sync::Arc;
use std::io;
pub trait StoresSessions {
fn generate(&self) -> SessionID;
fn store(&self, id: &SessionID, sec: &SessionSecrets) -> bool;
fn find(&self, id: &SessionID) -> Option<SessionSecrets>;
fn erase(&self, id: &SessionID) -> bool;
}
pub trait ResolvesCert {
fn resolve(&self,
server_name: Option<&ServerNameRequest>,
sigalgs: &SupportedSignatureAlgorithms,
ec_curves: &EllipticCurveList,
ec_pointfmts: &ECPointFormatList) -> Result<(CertificatePayload, Arc<Box<sign::Signer>>), ()>;
}
pub struct ServerConfig {
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
pub ignore_client_order: bool,
pub session_storage: Box<StoresSessions>,
pub cert_resolver: Box<ResolvesCert>,
pub alpn_protocols: Vec<String>,
pub client_auth_roots: verify::RootCertStore,
pub client_auth_offer: bool,
pub client_auth_mandatory: bool
}
struct NoSessionStorage {}
impl StoresSessions for NoSessionStorage {
fn generate(&self) -> SessionID { SessionID { bytes: Vec::new() } }
fn store(&self, _id: &SessionID, _sec: &SessionSecrets) -> bool { false }
fn find(&self, _id: &SessionID) -> Option<SessionSecrets> { None }
fn erase(&self, _id: &SessionID) -> bool { false }
}
struct FailResolveChain {}
impl ResolvesCert for FailResolveChain {
fn resolve(&self,
_server_name: Option<&ServerNameRequest>,
_sigalgs: &SupportedSignatureAlgorithms,
_ec_curves: &EllipticCurveList,
_ec_pointfmts: &ECPointFormatList) -> Result<(CertificatePayload, Arc<Box<sign::Signer>>), ()> {
Err(())
}
}
struct AlwaysResolvesChain {
chain: CertificatePayload,
key: Arc<Box<sign::Signer>>
}
impl AlwaysResolvesChain {
fn new_rsa(chain: Vec<Vec<u8>>, priv_key: &[u8]) -> AlwaysResolvesChain {
let key = sign::RSASigner::new(priv_key)
.expect("Invalid RSA private key");
let mut payload = Vec::new();
for cert in chain {
payload.push(ASN1Cert::new(cert));
}
AlwaysResolvesChain { chain: payload, key: Arc::new(Box::new(key)) }
}
}
impl ResolvesCert for AlwaysResolvesChain {
fn resolve(&self,
_server_name: Option<&ServerNameRequest>,
_sigalgs: &SupportedSignatureAlgorithms,
_ec_curves: &EllipticCurveList,
_ec_pointfmts: &ECPointFormatList) -> Result<(CertificatePayload, Arc<Box<sign::Signer>>), ()> {
Ok((self.chain.clone(), self.key.clone()))
}
}
impl ServerConfig {
pub fn new() -> ServerConfig {
ServerConfig {
ciphersuites: ALL_CIPHERSUITES.to_vec(),
ignore_client_order: false,
session_storage: Box::new(NoSessionStorage {}),
alpn_protocols: Vec::new(),
cert_resolver: Box::new(FailResolveChain {}),
client_auth_roots: verify::RootCertStore::empty(),
client_auth_offer: false,
client_auth_mandatory: false
}
}
pub fn set_single_cert(&mut self, cert_chain: Vec<Vec<u8>>, key_der: Vec<u8>) {
self.cert_resolver = Box::new(AlwaysResolvesChain::new_rsa(cert_chain, &key_der));
}
pub fn set_protocols(&mut self, protocols: &[String]) {
self.alpn_protocols.clear();
self.alpn_protocols.extend_from_slice(protocols);
}
pub fn set_client_auth_roots(&mut self, certs: Vec<Vec<u8>>, mandatory: bool) {
for cert in certs {
self.client_auth_roots.add(&cert)
.unwrap()
}
self.client_auth_offer = true;
self.client_auth_mandatory = mandatory;
}
}
pub struct ServerHandshakeData {
pub server_cert_chain: Option<CertificatePayload>,
pub ciphersuite: Option<&'static SupportedCipherSuite>,
pub secrets: SessionSecrets,
pub transcript: hash_hs::HandshakeHash,
pub kx_data: Option<KeyExchange>,
pub doing_client_auth: bool,
pub valid_client_cert_chain: Option<Vec<ASN1Cert>>
}
impl ServerHandshakeData {
fn new() -> ServerHandshakeData {
ServerHandshakeData {
server_cert_chain: None,
ciphersuite: None,
secrets: SessionSecrets::for_server(),
transcript: hash_hs::HandshakeHash::new(),
kx_data: None,
doing_client_auth: false,
valid_client_cert_chain: None
}
}
pub fn generate_server_random(&mut self) {
rand::fill_random(&mut self.secrets.server_random);
}
pub fn start_handshake_hash(&mut self) {
let hash = self.ciphersuite.as_ref().unwrap().get_hash();
self.transcript.start_hash(hash);
}
}
#[derive(PartialEq)]
pub enum ConnState {
ExpectClientHello,
ExpectCertificate,
ExpectClientKX,
ExpectCertificateVerify,
ExpectCCS,
ExpectFinished,
Traffic
}
pub struct ServerSessionImpl {
pub config: Arc<ServerConfig>,
pub handshake_data: ServerHandshakeData,
pub secrets_current: SessionSecrets,
pub common: SessionCommon,
pub alpn_protocol: Option<String>,
pub state: ConnState,
}
impl ServerSessionImpl {
pub fn new(server_config: &Arc<ServerConfig>) -> ServerSessionImpl {
let mut sess = ServerSessionImpl {
config: server_config.clone(),
handshake_data: ServerHandshakeData::new(),
secrets_current: SessionSecrets::for_server(),
common: SessionCommon::new(None),
alpn_protocol: None,
state: ConnState::ExpectClientHello
};
if sess.config.client_auth_offer {
sess.handshake_data.transcript.set_client_auth_enabled();
}
sess
}
pub fn wants_read(&self) -> bool {
!self.common.has_readable_plaintext()
}
pub fn wants_write(&self) -> bool {
!self.common.tls_queue.is_empty()
}
pub fn process_msg(&mut self, msg: &mut Message) -> Result<(), TLSError> {
if self.common.peer_encrypting {
let dm = try!(self.common.decrypt_incoming(msg));
*msg = dm;
}
if self.common.handshake_joiner.want_message(msg) {
try!(
self.common.handshake_joiner.take_message(msg)
.ok_or_else(|| TLSError::CorruptMessagePayload(ContentType::Handshake))
);
return self.process_new_handshake_messages();
}
msg.decode_payload();
if msg.is_content_type(ContentType::Alert) {
return self.common.process_alert(msg);
}
return self.process_main_protocol(msg);
}
fn process_new_handshake_messages(&mut self) -> Result<(), TLSError> {
loop {
match self.common.handshake_joiner.frames.pop_front() {
Some(mut msg) => try!(self.process_main_protocol(&mut msg)),
None => break
}
}
Ok(())
}
fn queue_unexpected_alert(&mut self) {
self.common.send_fatal_alert(AlertDescription::UnexpectedMessage);
}
pub fn process_main_protocol(&mut self, msg: &mut Message) -> Result<(), TLSError> {
if self.state == ConnState::Traffic && msg.is_handshake_type(HandshakeType::ClientHello) {
self.common.send_warning_alert(AlertDescription::NoRenegotiation);
return Ok(());
}
let handler = self.get_handler();
try!(handler.expect.check_message(msg)
.map_err(|err| { self.queue_unexpected_alert(); err }));
let new_state = try!((handler.handle)(self, msg));
self.state = new_state;
if self.state == ConnState::Traffic && !self.common.traffic {
self.common.start_traffic();
}
Ok(())
}
fn get_handler(&self) -> &'static server_hs::Handler {
match self.state {
ConnState::ExpectClientHello => &server_hs::EXPECT_CLIENT_HELLO,
ConnState::ExpectCertificate => &server_hs::EXPECT_CERTIFICATE,
ConnState::ExpectClientKX => &server_hs::EXPECT_CLIENT_KX,
ConnState::ExpectCertificateVerify => &server_hs::EXPECT_CERTIFICATE_VERIFY,
ConnState::ExpectCCS => &server_hs::EXPECT_CCS,
ConnState::ExpectFinished => &server_hs::EXPECT_FINISHED,
ConnState::Traffic => &server_hs::TRAFFIC
}
}
pub fn process_new_packets(&mut self) -> Result<(), TLSError> {
if self.common.message_deframer.desynced {
return Err(TLSError::CorruptMessage);
}
loop {
match self.common.message_deframer.frames.pop_front() {
Some(mut msg) => try!(self.process_msg(&mut msg)),
None => break
}
}
Ok(())
}
pub fn start_encryption(&mut self) {
let scs = self.handshake_data.ciphersuite.as_ref().unwrap();
self.common.start_encryption(scs, &self.secrets_current);
}
pub fn send_close_notify(&mut self) {
self.common.send_warning_alert(AlertDescription::CloseNotify)
}
pub fn get_peer_certificates(&self) -> Option<Vec<Vec<u8>>> {
if self.handshake_data.valid_client_cert_chain.is_none() {
return None;
}
let mut r = Vec::new();
for cert in self.handshake_data.valid_client_cert_chain.as_ref().unwrap() {
r.push(cert.0.clone());
}
Some(r)
}
}
pub struct ServerSession {
imp: ServerSessionImpl
}
impl ServerSession {
pub fn new(config: &Arc<ServerConfig>) -> ServerSession {
ServerSession { imp: ServerSessionImpl::new(config) }
}
}
impl Session for ServerSession {
fn read_tls(&mut self, rd: &mut io::Read) -> io::Result<usize> {
self.imp.common.read_tls(rd)
}
fn write_tls(&mut self, wr: &mut io::Write) -> io::Result<()> {
self.imp.common.write_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}
fn wants_read(&self) -> bool {
self.imp.wants_read()
}
fn wants_write(&self) -> bool {
self.imp.wants_write()
}
fn send_close_notify(&mut self) {
self.imp.send_close_notify()
}
fn get_peer_certificates(&self) -> Option<Vec<Vec<u8>>> {
self.imp.get_peer_certificates()
}
}
impl io::Read for ServerSession {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.imp.common.read(buf)
}
}
impl io::Write for ServerSession {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.imp.common.send_plain(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.imp.common.flush_plaintext();
Ok(())
}
}