use std::collections::HashSet;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::sync::Mutex;
use std::ffi::c_void;
use once_cell::sync::Lazy;
use widestring::{U16CString, U16CStr};
use windows::Win32::System::Diagnostics::Etw::EVENT_CONTROL_CODE_ENABLE_PROVIDER;
use windows::core::GUID;
use windows::core::PCWSTR;
use windows::Win32::Foundation::FILETIME;
use windows::Win32::System::Diagnostics::Etw;
use windows::Win32::System::Diagnostics::Etw::TRACE_QUERY_INFO_CLASS;
use windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime;
use windows::Win32::Foundation::ERROR_SUCCESS;
use windows::Win32::Foundation::ERROR_ALREADY_EXISTS;
use windows::Win32::Foundation::ERROR_CTX_CLOSE_PENDING;
use super::etw_types::*;
use crate::provider::Provider;
use crate::provider::event_filter::EventFilterDescriptor;
use crate::native::etw_types::event_record::EventRecord;
use crate::trace::{TraceProperties, TraceTrait};
use crate::trace::callback_data::CallbackData;
pub type TraceHandle = Etw::PROCESSTRACE_HANDLE;
pub type ControlHandle = Etw::CONTROLTRACE_HANDLE;
#[derive(Debug)]
pub enum EvntraceNativeError {
InvalidHandle,
AlreadyExist,
IoError(std::io::Error),
}
pub(crate) type EvntraceNativeResult<T> = Result<T, EvntraceNativeError>;
static UNIQUE_VALID_CONTEXTS: UniqueValidContexts = UniqueValidContexts::new();
struct UniqueValidContexts(Lazy<Mutex<HashSet<u64>>>);
enum ContextError{
AlreadyExist
}
impl UniqueValidContexts {
pub const fn new() -> Self {
Self(Lazy::new(|| Mutex::new(HashSet::new())))
}
fn insert(&self, ctx_ptr: *const c_void) -> Result<(), ContextError> {
match self.0.lock().unwrap().insert(ctx_ptr as u64) {
true => Ok(()),
false => Err(ContextError::AlreadyExist),
}
}
fn remove(&self, ctx_ptr: *const c_void) {
self.0.lock().unwrap().remove(&(ctx_ptr as u64));
}
pub fn is_valid(&self, ctx_ptr: *const c_void) -> bool {
self.0.lock().unwrap().contains(&(ctx_ptr as u64))
}
}
extern "system" fn trace_callback_thunk(p_record: *mut Etw::EVENT_RECORD) {
match std::panic::catch_unwind(AssertUnwindSafe(|| {
let record_from_ptr = unsafe {
EventRecord::from_ptr(p_record)
};
if let Some(event_record) = record_from_ptr {
let p_user_context = event_record.user_context();
if UNIQUE_VALID_CONTEXTS.is_valid(p_user_context) == false {
return;
}
let p_callback_data = p_user_context.cast::<Arc<CallbackData>>();
let callback_data = unsafe {
p_callback_data.as_ref()
};
if let Some(callback_data) = callback_data {
let cloned_arc = Arc::clone(callback_data);
cloned_arc.on_event(event_record);
}
}
})) {
Ok(_) => {}
Err(e) => {
log::error!("UNIMPLEMENTED PANIC: {e:?}");
std::process::exit(1);
}
}
}
fn filter_invalid_trace_handles(h: TraceHandle) -> Option<TraceHandle> {
if h.0 == u64::MAX || h.0 == u32::MAX as u64 {
None
} else {
Some(h)
}
}
fn filter_invalid_control_handle(h: ControlHandle) -> Option<ControlHandle> {
if h.0 == 0 {
None
} else {
Some(h)
}
}
pub(crate) fn start_trace<T>(trace_name: &U16CStr, trace_properties: &TraceProperties, enable_flags: Etw::EVENT_TRACE_FLAG) -> EvntraceNativeResult<(EventTraceProperties, ControlHandle)>
where
T: TraceTrait
{
let mut properties = EventTraceProperties::new::<T>(trace_name, trace_properties, enable_flags);
let mut control_handle = ControlHandle::default();
let status = unsafe {
Etw::StartTraceW(
&mut control_handle,
PCWSTR::from_raw(properties.trace_name_array().as_ptr()),
properties.as_mut_ptr(),
)
};
if status == ERROR_ALREADY_EXISTS {
return Err(EvntraceNativeError::AlreadyExist);
} else if status != ERROR_SUCCESS {
return Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(status.0 as i32),
));
}
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => Ok((properties, handle)),
}
}
#[allow(clippy::borrowed_box)] pub(crate) fn open_trace(trace_name: U16CString, callback_data: &Box<Arc<CallbackData>>) -> EvntraceNativeResult<TraceHandle> {
let mut log_file = EventTraceLogfile::create(callback_data, trace_name, trace_callback_thunk);
if let Err(ContextError::AlreadyExist) = UNIQUE_VALID_CONTEXTS.insert(log_file.context_ptr()) {
return Err(EvntraceNativeError::AlreadyExist);
}
let trace_handle = unsafe {
Etw::OpenTraceW(log_file.as_mut_ptr())
};
if filter_invalid_trace_handles(trace_handle).is_none() {
Err(EvntraceNativeError::IoError(std::io::Error::last_os_error()))
} else {
Ok(trace_handle)
}
}
pub(crate) fn enable_provider(control_handle: ControlHandle, provider: &Provider) -> EvntraceNativeResult<()> {
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
let owned_event_filter_descriptors: Vec<EventFilterDescriptor> = provider.filters()
.iter()
.filter_map(|filter| filter.to_event_filter_descriptor().ok()) .collect();
let parameters =
EnableTraceParameters::create(provider.guid(), provider.trace_flags(), &owned_event_filter_descriptors);
let res = unsafe {
Etw::EnableTraceEx2(
handle,
&provider.guid() as *const GUID,
EVENT_CONTROL_CODE_ENABLE_PROVIDER.0,
provider.level(),
provider.any(),
provider.all(),
0,
Some(parameters.as_ptr()),
)
};
if res == ERROR_SUCCESS {
Ok(())
} else {
Err(
EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(res.0 as i32)
)
)
}
}
}
}
pub(crate) fn process_trace(trace_handle: TraceHandle) -> EvntraceNativeResult<()> {
if filter_invalid_trace_handles(trace_handle).is_none() {
Err(EvntraceNativeError::InvalidHandle)
} else {
let mut now = FILETIME::default();
let result = unsafe {
GetSystemTimeAsFileTime(&mut now);
Etw::ProcessTrace(&[trace_handle], Some(&now), None)
};
if result == ERROR_SUCCESS {
Ok(())
} else {
Err(EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(result.0 as i32)))
}
}
}
pub(crate) fn control_trace(
properties: &mut EventTraceProperties,
control_handle: ControlHandle,
control_code: Etw::EVENT_TRACE_CONTROL,
) -> EvntraceNativeResult<()> {
match filter_invalid_control_handle(control_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
let status = unsafe {
Etw::ControlTraceW(
handle,
PCWSTR::null(),
properties.as_mut_ptr(),
control_code,
)
};
if status != ERROR_SUCCESS {
return Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(status.0 as i32),
));
}
Ok(())
}
}
}
pub(crate) fn control_trace_by_name(
properties: &mut EventTraceProperties,
trace_name: &U16CStr,
control_code: Etw::EVENT_TRACE_CONTROL,
) -> EvntraceNativeResult<()> {
let status = unsafe {
Etw::ControlTraceW(
Etw::CONTROLTRACE_HANDLE(0),
PCWSTR::from_raw(trace_name.as_ptr()),
properties.as_mut_ptr(),
control_code,
)
};
if status != ERROR_SUCCESS {
return Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(status.0 as i32),
));
}
Ok(())
}
#[allow(clippy::borrowed_box)] pub(crate) fn close_trace(trace_handle: TraceHandle, callback_data: &Box<Arc<CallbackData>>) -> EvntraceNativeResult<bool> {
match filter_invalid_trace_handles(trace_handle) {
None => Err(EvntraceNativeError::InvalidHandle),
Some(handle) => {
UNIQUE_VALID_CONTEXTS.remove(callback_data.as_ref() as *const Arc<CallbackData> as *const c_void);
let status = unsafe {
Etw::CloseTrace(handle)
};
match status {
ERROR_SUCCESS => Ok(false),
ERROR_CTX_CLOSE_PENDING => Ok(true),
status => Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(status.0 as i32),
))
}
},
}
}
pub(crate) fn query_info(class: TraceInformation, buf: &mut [u8]) -> EvntraceNativeResult<()> {
match unsafe {
Etw::TraceQueryInformation(
Etw::CONTROLTRACE_HANDLE(0),
TRACE_QUERY_INFO_CLASS(class as i32),
buf.as_mut_ptr() as *mut c_void,
buf.len() as u32,
None,
)
} {
ERROR_SUCCESS => Ok(()),
e => Err(EvntraceNativeError::IoError(
std::io::Error::from_raw_os_error(e.0 as i32),
)),
}
}