#![forbid(unsafe_code)]
#![warn(missing_docs)]
use std::fmt;
use std::future::Future;
use std::io;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite};
pub mod os_rng;
pub mod keylog;
pub mod alert;
pub mod config;
pub mod stream_info;
pub use alert::AlertDescription;
pub use keylog::{KeyLog, KeyLogPolicy};
pub use os_rng::OsRng;
pub use stream_info::connection_info_from;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TlsVersion {
Tls12,
Tls13,
}
impl TlsVersion {
pub const ALL: &'static [TlsVersion] = &[TlsVersion::Tls12, TlsVersion::Tls13];
}
impl fmt::Display for TlsVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TlsVersion::Tls12 => write!(f, "TLS 1.2"),
TlsVersion::Tls13 => write!(f, "TLS 1.3"),
}
}
}
impl std::str::FromStr for TlsVersion {
type Err = TlsError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"TLS 1.2" | "tls1.2" | "TLSv1.2" | "1.2" => Ok(TlsVersion::Tls12),
"TLS 1.3" | "tls1.3" | "TLSv1.3" | "1.3" => Ok(TlsVersion::Tls13),
_ => Err(TlsError::Other(format!("unknown TLS version: {s}"))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum CipherSuite {
Tls13Aes128GcmSha256,
Tls13Aes256GcmSha384,
Tls13Chacha20Poly1305Sha256,
Tls12EcdheEcdsaAes128GcmSha256,
Tls12EcdheEcdsaAes256GcmSha384,
Tls12EcdheRsaAes128GcmSha256,
Tls12EcdheRsaAes256GcmSha384,
Tls12EcdheEcdsaChacha20Poly1305Sha256,
Tls12EcdheRsaChacha20Poly1305Sha256,
Unknown,
}
impl CipherSuite {
pub fn iana_value(&self) -> [u8; 2] {
match self {
CipherSuite::Tls13Aes128GcmSha256 => [0x13, 0x01],
CipherSuite::Tls13Aes256GcmSha384 => [0x13, 0x02],
CipherSuite::Tls13Chacha20Poly1305Sha256 => [0x13, 0x03],
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256 => [0xC0, 0x2B],
CipherSuite::Tls12EcdheEcdsaAes256GcmSha384 => [0xC0, 0x2C],
CipherSuite::Tls12EcdheRsaAes128GcmSha256 => [0xC0, 0x2F],
CipherSuite::Tls12EcdheRsaAes256GcmSha384 => [0xC0, 0x30],
CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256 => [0xCC, 0xA9],
CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256 => [0xCC, 0xA8],
CipherSuite::Unknown => [0xFF, 0xFF],
}
}
pub fn from_iana(bytes: [u8; 2]) -> Option<Self> {
match bytes {
[0x13, 0x01] => Some(CipherSuite::Tls13Aes128GcmSha256),
[0x13, 0x02] => Some(CipherSuite::Tls13Aes256GcmSha384),
[0x13, 0x03] => Some(CipherSuite::Tls13Chacha20Poly1305Sha256),
[0xC0, 0x2B] => Some(CipherSuite::Tls12EcdheEcdsaAes128GcmSha256),
[0xC0, 0x2C] => Some(CipherSuite::Tls12EcdheEcdsaAes256GcmSha384),
[0xC0, 0x2F] => Some(CipherSuite::Tls12EcdheRsaAes128GcmSha256),
[0xC0, 0x30] => Some(CipherSuite::Tls12EcdheRsaAes256GcmSha384),
[0xCC, 0xA9] => Some(CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256),
[0xCC, 0xA8] => Some(CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256),
_ => None,
}
}
pub fn is_tls13(&self) -> bool {
matches!(
self,
CipherSuite::Tls13Aes128GcmSha256
| CipherSuite::Tls13Aes256GcmSha384
| CipherSuite::Tls13Chacha20Poly1305Sha256
)
}
pub fn is_tls12(&self) -> bool {
matches!(
self,
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256
| CipherSuite::Tls12EcdheEcdsaAes256GcmSha384
| CipherSuite::Tls12EcdheRsaAes128GcmSha256
| CipherSuite::Tls12EcdheRsaAes256GcmSha384
| CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256
| CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256
)
}
pub fn is_unknown(&self) -> bool {
matches!(self, CipherSuite::Unknown)
}
pub const ALL: &'static [CipherSuite] = &[
CipherSuite::Tls13Aes128GcmSha256,
CipherSuite::Tls13Aes256GcmSha384,
CipherSuite::Tls13Chacha20Poly1305Sha256,
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
CipherSuite::Tls12EcdheEcdsaAes256GcmSha384,
CipherSuite::Tls12EcdheRsaAes128GcmSha256,
CipherSuite::Tls12EcdheRsaAes256GcmSha384,
CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256,
CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
];
}
impl fmt::Display for CipherSuite {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
CipherSuite::Tls13Aes128GcmSha256 => "TLS_AES_128_GCM_SHA256",
CipherSuite::Tls13Aes256GcmSha384 => "TLS_AES_256_GCM_SHA384",
CipherSuite::Tls13Chacha20Poly1305Sha256 => "TLS_CHACHA20_POLY1305_SHA256",
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256 => {
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
CipherSuite::Tls12EcdheEcdsaAes256GcmSha384 => {
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
}
CipherSuite::Tls12EcdheRsaAes128GcmSha256 => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
CipherSuite::Tls12EcdheRsaAes256GcmSha384 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256 => {
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"
}
CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256 => {
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"
}
CipherSuite::Unknown => "UNKNOWN",
};
write!(f, "{name}")
}
}
impl std::str::FromStr for CipherSuite {
type Err = TlsError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"TLS_AES_128_GCM_SHA256" => Ok(CipherSuite::Tls13Aes128GcmSha256),
"TLS_AES_256_GCM_SHA384" => Ok(CipherSuite::Tls13Aes256GcmSha384),
"TLS_CHACHA20_POLY1305_SHA256" => Ok(CipherSuite::Tls13Chacha20Poly1305Sha256),
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
Ok(CipherSuite::Tls12EcdheEcdsaAes128GcmSha256)
}
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
Ok(CipherSuite::Tls12EcdheEcdsaAes256GcmSha384)
}
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
Ok(CipherSuite::Tls12EcdheRsaAes128GcmSha256)
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
Ok(CipherSuite::Tls12EcdheRsaAes256GcmSha384)
}
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
Ok(CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256)
}
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
Ok(CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256)
}
"UNKNOWN" => Ok(CipherSuite::Unknown),
_ => Err(TlsError::Other(format!("unknown cipher suite: {s}"))),
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub version: Option<TlsVersion>,
pub cipher_suite: Option<CipherSuite>,
pub alpn_protocol: Option<Vec<u8>>,
pub sni: Option<String>,
pub peer_certificates: Vec<Vec<u8>>,
}
impl ConnectionInfo {
pub fn new() -> Self {
Self {
version: None,
cipher_suite: None,
alpn_protocol: None,
sni: None,
peer_certificates: Vec::new(),
}
}
pub fn with_version(mut self, version: TlsVersion) -> Self {
self.version = Some(version);
self
}
pub fn with_cipher_suite(mut self, suite: CipherSuite) -> Self {
self.cipher_suite = Some(suite);
self
}
pub fn with_alpn_protocol(mut self, proto: Vec<u8>) -> Self {
self.alpn_protocol = Some(proto);
self
}
pub fn with_sni(mut self, sni: String) -> Self {
self.sni = Some(sni);
self
}
pub fn with_peer_certificates(mut self, certs: Vec<Vec<u8>>) -> Self {
self.peer_certificates = certs;
self
}
pub fn alpn_protocol_str(&self) -> Option<&str> {
self.alpn_protocol
.as_ref()
.and_then(|p| std::str::from_utf8(p).ok())
}
}
impl Default for ConnectionInfo {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct ConnectionInfoBuilder {
inner: ConnectionInfo,
}
impl ConnectionInfoBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn version(mut self, version: TlsVersion) -> Self {
self.inner.version = Some(version);
self
}
pub fn cipher_suite(mut self, suite: CipherSuite) -> Self {
self.inner.cipher_suite = Some(suite);
self
}
pub fn alpn_protocol(mut self, proto: Vec<u8>) -> Self {
self.inner.alpn_protocol = Some(proto);
self
}
pub fn sni(mut self, sni: String) -> Self {
self.inner.sni = Some(sni);
self
}
pub fn peer_certificates(mut self, certs: Vec<Vec<u8>>) -> Self {
self.inner.peer_certificates = certs;
self
}
pub fn build(self) -> ConnectionInfo {
self.inner
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum TlsError {
Io(io::ErrorKind),
Handshake(String),
BadCert(String),
InvalidConfig(String),
CertRevoked(String),
CertInvalid(String),
ProtocolViolation(String),
AlertReceived(AlertDescription),
Other(String),
}
impl TlsError {
pub fn is_handshake(&self) -> bool {
matches!(self, TlsError::Handshake(_))
}
pub fn is_io(&self) -> bool {
matches!(self, TlsError::Io(_))
}
pub fn is_cert(&self) -> bool {
matches!(
self,
TlsError::BadCert(_) | TlsError::CertRevoked(_) | TlsError::CertInvalid(_)
)
}
pub fn is_config(&self) -> bool {
matches!(self, TlsError::InvalidConfig(_))
}
pub fn is_protocol_violation(&self) -> bool {
matches!(self, TlsError::ProtocolViolation(_))
}
}
impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TlsError::Io(k) => write!(f, "I/O error: {k:?}"),
TlsError::Handshake(s) => write!(f, "handshake error: {s}"),
TlsError::BadCert(s) => write!(f, "bad certificate: {s}"),
TlsError::InvalidConfig(s) => write!(f, "invalid config: {s}"),
TlsError::CertRevoked(s) => write!(f, "certificate revoked: {s}"),
TlsError::CertInvalid(s) => write!(f, "invalid certificate: {s}"),
TlsError::ProtocolViolation(s) => write!(f, "protocol violation: {s}"),
TlsError::AlertReceived(d) => write!(f, "TLS alert received: {d}"),
TlsError::Other(s) => write!(f, "TLS error: {s}"),
}
}
}
impl std::error::Error for TlsError {}
impl From<io::Error> for TlsError {
fn from(e: io::Error) -> Self {
TlsError::Io(e.kind())
}
}
impl From<TlsError> for io::Error {
fn from(e: TlsError) -> Self {
match e {
TlsError::Io(kind) => io::Error::new(kind, "TLS I/O error"),
TlsError::Handshake(s) => io::Error::new(io::ErrorKind::ConnectionAborted, s),
TlsError::BadCert(s) => io::Error::new(io::ErrorKind::InvalidData, s),
TlsError::InvalidConfig(s) => io::Error::new(io::ErrorKind::InvalidInput, s),
TlsError::CertRevoked(s) => io::Error::new(io::ErrorKind::PermissionDenied, s),
TlsError::CertInvalid(s) => io::Error::new(io::ErrorKind::InvalidData, s),
TlsError::ProtocolViolation(s) => io::Error::new(io::ErrorKind::InvalidData, s),
TlsError::AlertReceived(d) => {
io::Error::new(io::ErrorKind::ConnectionAborted, format!("TLS alert: {d}"))
}
TlsError::Other(s) => io::Error::other(s),
}
}
}
impl From<rustls::Error> for TlsError {
fn from(e: rustls::Error) -> Self {
match &e {
rustls::Error::NoCertificatesPresented => {
TlsError::CertInvalid("no certificates presented".to_string())
}
rustls::Error::UnsupportedNameType => {
TlsError::CertInvalid("unsupported name type".to_string())
}
rustls::Error::InvalidCertificate(reason) => {
TlsError::CertInvalid(format!("{reason:?}"))
}
rustls::Error::PeerIncompatible(reason) => {
TlsError::ProtocolViolation(format!("{reason:?}"))
}
rustls::Error::PeerMisbehaved(reason) => {
TlsError::ProtocolViolation(format!("{reason:?}"))
}
rustls::Error::AlertReceived(alert) => TlsError::Handshake(format!("alert: {alert:?}")),
rustls::Error::BadMaxFragmentSize => {
TlsError::InvalidConfig("bad max fragment size".to_string())
}
rustls::Error::General(s) => TlsError::Other(s.clone()),
_ => TlsError::Other(e.to_string()),
}
}
}
pub type TlsStream = Box<dyn TlsStreamTrait>;
pub trait TlsStreamTrait: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> TlsStreamTrait for T {}
pub trait TlsConnector: Send + Sync + 'static {
fn connect(
&self,
stream: TlsStream,
server_name: rustls::pki_types::ServerName<'static>,
) -> Pin<Box<dyn Future<Output = Result<TlsStream, TlsError>> + Send + '_>>;
}
pub trait TlsAcceptor: Send + Sync + 'static {
fn accept(
&self,
stream: TlsStream,
) -> Pin<Box<dyn Future<Output = Result<TlsStream, TlsError>> + Send + '_>>;
}
pub trait TlsStreamInfo {
fn connection_info(&self) -> Option<&ConnectionInfo> {
None
}
}
#[cfg(feature = "generic-transport")]
pub type GenericTlsFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, TlsError>> + Send + 'a>>;
#[cfg(feature = "generic-transport")]
pub trait GenericTlsConnector: Send + Sync + 'static {
type Stream<S>: AsyncRead + AsyncWrite + Unpin + Send + TlsStreamInfo
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
fn connect<S>(
&self,
stream: S,
server_name: rustls::pki_types::ServerName<'static>,
) -> GenericTlsFuture<'_, Self::Stream<S>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
}
#[cfg(feature = "generic-transport")]
pub trait GenericTlsAcceptor: Send + Sync + 'static {
type Stream<S>: AsyncRead + AsyncWrite + Unpin + Send + TlsStreamInfo
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
fn accept<S>(&self, stream: S) -> GenericTlsFuture<'_, Self::Stream<S>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tls_version_display_roundtrip() {
let v13 = TlsVersion::Tls13;
let s = v13.to_string();
assert_eq!(s, "TLS 1.3");
let parsed: TlsVersion = s.parse().expect("should parse");
assert_eq!(parsed, v13);
let v12 = TlsVersion::Tls12;
let s = v12.to_string();
assert_eq!(s, "TLS 1.2");
let parsed: TlsVersion = s.parse().expect("should parse");
assert_eq!(parsed, v12);
}
#[test]
fn tls_version_parse_variants() {
assert_eq!("tls1.3".parse::<TlsVersion>().ok(), Some(TlsVersion::Tls13));
assert_eq!(
"TLSv1.2".parse::<TlsVersion>().ok(),
Some(TlsVersion::Tls12)
);
assert_eq!("1.3".parse::<TlsVersion>().ok(), Some(TlsVersion::Tls13));
assert!("TLS 1.0".parse::<TlsVersion>().is_err());
}
#[test]
fn cipher_suite_display_roundtrip() {
let suites = [
CipherSuite::Tls13Aes128GcmSha256,
CipherSuite::Tls13Aes256GcmSha384,
CipherSuite::Tls13Chacha20Poly1305Sha256,
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
CipherSuite::Tls12EcdheEcdsaAes256GcmSha384,
CipherSuite::Tls12EcdheRsaAes128GcmSha256,
CipherSuite::Tls12EcdheRsaAes256GcmSha384,
CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256,
CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
];
for suite in &suites {
let s = suite.to_string();
let parsed: CipherSuite = s.parse().expect("should parse");
assert_eq!(&parsed, suite);
}
}
#[test]
fn cipher_suite_iana_roundtrip() {
let suites = [
CipherSuite::Tls13Aes128GcmSha256,
CipherSuite::Tls13Aes256GcmSha384,
CipherSuite::Tls13Chacha20Poly1305Sha256,
CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
];
for suite in &suites {
let iana = suite.iana_value();
let from_iana = CipherSuite::from_iana(iana);
assert_eq!(from_iana, Some(*suite));
}
assert_eq!(CipherSuite::from_iana([0xFF, 0xFF]), None);
}
#[test]
fn cipher_suite_version_classification() {
assert!(CipherSuite::Tls13Aes128GcmSha256.is_tls13());
assert!(!CipherSuite::Tls13Aes128GcmSha256.is_tls12());
assert!(CipherSuite::Tls12EcdheRsaAes128GcmSha256.is_tls12());
assert!(!CipherSuite::Tls12EcdheRsaAes128GcmSha256.is_tls13());
}
#[test]
fn connection_info_builder() {
let info = ConnectionInfo::new()
.with_version(TlsVersion::Tls13)
.with_cipher_suite(CipherSuite::Tls13Aes256GcmSha384)
.with_alpn_protocol(b"h2".to_vec())
.with_sni("example.com".to_string());
assert_eq!(info.version, Some(TlsVersion::Tls13));
assert_eq!(info.cipher_suite, Some(CipherSuite::Tls13Aes256GcmSha384));
assert_eq!(info.alpn_protocol_str(), Some("h2"));
assert_eq!(info.sni.as_deref(), Some("example.com"));
assert!(info.peer_certificates.is_empty());
}
#[test]
fn connection_info_default() {
let info = ConnectionInfo::default();
assert_eq!(info.version, None);
assert_eq!(info.cipher_suite, None);
assert_eq!(info.alpn_protocol, None);
assert_eq!(info.sni, None);
assert!(info.peer_certificates.is_empty());
}
#[test]
fn tls_error_display_all_variants() {
let cases = [
(TlsError::Io(io::ErrorKind::BrokenPipe), "I/O error:"),
(TlsError::Handshake("test".into()), "handshake error:"),
(TlsError::BadCert("test".into()), "bad certificate:"),
(TlsError::InvalidConfig("test".into()), "invalid config:"),
(TlsError::CertRevoked("test".into()), "certificate revoked:"),
(TlsError::Other("test".into()), "TLS error:"),
];
for (err, prefix) in &cases {
assert!(
err.to_string().starts_with(prefix),
"{err} should start with {prefix}"
);
}
}
#[test]
fn tls_error_predicates() {
assert!(TlsError::Handshake("x".into()).is_handshake());
assert!(!TlsError::Handshake("x".into()).is_io());
assert!(TlsError::Io(io::ErrorKind::Other).is_io());
assert!(TlsError::BadCert("x".into()).is_cert());
assert!(TlsError::CertRevoked("x".into()).is_cert());
assert!(!TlsError::Other("x".into()).is_cert());
assert!(TlsError::InvalidConfig("x".into()).is_config());
}
#[test]
fn tls_error_from_io_error() {
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
let tls_err = TlsError::from(io_err);
assert!(tls_err.is_io());
}
#[test]
fn tls_error_into_io_error() {
let cases: Vec<TlsError> = vec![
TlsError::Io(io::ErrorKind::BrokenPipe),
TlsError::Handshake("hs".into()),
TlsError::BadCert("bc".into()),
TlsError::InvalidConfig("ic".into()),
TlsError::CertRevoked("cr".into()),
TlsError::Other("ot".into()),
];
for tls_err in cases {
let io_err: io::Error = tls_err.into();
let _ = io_err.kind();
}
}
}