use msgs::enums::CipherSuite;
use msgs::enums::{AlertDescription, HandshakeType, ExtensionType};
use session::{Session, SessionSecrets, SessionRandoms, SessionCommon};
use suites::{SupportedCipherSuite, ALL_CIPHERSUITES};
use msgs::handshake::{CertificatePayload, DigitallySignedStruct, SessionID};
use msgs::enums::SignatureScheme;
use msgs::enums::{ContentType, ProtocolVersion};
use msgs::message::Message;
use msgs::persist;
use client_hs;
use hash_hs;
use verify;
use anchors;
use sign;
use suites;
use error::TLSError;
use key;
use std::collections;
use std::sync::{Arc, Mutex};
use std::io;
pub trait StoresClientSessions : Send + Sync {
fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>>;
}
struct NoSessionStorage {}
impl StoresClientSessions for NoSessionStorage {
fn put(&mut self, _key: Vec<u8>, _value: Vec<u8>) -> bool {
false
}
fn get(&mut self, _key: &[u8]) -> Option<Vec<u8>> {
None
}
}
pub struct ClientSessionMemoryCache {
cache: collections::HashMap<Vec<u8>, Vec<u8>>,
max_entries: usize,
}
impl ClientSessionMemoryCache {
pub fn new(size: usize) -> Box<ClientSessionMemoryCache> {
debug_assert!(size > 0);
Box::new(ClientSessionMemoryCache {
cache: collections::HashMap::new(),
max_entries: size,
})
}
fn limit_size(&mut self) {
while self.cache.len() > self.max_entries {
let k = self.cache.keys().next().unwrap().clone();
self.cache.remove(&k);
}
}
}
impl StoresClientSessions for ClientSessionMemoryCache {
fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache.insert(key, value);
self.limit_size();
true
}
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.get(key).cloned()
}
}
pub trait ResolvesClientCert : Send + Sync {
fn resolve(&self,
acceptable_issuers: &[&[u8]],
sigschemes: &[SignatureScheme])
-> Option<sign::CertChainAndSigner>;
fn has_certs(&self) -> bool;
}
struct FailResolveClientCert {}
impl ResolvesClientCert for FailResolveClientCert {
fn resolve(&self,
_acceptable_issuers: &[&[u8]],
_sigschemes: &[SignatureScheme])
-> Option<sign::CertChainAndSigner> {
None
}
fn has_certs(&self) -> bool {
false
}
}
struct AlwaysResolvesClientCert {
chain: Vec<key::Certificate>,
key: Arc<Box<sign::Signer>>,
}
impl AlwaysResolvesClientCert {
fn new_rsa(chain: Vec<key::Certificate>,
priv_key: &key::PrivateKey)
-> AlwaysResolvesClientCert {
let key = sign::RSASigner::new(priv_key).expect("Invalid RSA private key");
AlwaysResolvesClientCert {
chain: chain,
key: Arc::new(Box::new(key)),
}
}
}
impl ResolvesClientCert for AlwaysResolvesClientCert {
fn resolve(&self,
_acceptable_issuers: &[&[u8]],
_sigschemes: &[SignatureScheme])
-> Option<sign::CertChainAndSigner> {
Some((self.chain.clone(), self.key.clone()))
}
fn has_certs(&self) -> bool {
true
}
}
pub struct ClientConfig {
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
pub root_store: anchors::RootCertStore,
pub alpn_protocols: Vec<String>,
pub session_persistence: Mutex<Box<StoresClientSessions>>,
pub mtu: Option<usize>,
pub client_auth_cert_resolver: Box<ResolvesClientCert>,
pub enable_tickets: bool,
pub versions: Vec<ProtocolVersion>,
verifier: Box<verify::ServerCertVerifier>,
}
impl ClientConfig {
pub fn new() -> ClientConfig {
ClientConfig {
ciphersuites: ALL_CIPHERSUITES.to_vec(),
root_store: anchors::RootCertStore::empty(),
alpn_protocols: Vec::new(),
session_persistence: Mutex::new(Box::new(NoSessionStorage {})),
mtu: None,
client_auth_cert_resolver: Box::new(FailResolveClientCert {}),
enable_tickets: true,
versions: vec![ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2],
verifier: Box::new(verify::WebPKIVerifier {})
}
}
#[doc(hidden)]
pub fn get_verifier(&self) -> &verify::ServerCertVerifier {
self.verifier.as_ref()
}
pub fn set_protocols(&mut self, protocols: &[String]) {
self.alpn_protocols.clear();
self.alpn_protocols.extend_from_slice(protocols);
}
pub fn set_persistence(&mut self, persist: Box<StoresClientSessions>) {
self.session_persistence = Mutex::new(persist);
}
pub fn set_mtu(&mut self, mtu: &Option<usize>) {
if let Some(x) = *mtu {
use msgs::fragmenter;
debug_assert!(x > fragmenter::PACKET_OVERHEAD);
self.mtu = Some(x - fragmenter::PACKET_OVERHEAD);
} else {
self.mtu = None;
}
}
pub fn set_single_client_cert(&mut self,
cert_chain: Vec<key::Certificate>,
key_der: key::PrivateKey) {
self.client_auth_cert_resolver = Box::new(AlwaysResolvesClientCert::new_rsa(cert_chain,
&key_der));
}
#[cfg(feature = "dangerous_configuration")]
pub fn dangerous(&mut self) -> danger::DangerousClientConfig {
danger::DangerousClientConfig { cfg: self }
}
}
#[cfg(feature = "dangerous_configuration")]
pub mod danger {
use super::ClientConfig;
use super::verify::ServerCertVerifier;
pub struct DangerousClientConfig<'a> {
pub cfg: &'a mut ClientConfig
}
impl<'a> DangerousClientConfig<'a> {
pub fn set_certificate_verifier(&mut self,
verifier: Box<ServerCertVerifier>) {
self.cfg.verifier = verifier;
}
}
}
pub struct ClientHandshakeData {
pub server_cert_chain: CertificatePayload,
pub dns_name: String,
pub session_id: SessionID,
pub sent_extensions: Vec<ExtensionType>,
pub server_kx_params: Vec<u8>,
pub server_kx_sig: Option<DigitallySignedStruct>,
pub transcript: hash_hs::HandshakeHash,
pub resuming_session: Option<persist::ClientSessionValue>,
pub randoms: SessionRandoms,
pub must_issue_new_ticket: bool,
pub using_ems: bool,
pub new_ticket: Vec<u8>,
pub new_ticket_lifetime: u32,
pub doing_client_auth: bool,
pub client_auth_sigscheme: Option<SignatureScheme>,
pub client_auth_cert: Option<CertificatePayload>,
pub client_auth_key: Option<Arc<Box<sign::Signer>>>,
pub client_auth_context: Option<Vec<u8>>,
pub offered_key_shares: Vec<suites::KeyExchange>,
}
impl ClientHandshakeData {
fn new(host_name: &str) -> ClientHandshakeData {
ClientHandshakeData {
server_cert_chain: Vec::new(),
dns_name: host_name.to_string(),
session_id: SessionID::empty(),
sent_extensions: Vec::new(),
server_kx_params: Vec::new(),
server_kx_sig: None,
transcript: hash_hs::HandshakeHash::new(),
resuming_session: None,
randoms: SessionRandoms::for_client(),
must_issue_new_ticket: false,
using_ems: false,
new_ticket: Vec::new(),
new_ticket_lifetime: 0,
doing_client_auth: false,
client_auth_sigscheme: None,
client_auth_cert: None,
client_auth_key: None,
client_auth_context: None,
offered_key_shares: Vec::new(),
}
}
}
pub struct ClientSessionImpl {
pub config: Arc<ClientConfig>,
pub handshake_data: ClientHandshakeData,
pub secrets: Option<SessionSecrets>,
pub alpn_protocol: Option<String>,
pub common: SessionCommon,
pub error: Option<TLSError>,
pub state: &'static client_hs::State,
}
impl ClientSessionImpl {
pub fn new(config: &Arc<ClientConfig>, hostname: &str) -> ClientSessionImpl {
let mut cs = ClientSessionImpl {
config: config.clone(),
handshake_data: ClientHandshakeData::new(hostname),
secrets: None,
alpn_protocol: None,
common: SessionCommon::new(config.mtu, true),
error: None,
state: &client_hs::EXPECT_SERVER_HELLO,
};
if cs.config.client_auth_cert_resolver.has_certs() {
cs.handshake_data.transcript.set_client_auth_enabled();
}
cs.state = client_hs::emit_client_hello(&mut cs);
cs
}
pub fn get_cipher_suites(&self) -> Vec<CipherSuite> {
let mut ret = Vec::new();
for cs in &self.config.ciphersuites {
ret.push(cs.suite);
}
ret.push(CipherSuite::TLS_EMPTY_RENEGOTIATION_INFO_SCSV);
ret
}
pub fn start_encryption_tls12(&mut self) {
self.common.start_encryption_tls12(self.secrets.as_ref().unwrap());
}
pub fn find_cipher_suite(&self, suite: CipherSuite) -> Option<&'static SupportedCipherSuite> {
for scs in &self.config.ciphersuites {
if scs.suite == suite {
return Some(scs);
}
}
None
}
pub fn wants_read(&self) -> bool {
!self.common.has_readable_plaintext()
}
pub fn wants_write(&self) -> bool {
!self.common.sendable_tls.is_empty()
}
pub fn is_handshaking(&self) -> bool {
!self.common.traffic
}
pub fn set_buffer_limit(&mut self, len: usize) {
self.common.set_buffer_limit(len)
}
pub fn process_msg(&mut self, mut msg: Message) -> Result<(), TLSError> {
if self.common.peer_encrypting {
let dm = self.common.decrypt_incoming(msg)?;
msg = dm;
}
if self.common.handshake_joiner.want_message(&msg) {
self.common
.handshake_joiner
.take_message(msg)
.ok_or_else(|| {
self.common.send_fatal_alert(AlertDescription::DecodeError);
TLSError::CorruptMessagePayload(ContentType::Handshake)
})?;
return self.process_new_handshake_messages();
}
if !msg.decode_payload() {
return Err(TLSError::CorruptMessagePayload(msg.typ));
}
if msg.is_content_type(ContentType::Alert) {
return self.common.process_alert(msg);
}
self.process_main_protocol(msg)
}
fn process_new_handshake_messages(&mut self) -> Result<(), TLSError> {
while let Some(msg) = self.common.handshake_joiner.frames.pop_front() {
self.process_main_protocol(msg)?;
}
Ok(())
}
fn queue_unexpected_alert(&mut self) {
self.common.send_fatal_alert(AlertDescription::UnexpectedMessage);
}
fn process_hello_req(&mut self) {
if !self.is_handshaking() {
self.common.send_warning_alert(AlertDescription::NoRenegotiation);
}
}
fn process_main_protocol(&mut self, msg: Message) -> Result<(), TLSError> {
if msg.is_handshake_type(HandshakeType::HelloRequest) && !self.common.is_tls13() {
self.process_hello_req();
return Ok(());
}
self.state.expect
.check_message(&msg)
.map_err(|err| {
self.queue_unexpected_alert();
err
})?;
let new_state = (self.state.handle)(self, msg)?;
self.state = new_state;
Ok(())
}
pub fn process_new_packets(&mut self) -> Result<(), TLSError> {
if let Some(ref err) = self.error {
return Err(err.clone());
}
if self.common.message_deframer.desynced {
return Err(TLSError::CorruptMessage);
}
while let Some(msg) = self.common.message_deframer.frames.pop_front() {
match self.process_msg(msg) {
Ok(_) => {}
Err(err) => {
self.error = Some(err.clone());
return Err(err);
}
}
}
Ok(())
}
pub fn get_peer_certificates(&self) -> Option<Vec<key::Certificate>> {
if self.handshake_data.server_cert_chain.is_empty() {
return None;
}
let mut r = Vec::new();
for cert in &self.handshake_data.server_cert_chain {
r.push(cert.clone());
}
Some(r)
}
pub fn get_alpn_protocol(&self) -> Option<String> {
self.alpn_protocol.clone()
}
pub fn get_protocol_version(&self) -> Option<ProtocolVersion> {
self.common.negotiated_version
}
}
pub struct ClientSession {
imp: ClientSessionImpl,
}
impl ClientSession {
pub fn new(config: &Arc<ClientConfig>, hostname: &str) -> ClientSession {
ClientSession { imp: ClientSessionImpl::new(config, hostname) }
}
}
impl Session for ClientSession {
fn read_tls(&mut self, rd: &mut io::Read) -> io::Result<usize> {
self.imp.common.read_tls(rd)
}
fn write_tls(&mut self, wr: &mut io::Write) -> io::Result<usize> {
self.imp.common.write_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}
fn wants_read(&self) -> bool {
self.imp.wants_read()
}
fn wants_write(&self) -> bool {
self.imp.wants_write()
}
fn is_handshaking(&self) -> bool {
self.imp.is_handshaking()
}
fn set_buffer_limit(&mut self, len: usize) {
self.imp.set_buffer_limit(len)
}
fn send_close_notify(&mut self) {
self.imp.common.send_close_notify()
}
fn get_peer_certificates(&self) -> Option<Vec<key::Certificate>> {
self.imp.get_peer_certificates()
}
fn get_alpn_protocol(&self) -> Option<String> {
self.imp.get_alpn_protocol()
}
fn get_protocol_version(&self) -> Option<ProtocolVersion> {
self.imp.get_protocol_version()
}
}
impl io::Read for ClientSession {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.imp.common.read(buf)
}
}
impl io::Write for ClientSession {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.imp.common.send_some_plaintext(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.imp.common.flush_plaintext();
Ok(())
}
}