#![cfg(feature = "debug")]
#![allow(unsafe_code)]
use std::ffi::{c_char, CStr};
use std::fmt::{self, Display};
#[allow(unused_imports)] use std::os::raw::{c_int, c_uint};
use std::sync::{Arc, OnceLock};
pub trait Tls13SecretCallbacks {
fn wireshark_keylog(&self, _secret: String);
fn secrets(&self, secret_type: Tls13Secret, random: &[u8], secret: &[u8]) {
let mut keylog = secret_type.to_string();
keylog.push(' ');
random
.iter()
.for_each(|i| keylog.push_str(&format!("{i:02x}")));
keylog.push(' ');
secret
.iter()
.for_each(|f| keylog.push_str(&format!("{f:02x}")));
keylog.push('\n');
self.wireshark_keylog(keylog);
}
}
pub type Tls13SecretCallbacksArg = Arc<dyn Tls13SecretCallbacks + Send + Sync>;
pub(crate) const RANDOM_SIZE: usize = 32;
#[cfg(not(windows))]
pub type Tls13SecretType = c_uint;
#[cfg(windows)]
pub type Tls13SecretType = c_int;
pub enum Tls13Secret {
ClientEarlyTrafficSecret,
ClientHandshakeTrafficSecret,
ServerHandshakeTrafficSecret,
ClientTrafficSecret,
ServerTrafficSecret,
EarlyExporterSecret,
ExporterSecret,
UnknownSecret(Tls13SecretType),
}
impl Display for Tls13Secret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Tls13Secret::*;
let secret = match self {
ClientEarlyTrafficSecret => "CLIENT_EARLY_TRAFFIC_SECRET",
ClientHandshakeTrafficSecret => "CLIENT_HANDSHAKE_TRAFFIC_SECRET",
ServerHandshakeTrafficSecret => "SERVER_HANDSHAKE_TRAFFIC_SECRET",
ClientTrafficSecret => "CLIENT_TRAFFIC_SECRET_0",
ServerTrafficSecret => "SERVER_TRAFFIC_SECRET_0",
EarlyExporterSecret => "EARLY_EXPORTER_SECRET",
ExporterSecret => "EXPORTER_SECRET",
UnknownSecret(_e) => "UNKNOWN_SECRET",
};
write!(f, "{secret}")
}
}
impl From<c_int> for Tls13Secret {
fn from(value: c_int) -> Self {
match value as Tls13SecretType {
wolfssl_sys::Tls13Secret_CLIENT_EARLY_TRAFFIC_SECRET => {
Tls13Secret::ClientEarlyTrafficSecret
}
wolfssl_sys::Tls13Secret_CLIENT_HANDSHAKE_TRAFFIC_SECRET => {
Tls13Secret::ClientHandshakeTrafficSecret
}
wolfssl_sys::Tls13Secret_SERVER_HANDSHAKE_TRAFFIC_SECRET => {
Tls13Secret::ServerHandshakeTrafficSecret
}
wolfssl_sys::Tls13Secret_CLIENT_TRAFFIC_SECRET => Tls13Secret::ClientTrafficSecret,
wolfssl_sys::Tls13Secret_SERVER_TRAFFIC_SECRET => Tls13Secret::ServerTrafficSecret,
wolfssl_sys::Tls13Secret_EARLY_EXPORTER_SECRET => Tls13Secret::EarlyExporterSecret,
wolfssl_sys::Tls13Secret_EXPORTER_SECRET => Tls13Secret::ExporterSecret,
e => Tls13Secret::UnknownSecret(e),
}
}
}
pub type LoggingCallback = fn(message: &str);
static LOGGING_CALLBACK: OnceLock<LoggingCallback> = OnceLock::new();
unsafe extern "C" fn logging_trampoline(_level: c_int, msg: *const c_char) {
if msg.is_null() {
return;
}
let c_str = unsafe { CStr::from_ptr(msg) };
let message = c_str.to_str().unwrap_or("Unable to decode C string");
if let Some(cb) = LOGGING_CALLBACK.get() {
if std::panic::catch_unwind(|| cb(message)).is_err() {
log::warn!("Panic in logging callback");
}
}
}
pub fn install_logging_callback(cb: LoggingCallback) {
let _ = LOGGING_CALLBACK.set(cb);
super::enable_debugging(true);
super::set_logging_callback(Some(logging_trampoline));
}