use std::{
collections::HashMap,
error::Error,
fmt,
fs::File,
io::{BufRead, BufReader, Write},
path::Path,
sync::{Arc, Mutex},
};
use bevy::log::warn;
use futures::executor::block_on;
use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName, UnixTime};
use tokio::sync::{mpsc, oneshot};
use crate::shared::{certificate::CertificateFingerprint, error::AsyncChannelError};
use super::{
CertificateInteractionError, ClientAsyncMessage, ConnectionLocalId, InvalidHostFile,
DEFAULT_KNOWN_HOSTS_FILE,
};
pub const DEFAULT_CERT_VERIFIER_BEHAVIOUR: CertVerifierBehaviour =
CertVerifierBehaviour::ImmediateAction(CertVerifierAction::AbortConnection);
#[derive(bevy::ecs::message::Message)]
pub struct CertInteractionEvent {
pub connection_id: ConnectionLocalId,
pub status: CertVerificationStatus,
pub info: CertVerificationInfo,
pub(crate) action_sender: Mutex<Option<oneshot::Sender<CertVerifierAction>>>,
}
impl CertInteractionEvent {
pub fn apply_cert_verifier_action(
&self,
action: CertVerifierAction,
) -> Result<(), CertificateInteractionError> {
let mut sender = self.action_sender.lock()?;
if let Some(sender) = sender.take() {
match sender.send(action) {
Ok(_) => Ok(()),
Err(_) => Err(AsyncChannelError::InternalChannelClosed.into()),
}
} else {
Err(CertificateInteractionError::CertificateActionAlreadyApplied)
}
}
}
#[derive(bevy::ecs::message::Message)]
pub struct CertTrustUpdateEvent {
pub connection_id: ConnectionLocalId,
pub cert_info: CertVerificationInfo,
}
#[derive(bevy::ecs::message::Message)]
pub struct CertConnectionAbortEvent {
pub connection_id: ConnectionLocalId,
pub status: CertVerificationStatus,
pub cert_info: CertVerificationInfo,
}
#[derive(Debug, Clone)]
pub enum CertificateVerificationMode {
SkipVerification,
SignedByCertificateAuthority,
TrustOnFirstUse(TrustOnFirstUseConfig),
}
#[derive(Debug, Clone)]
pub struct TrustOnFirstUseConfig {
pub known_hosts: KnownHosts,
pub verifier_behaviour: HashMap<CertVerificationStatus, CertVerifierBehaviour>,
}
impl Default for TrustOnFirstUseConfig {
fn default() -> Self {
TrustOnFirstUseConfig {
known_hosts: KnownHosts::HostsFile(DEFAULT_KNOWN_HOSTS_FILE.to_string()),
verifier_behaviour: HashMap::from([
(
CertVerificationStatus::UnknownCertificate,
CertVerifierBehaviour::ImmediateAction(CertVerifierAction::TrustAndStore),
),
(
CertVerificationStatus::UntrustedCertificate,
CertVerifierBehaviour::RequestClientAction,
),
(
CertVerificationStatus::TrustedCertificate,
CertVerifierBehaviour::ImmediateAction(CertVerifierAction::TrustOnce),
),
]),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum CertVerificationStatus {
UnknownCertificate,
UntrustedCertificate,
TrustedCertificate,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CertVerificationInfo {
pub server_name: ServerName,
pub fingerprint: CertificateFingerprint,
pub known_fingerprint: Option<CertificateFingerprint>,
pub certificate_chain: Vec<CertificateDer<'static>>,
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ServerName(RustlsServerName<'static>);
impl fmt::Display for ServerName {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0.to_str(), f)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CertVerifierBehaviour {
RequestClientAction,
ImmediateAction(CertVerifierAction),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CertVerifierAction {
AbortConnection,
TrustOnce,
TrustAndStore,
}
pub type CertStore = HashMap<ServerName, CertificateFingerprint>;
#[derive(Debug, Clone)]
pub enum KnownHosts {
Store(CertStore),
HostsFile(String),
}
#[derive(Debug)]
pub(crate) struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
impl SkipServerVerification {
pub(crate) fn new() -> Arc<Self> {
Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
}
}
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &RustlsServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[derive(Debug)]
pub(crate) struct TofuServerVerification {
store: CertStore,
verifier_behaviour: HashMap<CertVerificationStatus, CertVerifierBehaviour>,
to_sync_client: mpsc::Sender<ClientAsyncMessage>,
hosts_file: Option<String>,
provider: Arc<rustls::crypto::CryptoProvider>,
}
impl TofuServerVerification {
pub(crate) fn new(
store: CertStore,
verifier_behaviour: HashMap<CertVerificationStatus, CertVerifierBehaviour>,
to_sync_client: mpsc::Sender<ClientAsyncMessage>,
hosts_file: Option<String>,
provider: Arc<rustls::crypto::CryptoProvider>,
) -> Arc<Self> {
Arc::new(Self {
store,
verifier_behaviour,
to_sync_client,
hosts_file,
provider,
})
}
fn apply_verifier_behaviour_for_status(
&self,
status: CertVerificationStatus,
cert_info: CertVerificationInfo,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let behaviour = self
.verifier_behaviour
.get(&status)
.unwrap_or(&DEFAULT_CERT_VERIFIER_BEHAVIOUR);
match behaviour {
CertVerifierBehaviour::ImmediateAction(action) => {
self.apply_verifier_immediate_action(action, status, cert_info)
}
CertVerifierBehaviour::RequestClientAction => {
let (action_sender, cert_action_recv) = oneshot::channel::<CertVerifierAction>();
self.to_sync_client
.try_send(ClientAsyncMessage::CertificateInteractionRequest {
status: status.clone(),
info: cert_info.clone(),
action_sender,
})
.unwrap();
match block_on(cert_action_recv) {
Ok(action) => self.apply_verifier_immediate_action(&action, status, cert_info),
Err(err) => Err(rustls::Error::General(format!(
"Failed to receive CertVerifierAction from client: {}",
err
))),
}
}
}
}
fn apply_verifier_immediate_action(
&self,
action: &CertVerifierAction,
status: CertVerificationStatus,
cert_info: CertVerificationInfo,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
match action {
CertVerifierAction::AbortConnection => {
match self
.to_sync_client
.try_send(ClientAsyncMessage::CertificateConnectionAbort { status, cert_info })
{
Ok(_) => Err(rustls::Error::General(
"CertVerifierAction requested to abort the connection".to_owned(),
)),
Err(_) => Err(rustls::Error::General(
"Failed to signal CertificateConnectionAbort".to_owned(),
)),
}
}
CertVerifierAction::TrustOnce => {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
CertVerifierAction::TrustAndStore => {
if let Some(file) = &self.hosts_file {
let mut store_clone = self.store.clone();
store_clone
.insert(cert_info.server_name.clone(), cert_info.fingerprint.clone());
if let Err(store_error) = store_known_hosts_to_file(file, &store_clone) {
return Err(rustls::Error::General(format!(
"Failed to store new certificate entry: {}",
store_error
)));
}
}
match self
.to_sync_client
.try_send(ClientAsyncMessage::CertificateTrustUpdate(cert_info))
{
Ok(_) => Ok(rustls::client::danger::ServerCertVerified::assertion()),
Err(_) => Err(rustls::Error::General(
"Failed to signal new trusted certificate entry".to_owned(),
)),
}
}
}
}
}
impl rustls::client::danger::ServerCertVerifier for TofuServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &RustlsServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let status;
let server_name = ServerName(_server_name.to_owned());
let known_fingerprint = self.store.get(&server_name).cloned();
let certificate_chain = std::iter::once(_end_entity)
.chain(_intermediates)
.map(|x| x.clone().into_owned())
.collect();
let cert_info = CertVerificationInfo {
server_name,
fingerprint: CertificateFingerprint::from(_end_entity),
known_fingerprint,
certificate_chain,
};
if let Some(ref known_fingerprint) = cert_info.known_fingerprint {
if *known_fingerprint == cert_info.fingerprint {
status = CertVerificationStatus::TrustedCertificate;
} else {
status = CertVerificationStatus::UntrustedCertificate;
}
} else {
status = CertVerificationStatus::UnknownCertificate;
}
self.apply_verifier_behaviour_for_status(status, cert_info)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.provider
.signature_verification_algorithms
.supported_schemes()
}
}
fn store_known_hosts_to_file(file: &String, store: &CertStore) -> Result<(), Box<dyn Error>> {
let path = std::path::Path::new(file);
if let Some(prefix) = path.parent() {
std::fs::create_dir_all(prefix)?;
}
let mut store_file = File::create(path)?;
for entry in store {
writeln!(store_file, "{} {}", entry.0, entry.1)?;
}
Ok(())
}
fn parse_known_host_line(
line: String,
) -> Result<(ServerName, CertificateFingerprint), Box<dyn Error>> {
let mut parts = line.split_whitespace();
let adr_str = parts.next().ok_or(InvalidHostFile)?;
let serv_name = ServerName(RustlsServerName::try_from(adr_str)?.to_owned());
let fingerprint_b64 = parts.next().ok_or(InvalidHostFile)?;
let fingerprint_bytes = base64::decode(fingerprint_b64)?;
match fingerprint_bytes.try_into() {
Ok(buf) => Ok((serv_name, CertificateFingerprint::new(buf))),
Err(_) => Err(Box::new(InvalidHostFile)),
}
}
fn load_known_hosts_from_file(
file_path: String,
) -> Result<(CertStore, Option<String>), Box<dyn Error>> {
let mut store = HashMap::new();
for line in BufReader::new(File::open(&file_path)?).lines() {
let entry = parse_known_host_line(line?)?;
store.insert(entry.0, entry.1);
}
Ok((store, Some(file_path)))
}
pub(crate) fn load_known_hosts_store_from_config(
known_host_config: KnownHosts,
) -> Result<(CertStore, Option<String>), Box<dyn Error>> {
match known_host_config {
KnownHosts::Store(store) => Ok((store, None)),
KnownHosts::HostsFile(file) => {
if !Path::new(&file).exists() {
warn!(
"Known hosts file `{}` not found, no known hosts loaded",
file
);
Ok((HashMap::new(), Some(file)))
} else {
load_known_hosts_from_file(file)
}
}
}
}