use std::collections::HashSet;
use std::ffi::c_void;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::sync::Mutex;
use once_cell::sync::Lazy;
use widestring::U16CStr;
use windows::core::GUID;
use windows::core::PCWSTR;
use windows::Win32::Foundation::ERROR_ALREADY_EXISTS;
use windows::Win32::Foundation::ERROR_CTX_CLOSE_PENDING;
use windows::Win32::Foundation::ERROR_SUCCESS;
use windows::Win32::Foundation::FILETIME;
use windows::Win32::System::Diagnostics::Etw;
use windows::Win32::System::Diagnostics::Etw::EVENT_CONTROL_CODE_ENABLE_PROVIDER;
use windows::Win32::System::Diagnostics::Etw::TRACE_QUERY_INFO_CLASS;
use super::etw_types::*;
use crate::native::etw_types::event_record::EventRecord;
use crate::provider::event_filter::EventFilterDescriptor;
use crate::provider::Provider;
use crate::trace::callback_data::CallbackData;
use crate::trace::{RealTimeTraceTrait, TraceProperties};
pub type TraceHandle = Etw::PROCESSTRACE_HANDLE;
pub type ControlHandle = Etw::CONTROLTRACE_HANDLE;
#[derive(Debug)]
pub enum EvntraceNativeError {
InvalidHandle,
AlreadyExist,
IoError(std::io::Error),
}
pub(crate) type EvntraceNativeResult<T> = Result<T, EvntraceNativeError>;
static UNIQUE_VALID_CONTEXTS: UniqueValidContexts = UniqueValidContexts::new();
struct UniqueValidContexts(Lazy<Mutex<HashSet<u64>>>);
enum ContextError {
AlreadyExist,
}
impl UniqueValidContexts {
pub const fn new() -> Self {
Self(Lazy::new(|| Mutex::new(HashSet::new())))
}
fn insert(&self, ctx_ptr: *const c_void) -> Result<(), ContextError> {
match self.0.lock().unwrap().insert(ctx_ptr as u64) {
true => Ok(()),
false => Err(ContextError::AlreadyExist),
}
}
fn remove(&self, ctx_ptr: *const c_void) {
self.0.lock().unwrap().remove(&(ctx_ptr as u64));
}
pub fn is_valid(&self, ctx_ptr: *const c_void) -> bool {
self.0.lock().unwrap().contains(&(ctx_ptr as u64))
}
}
extern "system" fn trace_callback_thunk(p_record: *mut Etw::EVENT_RECORD) {
match std::panic::catch_unwind(AssertUnwindSafe(|| {
let record_from_ptr = unsafe {
EventRecord::from_ptr(p_record)
};
if let Some(event_record) = record_from_ptr {
let p_user_context = event_record.user_context();
if !UNIQUE_VALID_CONTEXTS.is_valid(p_user_context) {
return;
}
let p_callback_data = p_user_context.cast::<Arc<CallbackData>>();
let callback_data = unsafe {
p_callback_data.as_ref()
};
if let Some(callback_data) = callback_data {
let cloned_arc = Arc::clone(callback_data);
cloned_arc.on_event(event_record);
}
}
})) {
Ok(_) => {}
Err(e) => {
log::error!("UNIMPLEMENTED PANIC: {e:?}");
std::process::exit(1);
}
}
}
fn filter_invalid_trace_handles(h: TraceHandle) -> Option<TraceHandle> {
if h.Value == u64::MAX || h.Value == u32::MAX as u64 {
None
} else {
Some(h)
}
}
fn filter_invalid_control_handle(h: ControlHandle) -> Option<ControlHandle> {
if h.Value == 0 {
None
} else {
Some(h)
}
}
pub(crate) fn start_trace<T>(
trace_name: &U16CStr,
etl_dump_file: Option<(&U16CStr, DumpFileLoggingMode, Option<u32>)>,
trace_properties: &TraceProperties,
enable_flags: Etw::EVENT_TRACE_FLAG,
) -> EvntraceNativeResult<(EventTraceProperties, ControlHandle)>
where
T: RealTimeTraceTrait,
{
let mut properties =
EventTraceProperties::new::<T>(trace_name, etl_dump_file, trace_properties, enable_flags);
let mut control_handle = ControlHandle::default();
let status = unsafe {
Etw::StartTraceW(
&mut control_handle,
PCWSTR::from_raw(properties.trace_name_array().as_ptr()),
properties.as_mut_ptr(),
)
}
.ok();
if let Err(status) = status {
let code = status.code();
if code == ERROR_ALREADY_EXISTS.to_hresult() {
return Err(EvntraceNativeError::AlreadyExist);
} else if code != ERROR_SUCCESS.to_hresult() {
return Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(code.0),
));
}
}
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => Ok((properties, handle)),
}
}
#[allow(clippy::borrowed_box)] pub(crate) fn open_trace(
subscription_source: SubscriptionSource,
callback_data: &Box<Arc<CallbackData>>,
) -> EvntraceNativeResult<TraceHandle> {
let mut log_file =
EventTraceLogfile::create(callback_data, subscription_source, trace_callback_thunk);
if let Err(ContextError::AlreadyExist) = UNIQUE_VALID_CONTEXTS.insert(log_file.context_ptr()) {
return Err(EvntraceNativeError::AlreadyExist);
}
let trace_handle = unsafe {
Etw::OpenTraceW(log_file.as_mut_ptr())
};
if filter_invalid_trace_handles(trace_handle).is_none() {
Err(EvntraceNativeError::IoError(std::io::Error::last_os_error()))
} else {
Ok(trace_handle)
}
}
pub(crate) fn enable_provider(
control_handle: ControlHandle,
provider: &Provider,
) -> EvntraceNativeResult<()> {
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
let owned_event_filter_descriptors: Vec<EventFilterDescriptor> = provider
.filters()
.iter()
.filter_map(|filter| filter.to_event_filter_descriptor().ok()) .collect();
let parameters = EnableTraceParameters::create(
provider.guid(),
provider.trace_flags(),
&owned_event_filter_descriptors,
);
let res = unsafe {
Etw::EnableTraceEx2(
handle,
&provider.guid() as *const GUID,
EVENT_CONTROL_CODE_ENABLE_PROVIDER.0,
provider.level(),
provider.any(),
provider.all(),
0,
Some(parameters.as_ptr()),
)
}
.ok();
res.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
})
}
}
}
pub(crate) fn process_trace(trace_handle: TraceHandle) -> EvntraceNativeResult<()> {
if filter_invalid_trace_handles(trace_handle).is_none() {
Err(EvntraceNativeError::InvalidHandle)
} else {
let result = unsafe {
let mut start = FILETIME::default();
Etw::ProcessTrace(&[trace_handle], Some(&mut start as *mut FILETIME), None)
}
.ok();
result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
})
}
}
pub(crate) fn control_trace(
properties: &mut EventTraceProperties,
control_handle: ControlHandle,
control_code: Etw::EVENT_TRACE_CONTROL,
) -> EvntraceNativeResult<()> {
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
let result = unsafe {
Etw::ControlTraceW(
handle,
PCWSTR::null(),
properties.as_mut_ptr(),
control_code,
)
}
.ok();
result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
})
}
}
}
pub(crate) fn control_trace_by_name(
properties: &mut EventTraceProperties,
trace_name: &U16CStr,
control_code: Etw::EVENT_TRACE_CONTROL,
) -> EvntraceNativeResult<()> {
let result = unsafe {
Etw::ControlTraceW(
Etw::CONTROLTRACE_HANDLE { Value: 0 },
PCWSTR::from_raw(trace_name.as_ptr()),
properties.as_mut_ptr(),
control_code,
)
}
.ok();
result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
})
}
#[allow(clippy::borrowed_box)] pub(crate) fn close_trace(
trace_handle: TraceHandle,
callback_data: &Box<Arc<CallbackData>>,
) -> EvntraceNativeResult<bool> {
match filter_invalid_trace_handles(trace_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
UNIQUE_VALID_CONTEXTS
.remove(callback_data.as_ref() as *const Arc<CallbackData> as *const c_void);
let status = unsafe { Etw::CloseTrace(handle) }.ok();
match status {
Ok(()) => Ok(false),
Err(err) if err.code() == ERROR_CTX_CLOSE_PENDING.to_hresult() => Ok(true),
Err(err) => Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(err.code().0),
)),
}
}
}
}
pub(crate) fn query_info(class: TraceInformation, buf: &mut [u8]) -> EvntraceNativeResult<()> {
let result = unsafe {
Etw::TraceQueryInformation(
Etw::CONTROLTRACE_HANDLE { Value: 0 },
TRACE_QUERY_INFO_CLASS(class as i32),
buf.as_mut_ptr().cast(),
buf.len() as u32,
None,
)
}
.ok();
result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
})
}