use std::io::{Read, Write};
use picky_asn1_x509::Certificate;
use rustls::{Connection, ProtocolVersion};
use crate::credssp::sspi_cred_ssp::cipher_block_size::get_cipher_block_size;
use crate::{
ConnectionCipher, ConnectionHash, ConnectionInfo, ConnectionKeyExchange, ConnectionProtocol, Error, ErrorKind,
Result, StreamSizes,
};
pub(super) const TLS_PACKET_HEADER_LEN: usize = 1 + 2 + 2 ;
const TLS_PACKET_SEQUENCE_NUMBER_LEN: usize = size_of::<u64>();
const TLS_APPLICATION_DATA_CONTENT_TYPE: u8 = 0x17;
pub(super) mod danger {
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::{DigitallySignedStruct, Error, SignatureScheme, pki_types};
#[derive(Debug)]
pub(crate) struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_: &pki_types::CertificateDer<'_>,
_: &[pki_types::CertificateDer<'_>],
_: &pki_types::ServerName<'_>,
_: &[u8],
_: pki_types::UnixTime,
) -> Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_: &[u8],
_: &pki_types::CertificateDer<'_>,
_: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_: &[u8],
_: &pki_types::CertificateDer<'_>,
_: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
}
#[derive(Debug)]
struct TlsTrafficParts<'data> {
header: &'data mut [u8],
application_data: &'data mut [u8],
extra: &'data mut [u8],
}
#[derive(Debug)]
pub(super) struct DecryptionResultBuffers<'data> {
pub header: &'data mut [u8],
pub decrypted: &'data mut [u8],
pub extra: &'data mut [u8],
}
#[derive(Debug)]
pub(super) enum DecryptionResult<'data> {
Success(DecryptionResultBuffers<'data>),
IncompleteMessage(usize),
}
#[derive(Debug)]
pub(super) enum TlsConnection {
Rustls(Connection),
}
enum FindTlsPacketResult<'data> {
TlsPacket(&'data mut [u8]),
Missing(usize),
}
impl TlsConnection {
pub(super) fn encrypt_tls(&mut self, plain_data: &[u8]) -> Result<Vec<u8>> {
match self {
TlsConnection::Rustls(tls_connection) => {
let mut writer = tls_connection.writer();
let _bytes_written = writer.write(plain_data)?;
let mut tls_buffer = Vec::new();
let _bytes_written = tls_connection.write_tls(&mut tls_buffer)?;
Ok(tls_buffer)
}
}
}
fn find_tls_data_to_decrypt<'data>(
connection: &Connection,
payload: &'data mut [u8],
) -> Result<FindTlsPacketResult<'data>> {
if payload.len() < TLS_PACKET_HEADER_LEN {
return Ok(FindTlsPacketResult::Missing(TLS_PACKET_HEADER_LEN));
}
let mut tls_packet_start = vec![TLS_APPLICATION_DATA_CONTENT_TYPE];
let tls_version = connection
.protocol_version()
.ok_or_else(|| Error::new(ErrorKind::InternalError, "can not query negotiated TLS version"))?;
let tls_version = u16::from(if tls_version == ProtocolVersion::TLSv1_3 {
ProtocolVersion::TLSv1_2
} else {
tls_version
});
tls_packet_start.extend_from_slice(&tls_version.to_be_bytes());
if payload[0..1 + 2 ] != tls_packet_start {
return Err(Error::new(
ErrorKind::InvalidToken,
format!(
"invalid TLS packet header: expected {:?} but got {:?}",
tls_packet_start,
&payload[0..3]
),
));
}
let encrypted_application_data_len = usize::from(u16::from_be_bytes(payload[3..5].try_into().unwrap()));
let tls_packet_len = TLS_PACKET_HEADER_LEN + encrypted_application_data_len;
if payload.len() < tls_packet_len {
return Ok(FindTlsPacketResult::Missing(
TLS_PACKET_HEADER_LEN + encrypted_application_data_len - payload.len(),
));
}
Ok(FindTlsPacketResult::TlsPacket(&mut payload[0..tls_packet_len]))
}
fn split_tls_traffic<'a>(connection: &Connection, payload: &'a mut [u8]) -> Result<TlsTrafficParts<'a>> {
const TLS_PACKET_PREFIX_LEN: usize = TLS_PACKET_HEADER_LEN + TLS_PACKET_SEQUENCE_NUMBER_LEN;
if payload.len() < TLS_PACKET_PREFIX_LEN {
return Err(Error::new(ErrorKind::InvalidToken, "Input TLS buffer is too short."));
}
let mut tls_packet_start = vec![TLS_APPLICATION_DATA_CONTENT_TYPE];
let tls_version = connection
.protocol_version()
.ok_or_else(|| Error::new(ErrorKind::InternalError, "can not query negotiated TLS version"))?;
let tls_version = u16::from(if tls_version == ProtocolVersion::TLSv1_3 {
ProtocolVersion::TLSv1_2
} else {
tls_version
});
tls_packet_start.extend_from_slice(&tls_version.to_be_bytes());
if payload[0..1 + 2 ] != tls_packet_start {
return Err(Error::new(
ErrorKind::InvalidToken,
format!(
"invalid TLS packet header: expected {:?} but got {:?}",
tls_packet_start,
&payload[0..3],
),
));
}
let encrypted_application_data_len = usize::from(u16::from_be_bytes(payload[3..5].try_into().unwrap()));
if payload.len() < TLS_PACKET_HEADER_LEN + encrypted_application_data_len {
return Err(Error::new(ErrorKind::InvalidToken, "Input TLS buffer is too short."));
}
let (header, rest) = payload.split_at_mut(TLS_PACKET_PREFIX_LEN);
let (application_data, extra) =
rest.split_at_mut(encrypted_application_data_len - TLS_PACKET_SEQUENCE_NUMBER_LEN);
Ok(TlsTrafficParts {
header,
application_data,
extra,
})
}
pub(super) fn decrypt_tls<'a>(&mut self, payload: &'a mut [u8]) -> Result<DecryptionResult<'a>> {
match self {
TlsConnection::Rustls(tls_connection) => {
let mut tls_packet = match TlsConnection::find_tls_data_to_decrypt(tls_connection, payload)? {
FindTlsPacketResult::TlsPacket(data) => data as &[u8],
FindTlsPacketResult::Missing(needed_bytes_amount) => {
return Ok(DecryptionResult::IncompleteMessage(needed_bytes_amount));
}
};
let mut plain_data = Vec::with_capacity(tls_packet.len());
while !tls_packet.is_empty() {
let _ = tls_connection.read_tls(&mut tls_packet)?;
let tls_state = tls_connection
.process_new_packets()
.map_err(|err| Error::new(ErrorKind::DecryptFailure, err.to_string()))?;
let decrypted_data_len = plain_data.len();
plain_data.resize(decrypted_data_len + tls_state.plaintext_bytes_to_read(), 0);
let mut reader = tls_connection.reader();
let _plain_data_len = reader.read(&mut plain_data[decrypted_data_len..])?;
}
let TlsTrafficParts {
header,
application_data,
extra,
} = TlsConnection::split_tls_traffic(tls_connection, payload)?;
if application_data.len() < plain_data.len() {
return Err(Error::new(
ErrorKind::DecryptFailure,
"Decrypted data can not be larger then encrypted one.",
));
}
let decrypted = &mut application_data[0..plain_data.len()];
decrypted.copy_from_slice(&plain_data);
Ok(DecryptionResult::Success(DecryptionResultBuffers {
header,
decrypted,
extra,
}))
}
}
}
pub(super) fn peer_certificates(&self) -> Result<Vec<&[u8]>> {
match self {
TlsConnection::Rustls(tls_connection) => tls_connection
.peer_certificates()
.map(|certificates| certificates.iter().map(|cert| cert.as_ref()).collect())
.ok_or_else(|| Error::new(ErrorKind::CertificateUnknown, "The server certificate is not present")),
}
}
pub(super) fn process_tls_packets(&mut self, mut input_token: &[u8]) -> Result<(usize, Vec<u8>)> {
match self {
TlsConnection::Rustls(tls_connection) => {
if !input_token.is_empty() {
let _bytes_read = tls_connection.read_tls(&mut input_token)?;
}
let _io_status = tls_connection
.process_new_packets()
.map_err(|err| Error::new(ErrorKind::EncryptFailure, err.to_string()))?;
let mut tls_buffer = Vec::new();
let bytes_written = tls_connection.write_tls(&mut tls_buffer)?;
Ok((bytes_written, tls_buffer))
}
}
}
pub(super) fn stream_sizes(&self) -> Result<StreamSizes> {
match self {
TlsConnection::Rustls(tls_connection) => {
let connection_cipher = tls_connection
.negotiated_cipher_suite()
.ok_or_else(|| Error::new(ErrorKind::InternalError, "connection cipher is not negotiated"))?;
let suite = match connection_cipher {
rustls::SupportedCipherSuite::Tls12(cipher_suite) => cipher_suite.common.suite,
rustls::SupportedCipherSuite::Tls13(cipher_suite) => cipher_suite.common.suite,
};
Ok(StreamSizes {
header: TLS_PACKET_HEADER_LEN as u32,
trailer: 0x2c,
max_message: 0x4000,
buffers: 4,
block_size: get_cipher_block_size(suite)?,
})
}
}
}
pub(super) fn connection_info(&self) -> Result<ConnectionInfo> {
match self {
TlsConnection::Rustls(tls_connection) => {
let protocol_version = tls_connection.protocol_version().ok_or_else(|| {
Error::new(
ErrorKind::InvalidParameter,
"Can not acquire connection protocol version",
)
})?;
let protocol = match tls_connection {
Connection::Client(_) => match protocol_version {
ProtocolVersion::SSLv2 => ConnectionProtocol::SpProtSsl2Client,
ProtocolVersion::TLSv1_0 => ConnectionProtocol::SpProtTls1Client,
ProtocolVersion::TLSv1_1 => ConnectionProtocol::SpProtTls1_1Client,
ProtocolVersion::TLSv1_2 => ConnectionProtocol::SpProtTls1_2Client,
ProtocolVersion::TLSv1_3 => ConnectionProtocol::SpProtTls1_3Client,
version => {
return Err(Error::new(
ErrorKind::InternalError,
format!("Unsupported connection protocol was used: {version:?}"),
));
}
},
Connection::Server(_) => match protocol_version {
ProtocolVersion::SSLv2 => ConnectionProtocol::SpProtSsl2Server,
ProtocolVersion::TLSv1_0 => ConnectionProtocol::SpProtTls1Server,
ProtocolVersion::TLSv1_1 => ConnectionProtocol::SpProtTls1_1Server,
ProtocolVersion::TLSv1_2 => ConnectionProtocol::SpProtTls1_2Server,
ProtocolVersion::TLSv1_3 => ConnectionProtocol::SpProtTls1_3Server,
version => {
return Err(Error::new(
ErrorKind::InternalError,
format!("Unsupported connection protocol was used: {version:?}"),
));
}
},
};
let connection_cipher = tls_connection
.negotiated_cipher_suite()
.ok_or_else(|| Error::new(ErrorKind::InternalError, "Connection cipher is not negotiated"))?;
let common = match connection_cipher {
rustls::SupportedCipherSuite::Tls12(cipher_suite) => &cipher_suite.common,
rustls::SupportedCipherSuite::Tls13(cipher_suite) => &cipher_suite.common,
};
let (cipher, cipher_strength) = match common.suite.as_str() {
Some(name) if name.contains("AES_128_GCM") => (ConnectionCipher::CalgAes128, 128),
Some(name) if name.contains("AES_256_GCM") => (ConnectionCipher::CalgAes256, 256),
_ => {
return Err(Error::new(
ErrorKind::UnsupportedFunction,
format!("alg_id for {:?} is not known", common.suite),
));
}
};
let hash_strength = common.hash_provider.output_len().try_into()?;
Ok(ConnectionInfo {
protocol,
cipher,
cipher_strength,
hash: ConnectionHash::CalgSha,
hash_strength,
key_exchange: ConnectionKeyExchange::CalgRsaKeyx,
exchange_strength: (self.raw_peer_public_key()?.len() * 8).try_into()?,
})
}
}
}
pub(super) fn raw_peer_public_key(&self) -> Result<Vec<u8>> {
let certificates = self.peer_certificates()?;
let peer_certificate = certificates
.first()
.ok_or_else(|| Error::new(ErrorKind::CertificateUnknown, "cannot acquire server certificate"))?;
let peer_certificate: Certificate = picky_asn1_der::from_bytes(peer_certificate)?;
let raw_public_key = match peer_certificate
.tbs_certificate
.subject_public_key_info
.subject_public_key
{
picky_asn1_x509::PublicKey::Rsa(rsa_pk) => picky_asn1_der::to_vec(&rsa_pk.0)?,
picky_asn1_x509::PublicKey::Ec(ec) => picky_asn1_der::to_vec(&ec)?,
picky_asn1_x509::PublicKey::Ed(ed) => picky_asn1_der::to_vec(&ed)?,
picky_asn1_x509::PublicKey::Mldsa(mldsa) => picky_asn1_der::to_vec(&mldsa)?,
};
Ok(raw_public_key)
}
}