Skip to main content

windows_erg/etw/
session.rs

1//! ETW trace session management.
2
3use super::decode::{DecodedEvent, decode_from_record_parts};
4use super::schema::SchemaCache;
5use super::types::{CpuSample, StackTrace, SystemProvider, ThreadContext, TraceEvent};
6use crate::Result;
7use crate::error::{Error, EtwConsumeError, EtwError, EtwProviderError, EtwSessionError};
8use crate::types::ProcessId;
9use crate::utils::to_utf16_nul;
10use crate::wait::Wait;
11use std::borrow::Cow;
12use std::collections::HashSet;
13use std::sync::mpsc::{self, Receiver, SyncSender};
14use std::sync::{Arc, Mutex};
15use std::thread::JoinHandle;
16use std::time::Duration;
17use windows::Win32::Foundation::ERROR_SUCCESS;
18use windows::Win32::System::Diagnostics::Etw::*;
19use windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime;
20use windows::core::{GUID, PWSTR};
21
22const MAX_SESSION_NAME_LEN: usize = 1024;
23const KERNEL_LOGGER_NAME: &str = "NT Kernel Logger";
24const ERROR_ALREADY_EXISTS_CODE: u32 = 183;
25
26/// Callback context shared between the `ProcessTrace` thread and the consumer.
27struct CallbackContext {
28    raw_sender: Option<SyncSender<TraceEvent>>,
29    decoded_sender: Option<SyncSender<DecodedEvent>>,
30    schema_cache: Option<Mutex<SchemaCache>>,
31    process_filter: Option<HashSet<ProcessId>>,
32    include_thread_context: bool,
33    include_stack_traces: bool,
34    include_cpu_samples: bool,
35}
36
37fn normalize_process_filter(pids: Vec<ProcessId>) -> Option<HashSet<ProcessId>> {
38    if pids.is_empty() {
39        return None;
40    }
41    Some(pids.into_iter().collect())
42}
43
44fn extract_stack_trace(record: &EVENT_RECORD) -> Option<StackTrace> {
45    if record.ExtendedDataCount == 0 || record.ExtendedData.is_null() {
46        return None;
47    }
48
49    let items = unsafe {
50        std::slice::from_raw_parts(record.ExtendedData, record.ExtendedDataCount as usize)
51    };
52
53    for item in items {
54        let ext_type = item.ExtType;
55        let is_stack32 = ext_type == EVENT_HEADER_EXT_TYPE_STACK_TRACE32 as u16;
56        let is_stack64 = ext_type == EVENT_HEADER_EXT_TYPE_STACK_TRACE64 as u16;
57        if !is_stack32 && !is_stack64 {
58            continue;
59        }
60
61        if item.DataPtr == 0 || item.DataSize < 8 {
62            continue;
63        }
64
65        let raw = unsafe {
66            std::slice::from_raw_parts(item.DataPtr as *const u8, item.DataSize as usize)
67        };
68
69        if raw.len() < 8 {
70            continue;
71        }
72
73        let match_id = u64::from_le_bytes(raw[0..8].try_into().ok()?);
74        let frame_size = if is_stack32 { 4 } else { 8 };
75
76        let mut frames = Vec::new();
77        let mut offset = 8usize;
78        while offset + frame_size <= raw.len() {
79            let addr = if frame_size == 4 {
80                let bytes: [u8; 4] = raw[offset..offset + 4].try_into().ok()?;
81                u32::from_le_bytes(bytes) as u64
82            } else {
83                let bytes: [u8; 8] = raw[offset..offset + 8].try_into().ok()?;
84                u64::from_le_bytes(bytes)
85            };
86
87            if addr != 0 {
88                frames.push(addr);
89            }
90            offset += frame_size;
91        }
92
93        return Some(StackTrace::new(match_id, frames));
94    }
95
96    None
97}
98
99fn extract_cpu_sample(record: &EVENT_RECORD) -> CpuSample {
100    // ETW_BUFFER_CONTEXT starts with ProcessorNumber (u8).
101    let processor_number = unsafe { *(std::ptr::addr_of!(record.BufferContext) as *const u8) };
102    CpuSample::new(processor_number)
103}
104
105/// Owns callback context storage and provides a stable user-context pointer.
106///
107/// `EVENT_TRACE_LOGFILEW::Context` stores a raw pointer that ETW passes back to
108/// `trace_callback_fn` for every event. We keep the `Arc<CallbackContext>` in a
109/// boxed allocation so the address stays stable for the full trace lifetime.
110struct CallbackContextGuard {
111    #[allow(clippy::redundant_allocation)]
112    boxed_ctx: Box<Arc<CallbackContext>>,
113}
114
115impl CallbackContextGuard {
116    fn new(ctx: CallbackContext) -> Self {
117        Self {
118            boxed_ctx: Box::new(Arc::new(ctx)),
119        }
120    }
121
122    fn as_user_context_ptr(&self) -> *mut std::ffi::c_void {
123        self.boxed_ctx.as_ref() as *const Arc<CallbackContext> as *mut std::ffi::c_void
124    }
125}
126
127/// Output stream strategy for ETW events.
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum EventStreamMode {
130    /// Emit only raw `TraceEvent` values.
131    Raw,
132    /// Emit only decoded `DecodedEvent` values.
133    Decoded,
134    /// Emit both raw and decoded streams.
135    Both,
136}
137
138/// The `ProcessTrace` callback invoked by Windows for each event record.
139///
140/// # Safety
141///
142/// Called by the OS from the `ProcessTrace` background thread. The `UserContext`
143/// field of the `EVENT_RECORD` must point to a valid `Arc<CallbackContext>` that
144/// remains alive for the duration of the trace.
145unsafe extern "system" fn trace_callback_fn(event_record: *mut EVENT_RECORD) {
146    let record = match unsafe { event_record.as_ref() } {
147        Some(r) => r,
148        None => return,
149    };
150    let ctx_ptr = record.UserContext as *const Arc<CallbackContext>;
151    let ctx = match unsafe { ctx_ptr.as_ref() } {
152        Some(c) => c,
153        None => return,
154    };
155
156    if let Some(filter) = &ctx.process_filter
157        && !filter.contains(&ProcessId::new(record.EventHeader.ProcessId))
158    {
159        return;
160    }
161
162    let fields = ctx
163        .schema_cache
164        .as_ref()
165        .and_then(|cache| cache.lock().ok())
166        .and_then(|mut cache| cache.parse_event_fields(record));
167
168    let payload = if record.UserDataLength > 0 && !record.UserData.is_null() {
169        unsafe {
170            std::slice::from_raw_parts(record.UserData as *const u8, record.UserDataLength as usize)
171        }
172    } else {
173        &[]
174    };
175
176    if let Some(sender) = &ctx.decoded_sender {
177        let desc = record.EventHeader.EventDescriptor;
178        let decoded = decode_from_record_parts(
179            record.EventHeader.ProviderId,
180            desc.Version,
181            desc.Opcode,
182            payload,
183            fields.as_deref(),
184        );
185        // Drop decoded events when channel is full (bounded backpressure).
186        let _ = sender.try_send(decoded);
187    }
188
189    if let Some(sender) = &ctx.raw_sender {
190        let mut event = TraceEvent::from_event_record_with_fields(record, fields);
191        if ctx.include_thread_context {
192            event.thread_context = Some(ThreadContext::new(event.process_id, event.thread_id));
193        }
194        if ctx.include_stack_traces {
195            event.stack_trace = extract_stack_trace(record);
196        }
197        if ctx.include_cpu_samples {
198            event.cpu_sample = Some(extract_cpu_sample(record));
199        }
200        // Drop raw events when channel is full (bounded backpressure).
201        let _ = sender.try_send(event);
202    }
203}
204
205/// A running ETW trace session.
206///
207/// Created by [`EventTraceBuilder::start`]. Automatically stops the trace when
208/// dropped (RAII).
209///
210/// Use [`EventTrace::builder`] to configure and start a session.
211pub struct EventTrace {
212    /// Session name (`NT Kernel Logger` for kernel providers, custom for user-mode providers).
213    name: String,
214
215    /// Handle used for [`ControlTraceW`] stop/flush operations.
216    session_handle: CONTROLTRACE_HANDLE,
217
218    /// Handle returned by `OpenTraceW`, used for `CloseTrace`.
219    trace_handle: PROCESSTRACE_HANDLE,
220
221    /// Optional bounded channel receiving raw events from the callback.
222    event_rx: Option<Receiver<TraceEvent>>,
223
224    /// Optional bounded channel receiving decoded events from the callback.
225    decoded_rx: Option<Receiver<DecodedEvent>>,
226
227    /// Running total of events delivered through [`next_batch`][Self::next_batch].
228    events_processed: usize,
229
230    /// `false` after [`stop`][Self::stop] or [`drop`][Drop::drop] to avoid double-stop.
231    started: bool,
232
233    /// Background thread running `ProcessTrace` (blocks until `CloseTrace`).
234    process_thread: Option<JoinHandle<()>>,
235
236    /// Internal stop signal that can be shared with external coordinators.
237    stop_signal: Wait,
238
239    /// Owns callback context memory for ETW callback user-data pointer.
240    _callback_ctx_guard: CallbackContextGuard,
241}
242
243impl EventTrace {
244    /// Create a builder to configure and start an ETW trace session.
245    ///
246    /// No validation happens here — all checks run in
247    /// [`EventTraceBuilder::start`] so the builder can always be constructed
248    /// without a `Result`.
249    ///
250    /// # Example
251    ///
252    /// ```no_run
253    /// use windows_erg::etw::{EventTrace, SystemProvider};
254    ///
255    /// let mut trace = EventTrace::builder("ProcessMonitor")
256    ///     .system_provider(SystemProvider::Process)
257    ///     .start()?;
258    /// # Ok::<(), windows_erg::Error>(())
259    /// ```
260    pub fn builder(name: impl Into<String>) -> EventTraceBuilder {
261        EventTraceBuilder {
262            name: name.into(),
263            system_providers: Vec::new(),
264            user_providers: Vec::new(),
265            buffer_size: 64,
266            min_buffers: 2,
267            max_buffers: 20,
268            flush_interval: 1,
269            channel_capacity: 10_000,
270            stream_mode: EventStreamMode::Raw,
271            stack_traces: false,
272            thread_context: false,
273            detailed_events: false,
274            cpu_samples: false,
275            process_filter: Vec::new(),
276        }
277    }
278
279    /// The active ETW session name.
280    ///
281    /// Kernel sessions always use `NT Kernel Logger`; user-mode sessions use
282    /// the builder name.
283    pub fn name(&self) -> &str {
284        &self.name
285    }
286
287    /// Total events delivered so far across all [`next_batch`][Self::next_batch] calls.
288    pub fn events_processed(&self) -> usize {
289        self.events_processed
290    }
291
292    /// Get a clone of the stop signal for external cancellation coordination.
293    pub fn stop_handle(&self) -> Wait {
294        self.stop_signal.clone()
295    }
296
297    /// Fetch the next batch of events into the output buffer.
298    ///
299    /// Clears `out_events` before filling it. Returns the number of events added.
300    pub fn next_batch(&mut self, out_events: &mut Vec<TraceEvent>) -> Result<usize> {
301        self.next_batch_with_filter(out_events, |_| true)
302    }
303
304    /// Fetch the next batch unless the session stop signal has been set.
305    ///
306    /// Returns `0` when stop was requested.
307    pub fn next_batch_or_stopped(&mut self, out_events: &mut Vec<TraceEvent>) -> Result<usize> {
308        if self.stop_signal.is_signaled()? {
309            out_events.clear();
310            return Ok(0);
311        }
312        self.next_batch(out_events)
313    }
314
315    /// Continuously drain batches until the stop signal is set.
316    ///
317    /// The output buffer is reused on each iteration.
318    pub fn run_until_stopped(
319        &mut self,
320        out_events: &mut Vec<TraceEvent>,
321        poll_interval: Duration,
322    ) -> Result<()> {
323        loop {
324            if self.stop_signal.is_signaled()? {
325                out_events.clear();
326                return Ok(());
327            }
328            let _ = self.next_batch(out_events)?;
329            std::thread::sleep(poll_interval);
330        }
331    }
332
333    /// Fetch the next batch of events, keeping only those that pass `filter`.
334    ///
335    /// Clears `out_events` before filling it. Returns the number of events added.
336    ///
337    /// Filtering happens **during** enumeration, so rejected events are never
338    /// pushed to the buffer.
339    pub fn next_batch_with_filter<F>(
340        &mut self,
341        out_events: &mut Vec<TraceEvent>,
342        filter: F,
343    ) -> Result<usize>
344    where
345        F: Fn(&TraceEvent) -> bool,
346    {
347        let rx = self.event_rx.as_ref().ok_or_else(|| {
348            Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
349                Cow::Borrowed("Raw event stream is disabled for this session"),
350            )))
351        })?;
352
353        out_events.clear();
354        while let Ok(event) = rx.try_recv() {
355            if filter(&event) {
356                out_events.push(event);
357                self.events_processed += 1;
358            }
359        }
360        Ok(out_events.len())
361    }
362
363    /// Fetch the next batch of decoded events into the output buffer.
364    ///
365    /// Clears `out_events` before filling it. Returns the number of events added.
366    pub fn next_batch_decoded(&mut self, out_events: &mut Vec<DecodedEvent>) -> Result<usize> {
367        self.next_batch_decoded_with_filter(out_events, |_| true)
368    }
369
370    /// Fetch the next batch of decoded events, keeping only those that pass `filter`.
371    ///
372    /// Clears `out_events` before filling it. Returns the number of events added.
373    pub fn next_batch_decoded_with_filter<F>(
374        &mut self,
375        out_events: &mut Vec<DecodedEvent>,
376        filter: F,
377    ) -> Result<usize>
378    where
379        F: Fn(&DecodedEvent) -> bool,
380    {
381        let rx = self.decoded_rx.as_ref().ok_or_else(|| {
382            Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
383                Cow::Borrowed("Decoded event stream is disabled for this session"),
384            )))
385        })?;
386
387        out_events.clear();
388        while let Ok(event) = rx.try_recv() {
389            if filter(&event) {
390                out_events.push(event);
391                self.events_processed += 1;
392            }
393        }
394        Ok(out_events.len())
395    }
396
397    /// Stop the trace session explicitly.
398    ///
399    /// Also called automatically when `EventTrace` is dropped.
400    pub fn stop(&mut self) -> Result<()> {
401        if !self.started {
402            return Ok(());
403        }
404
405        let _ = self.stop_signal.set();
406
407        // 1. Stop the ETW session via ControlTraceW.
408        let name_wide = to_utf16_nul(&self.name);
409
410        let mut properties_buffer =
411            vec![0u8; std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2)];
412
413        unsafe {
414            let properties = &mut *(properties_buffer.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
415            properties.Wnode.BufferSize = properties_buffer.len() as u32;
416            properties.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
417
418            let _ = ControlTraceW(
419                self.session_handle,
420                PWSTR(name_wide.as_ptr() as *mut u16),
421                properties,
422                EVENT_TRACE_CONTROL_STOP,
423            );
424        }
425
426        // 2. Close the trace handle — unblocks the ProcessTrace background thread.
427        if self.trace_handle.Value != u64::MAX {
428            unsafe {
429                // ERROR_CTX_CLOSE_PENDING is expected while ProcessTrace finishes.
430                let _ = CloseTrace(self.trace_handle);
431            }
432            self.trace_handle = PROCESSTRACE_HANDLE { Value: u64::MAX };
433        }
434
435        // 3. Wait for the ProcessTrace background thread to exit.
436        if let Some(handle) = self.process_thread.take() {
437            let _ = handle.join();
438        }
439
440        self.started = false;
441        Ok(())
442    }
443}
444
445impl Drop for EventTrace {
446    fn drop(&mut self) {
447        let _ = self.stop();
448    }
449}
450
451// SAFETY: CONTROLTRACE_HANDLE is an opaque integer; EventTrace owns it exclusively.
452unsafe impl Send for EventTrace {}
453
454/// Builder for configuring an ETW trace session.
455///
456/// Obtained from [`EventTrace::builder`]. Chain configuration methods, then
457/// call [`start`][Self::start] to begin tracing and receive an [`EventTrace`] handle.
458pub struct EventTraceBuilder {
459    name: String,
460    system_providers: Vec<SystemProvider>,
461    user_providers: Vec<GUID>,
462    buffer_size: u32,
463    min_buffers: u32,
464    max_buffers: u32,
465    flush_interval: u32,
466    channel_capacity: usize,
467    stream_mode: EventStreamMode,
468
469    // Optional enrichment and filtering flags.
470    stack_traces: bool,
471    thread_context: bool,
472    detailed_events: bool,
473    cpu_samples: bool,
474    process_filter: Vec<ProcessId>,
475}
476
477impl EventTraceBuilder {
478    /// Add a kernel event source to this trace session.
479    ///
480    /// Can be called multiple times to monitor several sources at once.
481    /// At least one provider must be added before calling [`start`][Self::start].
482    ///
483    /// # Example
484    ///
485    /// ```no_run
486    /// use windows_erg::etw::{EventTrace, SystemProvider};
487    ///
488    /// let trace = EventTrace::builder("SecurityMonitor")
489    ///     .system_provider(SystemProvider::Process)
490    ///     .system_provider(SystemProvider::Registry)
491    ///     .start()?;
492    /// # Ok::<(), windows_erg::Error>(())
493    /// ```
494    pub fn system_provider(mut self, provider: SystemProvider) -> Self {
495        self.system_providers.push(provider);
496        self
497    }
498
499    /// Add a user-mode ETW provider by GUID.
500    ///
501    /// This enables events from providers registered with `EventRegister`
502    /// (for example application or service providers).
503    ///
504    /// User-mode providers cannot be mixed with kernel [`SystemProvider`]s in
505    /// a single session.
506    pub fn user_provider(mut self, provider_guid: GUID) -> Self {
507        self.user_providers.push(provider_guid);
508        self
509    }
510
511    /// Set buffer size in kilobytes (default: 64 KB).
512    ///
513    /// Larger buffers reduce the chance of losing events at the cost of memory.
514    pub fn buffer_size(mut self, size_kb: u32) -> Self {
515        self.buffer_size = size_kb;
516        self
517    }
518
519    /// Set the minimum number of event buffers pre-allocated by the OS (default: 2).
520    pub fn min_buffers(mut self, count: u32) -> Self {
521        self.min_buffers = count;
522        self
523    }
524
525    /// Set the maximum number of event buffers the OS may allocate (default: 20).
526    pub fn max_buffers(mut self, count: u32) -> Self {
527        self.max_buffers = count;
528        self
529    }
530
531    /// Set how often the OS flushes filled buffers, in seconds (default: 1).
532    pub fn flush_interval(mut self, seconds: u32) -> Self {
533        self.flush_interval = seconds;
534        self
535    }
536
537    /// Set the internal event channel capacity (default: 10 000).
538    ///
539    /// Bounds memory usage during high-volume tracing. Events beyond this
540    /// limit are dropped when the consumer falls behind.
541    pub fn channel_capacity(mut self, capacity: usize) -> Self {
542        self.channel_capacity = capacity;
543        self
544    }
545
546    /// Emit only decoded events to avoid raw event allocation overhead.
547    pub fn with_decoded_stream(mut self) -> Self {
548        self.stream_mode = EventStreamMode::Decoded;
549        self
550    }
551
552    /// Emit both raw and decoded events.
553    pub fn with_both_streams(mut self) -> Self {
554        self.stream_mode = EventStreamMode::Both;
555        self
556    }
557
558    // -------------------------------------------------------------------------
559    // Optional with_* enrichment features
560    // -------------------------------------------------------------------------
561
562    /// Capture stack trace metadata for events when ETW provides it.
563    ///
564    /// When enabled, raw [`TraceEvent`] values may include `stack_trace`
565    /// parsed from event extended data items.
566    pub fn with_stack_traces(mut self) -> Self {
567        self.stack_traces = true;
568        self
569    }
570
571    /// Include thread context metadata in each event.
572    ///
573    /// When enabled, raw [`TraceEvent`] values include `thread_context` metadata
574    /// populated from the ETW event header (`ProcessId` and `ThreadId`).
575    pub fn with_thread_context(mut self) -> Self {
576        self.thread_context = true;
577        self
578    }
579
580    /// Parse event payloads into named fields using the provider schema *(planned feature)*.
581    ///
582    /// When implemented, the raw `data` bytes in each [`TraceEvent`] will be
583    /// pre-decoded into structured fields based on the provider's event schema.
584    pub fn with_detailed_events(mut self) -> Self {
585        self.detailed_events = true;
586        self
587    }
588
589    /// Attach basic CPU sampling metadata to each raw event.
590    ///
591    /// When enabled, raw [`TraceEvent`] values include `cpu_sample`
592    /// with the logical processor number from ETW buffer context.
593    pub fn with_cpu_samples(mut self) -> Self {
594        self.cpu_samples = true;
595        self
596    }
597
598    /// Restrict event collection to specific process IDs.
599    ///
600    /// When non-empty, only events whose `ProcessId` matches one of `pids`
601    /// are forwarded from the ETW callback to the output channels.
602    ///
603    /// # Example
604    ///
605    /// ```no_run
606    /// use windows_erg::etw::{EventTrace, SystemProvider};
607    ///
608    /// let trace = EventTrace::builder("TargetedMonitor")
609    ///     .system_provider(SystemProvider::FileIo)
610    ///     .with_process_filter(vec![1234, 5678])
611    ///     .start()?;
612    /// # Ok::<(), windows_erg::Error>(())
613    /// ```
614    pub fn with_process_filter<I, P>(mut self, pids: I) -> Self
615    where
616        I: IntoIterator<Item = P>,
617        P: Into<ProcessId>,
618    {
619        self.process_filter = pids.into_iter().map(Into::into).collect();
620        self
621    }
622
623    /// Start the trace session and return an [`EventTrace`] handle.
624    ///
625    /// # Errors
626    ///
627    /// | Condition | Error |
628    /// |-----------|-------|
629    /// | Empty session name | [`EtwError::SessionStartFailed`] |
630    /// | Name longer than 1024 chars | [`EtwError::SessionStartFailed`] |
631    /// | No providers specified | [`EtwError::SessionStartFailed`] |
632    /// | Mixed kernel + user providers | [`EtwError::SessionStartFailed`] |
633    /// | `min_buffers` > `max_buffers` | [`EtwError::SessionStartFailed`] |
634    /// | `NT Kernel Logger` already running | `SessionStartFailed` with `ERROR_ALREADY_EXISTS` |
635    /// | Windows API failure | [`EtwError::SessionStartFailed`] with OS error code |
636    pub fn start(self) -> Result<EventTrace> {
637        // ----- Validate -----
638
639        if self.name.is_empty() {
640            return Err(Error::Etw(EtwError::SessionStartFailed(
641                EtwSessionError::new(
642                    Cow::Borrowed(""),
643                    Cow::Borrowed("Session name cannot be empty"),
644                ),
645            )));
646        }
647
648        if self.name.len() > MAX_SESSION_NAME_LEN {
649            return Err(Error::Etw(EtwError::SessionStartFailed(
650                EtwSessionError::new(
651                    Cow::Owned(self.name.clone()),
652                    Cow::Borrowed("Session name exceeds 1024 characters"),
653                ),
654            )));
655        }
656
657        if self.system_providers.is_empty() && self.user_providers.is_empty() {
658            return Err(Error::Etw(EtwError::SessionStartFailed(
659                EtwSessionError::new(
660                    Cow::Owned(self.name.clone()),
661                    Cow::Borrowed(
662                        "At least one system provider or user provider GUID must be specified",
663                    ),
664                ),
665            )));
666        }
667
668        if !self.system_providers.is_empty() && !self.user_providers.is_empty() {
669            return Err(Error::Etw(EtwError::SessionStartFailed(
670                EtwSessionError::invalid_config(
671                    Cow::Owned(self.name.clone()),
672                    "providers",
673                    Cow::Borrowed(
674                        "Cannot mix kernel system providers with user-mode provider GUIDs in one session",
675                    ),
676                ),
677            )));
678        }
679
680        if self.min_buffers > self.max_buffers {
681            return Err(Error::Etw(EtwError::SessionStartFailed(
682                EtwSessionError::new(
683                    Cow::Owned(self.name.clone()),
684                    Cow::Owned(format!(
685                        "min_buffers ({}) cannot exceed max_buffers ({})",
686                        self.min_buffers, self.max_buffers
687                    )),
688                ),
689            )));
690        }
691
692        // ----- Build EVENT_TRACE_PROPERTIES -----
693
694        let is_kernel_session = !self.system_providers.is_empty();
695
696        // Kernel providers require the reserved "NT Kernel Logger" name.
697        let session_name = if is_kernel_session {
698            KERNEL_LOGGER_NAME.to_string()
699        } else {
700            self.name.clone()
701        };
702        let name_wide: Vec<u16> = session_name
703            .encode_utf16()
704            .chain(std::iter::once(0))
705            .collect();
706
707        let properties_size =
708            std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2);
709        let mut properties_buffer = vec![0u8; properties_size];
710
711        let properties =
712            unsafe { &mut *(properties_buffer.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES) };
713
714        // Kernel sessions combine EVENT_TRACE_FLAG values from enabled providers.
715        let enable_flags: u32 = if is_kernel_session {
716            self.system_providers
717                .iter()
718                .fold(0u32, |acc, p| acc | p.trace_flags())
719        } else {
720            0
721        };
722
723        properties.Wnode.BufferSize = properties_buffer.len() as u32;
724        properties.Wnode.Flags = WNODE_FLAG_TRACED_GUID;
725        properties.Wnode.ClientContext = 1; // QPC clock resolution
726        properties.Wnode.Guid = GUID::zeroed();
727        properties.BufferSize = self.buffer_size;
728        properties.MinimumBuffers = self.min_buffers;
729        properties.MaximumBuffers = self.max_buffers;
730        properties.FlushTimer = self.flush_interval;
731        properties.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
732        properties.EnableFlags = EVENT_TRACE_FLAG(enable_flags);
733        properties.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
734
735        // ----- StartTraceW -----
736
737        let mut session_handle = CONTROLTRACE_HANDLE::default();
738
739        let start_result = unsafe {
740            StartTraceW(
741                &mut session_handle,
742                PWSTR(name_wide.as_ptr() as *mut u16),
743                properties,
744            )
745        };
746
747        if start_result.0 == ERROR_ALREADY_EXISTS_CODE && is_kernel_session {
748            // Stale session from a previous crash — stop it and retry.
749            let stop_buf_size =
750                std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2);
751            let mut stop_buf = vec![0u8; stop_buf_size];
752            unsafe {
753                let stop_props = &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
754                stop_props.Wnode.BufferSize = stop_buf.len() as u32;
755                stop_props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
756                let _ = ControlTraceW(
757                    CONTROLTRACE_HANDLE::default(),
758                    PWSTR(name_wide.as_ptr() as *mut u16),
759                    stop_props,
760                    EVENT_TRACE_CONTROL_STOP,
761                );
762            }
763
764            // Re-build properties and retry (StartTraceW may have modified the buffer).
765            let mut retry_buf = vec![0u8; properties_size];
766            let retry_result = unsafe {
767                let props = &mut *(retry_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
768                props.Wnode.BufferSize = retry_buf.len() as u32;
769                props.Wnode.Flags = WNODE_FLAG_TRACED_GUID;
770                props.Wnode.ClientContext = 1;
771                props.Wnode.Guid = GUID::zeroed();
772                props.BufferSize = self.buffer_size;
773                props.MinimumBuffers = self.min_buffers;
774                props.MaximumBuffers = self.max_buffers;
775                props.FlushTimer = self.flush_interval;
776                props.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
777                props.EnableFlags = EVENT_TRACE_FLAG(enable_flags);
778                props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
779                StartTraceW(
780                    &mut session_handle,
781                    PWSTR(name_wide.as_ptr() as *mut u16),
782                    props,
783                )
784            };
785
786            if retry_result != ERROR_SUCCESS {
787                return Err(Error::Etw(EtwError::SessionStartFailed(
788                    EtwSessionError::with_code(
789                        Cow::Owned(session_name),
790                        Cow::Borrowed("Failed to start trace after stopping stale session"),
791                        retry_result.0 as i32,
792                    ),
793                )));
794            }
795        } else if start_result != ERROR_SUCCESS {
796            return Err(Error::Etw(EtwError::SessionStartFailed(
797                EtwSessionError::with_code(
798                    Cow::Owned(session_name),
799                    Cow::Borrowed("Failed to start trace session"),
800                    start_result.0 as i32,
801                ),
802            )));
803        }
804
805        if !is_kernel_session {
806            for provider_guid in &self.user_providers {
807                let enable_result = unsafe {
808                    EnableTraceEx2(
809                        session_handle,
810                        provider_guid as *const GUID,
811                        EVENT_CONTROL_CODE_ENABLE_PROVIDER.0,
812                        TRACE_LEVEL_VERBOSE as u8,
813                        u64::MAX,
814                        0,
815                        0,
816                        None,
817                    )
818                };
819
820                if enable_result != ERROR_SUCCESS {
821                    let mut stop_buf = vec![
822                        0u8;
823                        std::mem::size_of::<EVENT_TRACE_PROPERTIES>()
824                            + (MAX_SESSION_NAME_LEN * 2)
825                    ];
826                    unsafe {
827                        let stop_props =
828                            &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
829                        stop_props.Wnode.BufferSize = stop_buf.len() as u32;
830                        stop_props.LoggerNameOffset =
831                            std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
832                        let _ = ControlTraceW(
833                            session_handle,
834                            PWSTR(name_wide.as_ptr() as *mut u16),
835                            stop_props,
836                            EVENT_TRACE_CONTROL_STOP,
837                        );
838                    }
839
840                    return Err(Error::Etw(EtwError::ProviderEnableFailed(
841                        EtwProviderError::with_code(
842                            Cow::Owned(format!("{provider_guid:?}")),
843                            Cow::Borrowed("Failed to enable user-mode ETW provider"),
844                            enable_result.0 as i32,
845                        ),
846                    )));
847                }
848            }
849        }
850
851        // ----- Event consumption pipeline -----
852
853        let (raw_tx, event_rx) = match self.stream_mode {
854            EventStreamMode::Raw | EventStreamMode::Both => {
855                let (tx, rx) = mpsc::sync_channel(self.channel_capacity);
856                (Some(tx), Some(rx))
857            }
858            EventStreamMode::Decoded => (None, None),
859        };
860
861        let (decoded_tx, decoded_rx) = match self.stream_mode {
862            EventStreamMode::Decoded | EventStreamMode::Both => {
863                let (tx, rx) = mpsc::sync_channel(self.channel_capacity);
864                (Some(tx), Some(rx))
865            }
866            EventStreamMode::Raw => (None, None),
867        };
868
869        let schema_cache = if self.detailed_events || decoded_tx.is_some() {
870            Some(Mutex::new(SchemaCache::new()))
871        } else {
872            None
873        };
874
875        let callback_ctx_guard = CallbackContextGuard::new(CallbackContext {
876            raw_sender: raw_tx,
877            decoded_sender: decoded_tx,
878            schema_cache,
879            process_filter: normalize_process_filter(self.process_filter),
880            include_thread_context: self.thread_context,
881            include_stack_traces: self.stack_traces,
882            include_cpu_samples: self.cpu_samples,
883        });
884        let ctx_ptr = callback_ctx_guard.as_user_context_ptr();
885
886        // Configure real-time trace consumption via OpenTraceW.
887        let mut log_file = EVENT_TRACE_LOGFILEW {
888            LoggerName: PWSTR(name_wide.as_ptr() as *mut u16),
889            Anonymous1: EVENT_TRACE_LOGFILEW_0 {
890                ProcessTraceMode: PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_REAL_TIME,
891            },
892            Anonymous2: EVENT_TRACE_LOGFILEW_1 {
893                EventRecordCallback: Some(trace_callback_fn),
894            },
895            Context: ctx_ptr,
896            ..Default::default()
897        };
898
899        let trace_handle = unsafe { OpenTraceW(&mut log_file) };
900        if trace_handle.Value == u64::MAX {
901            // OpenTraceW failed — clean up the started session.
902            let mut stop_buf = vec![
903                0u8;
904                std::mem::size_of::<EVENT_TRACE_PROPERTIES>()
905                    + (MAX_SESSION_NAME_LEN * 2)
906            ];
907            unsafe {
908                let stop_props = &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
909                stop_props.Wnode.BufferSize = stop_buf.len() as u32;
910                stop_props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
911                let _ = ControlTraceW(
912                    session_handle,
913                    PWSTR(name_wide.as_ptr() as *mut u16),
914                    stop_props,
915                    EVENT_TRACE_CONTROL_STOP,
916                );
917            }
918            return Err(Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
919                Cow::Borrowed("OpenTraceW failed"),
920            ))));
921        }
922
923        // Spawn background thread — ProcessTrace blocks until CloseTrace is called.
924        let process_trace_handle = trace_handle;
925        let process_thread = std::thread::spawn(move || unsafe {
926            let handles = [process_trace_handle];
927            let now = GetSystemTimeAsFileTime();
928            let _ = ProcessTrace(&handles, Some(&now as *const _), None);
929        });
930
931        Ok(EventTrace {
932            name: session_name,
933            session_handle,
934            trace_handle,
935            event_rx,
936            decoded_rx,
937            events_processed: 0,
938            started: true,
939            process_thread: Some(process_thread),
940            stop_signal: Wait::manual_reset(false)?,
941            _callback_ctx_guard: callback_ctx_guard,
942        })
943    }
944}
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949
950    fn make_trace_event(id: u16, process_id: u32) -> TraceEvent {
951        const FILETIME_UNIX_EPOCH: i64 = 116_444_736_000_000_000;
952
953        let mut record = EVENT_RECORD::default();
954        record.EventHeader.EventDescriptor.Id = id;
955        record.EventHeader.ProviderId = GUID::zeroed();
956        record.EventHeader.ProcessId = process_id;
957        record.EventHeader.ThreadId = 1;
958        record.EventHeader.TimeStamp = FILETIME_UNIX_EPOCH;
959        record.UserDataLength = 0;
960        record.UserData = std::ptr::null_mut();
961        TraceEvent::from_event_record_with_fields(&record, None)
962    }
963
964    fn inert_trace(
965        event_rx: Option<Receiver<TraceEvent>>,
966        decoded_rx: Option<Receiver<DecodedEvent>>,
967    ) -> EventTrace {
968        EventTrace {
969            name: "TestTrace".to_string(),
970            session_handle: CONTROLTRACE_HANDLE::default(),
971            trace_handle: PROCESSTRACE_HANDLE { Value: u64::MAX },
972            event_rx,
973            decoded_rx,
974            events_processed: 0,
975            started: false,
976            process_thread: None,
977            stop_signal: Wait::manual_reset(false).expect("wait handle create"),
978            _callback_ctx_guard: CallbackContextGuard::new(CallbackContext {
979                raw_sender: None,
980                decoded_sender: None,
981                schema_cache: None,
982                process_filter: None,
983                include_thread_context: false,
984                include_stack_traces: false,
985                include_cpu_samples: false,
986            }),
987        }
988    }
989
990    #[test]
991    fn test_builder_requires_provider() {
992        // No provider selection → start() must fail.
993        let result = EventTrace::builder("TestSession").start();
994        assert!(result.is_err());
995    }
996
997    #[test]
998    fn test_start_fails_when_mixing_kernel_and_user_providers() {
999        let result = EventTrace::builder("TestSession")
1000            .system_provider(SystemProvider::Process)
1001            .user_provider(GUID::zeroed())
1002            .start();
1003
1004        match result {
1005            Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1006                assert!(e.reason.contains("Cannot mix kernel system providers"));
1007            }
1008            _ => panic!("expected SessionStartFailed"),
1009        }
1010    }
1011
1012    #[test]
1013    fn test_empty_name_fails() {
1014        let result = EventTrace::builder("").start();
1015        assert!(result.is_err());
1016    }
1017
1018    #[test]
1019    fn test_name_too_long_fails() {
1020        let long_name = "x".repeat(MAX_SESSION_NAME_LEN + 1);
1021        let result = EventTrace::builder(long_name).start();
1022
1023        match result {
1024            Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1025                assert!(e.reason.contains("exceeds 1024"));
1026            }
1027            _ => panic!("expected SessionStartFailed"),
1028        }
1029    }
1030
1031    #[test]
1032    fn test_max_name_length_passes_length_validation() {
1033        let max_name = "x".repeat(MAX_SESSION_NAME_LEN);
1034        let result = EventTrace::builder(max_name).start();
1035
1036        match result {
1037            Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1038                // If max length is accepted, validation proceeds to provider requirement.
1039                assert!(
1040                    e.reason
1041                        .contains("At least one system provider or user provider GUID")
1042                );
1043            }
1044            _ => panic!("expected SessionStartFailed"),
1045        }
1046    }
1047
1048    #[test]
1049    fn test_buffer_constraint_fails() {
1050        let result = EventTrace::builder("Test")
1051            .system_provider(SystemProvider::Process)
1052            .min_buffers(10)
1053            .max_buffers(5) // invalid: min > max
1054            .start();
1055        assert!(result.is_err());
1056    }
1057
1058    #[test]
1059    fn test_normalize_process_filter_empty_is_none() {
1060        let filter = normalize_process_filter(Vec::new());
1061        assert!(filter.is_none());
1062    }
1063
1064    #[test]
1065    fn test_normalize_process_filter_deduplicates() {
1066        let filter = normalize_process_filter(vec![
1067            ProcessId::new(100),
1068            ProcessId::new(200),
1069            ProcessId::new(100),
1070        ])
1071        .expect("expected filter set");
1072        assert_eq!(filter.len(), 2);
1073        assert!(filter.contains(&ProcessId::new(100)));
1074        assert!(filter.contains(&ProcessId::new(200)));
1075    }
1076
1077    #[test]
1078    fn test_extract_stack_trace_none_without_extended_data() {
1079        let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1080        record.ExtendedDataCount = 0;
1081        record.ExtendedData = std::ptr::null_mut();
1082
1083        assert!(extract_stack_trace(&record).is_none());
1084    }
1085
1086    #[test]
1087    fn test_extract_stack_trace_64bit_payload() {
1088        let mut payload = Vec::new();
1089        payload.extend_from_slice(&0x1122_3344_5566_7788u64.to_le_bytes());
1090        payload.extend_from_slice(&0x0000_0000_0000_1111u64.to_le_bytes());
1091        payload.extend_from_slice(&0x0000_0000_0000_2222u64.to_le_bytes());
1092
1093        let mut ext: EVENT_HEADER_EXTENDED_DATA_ITEM = unsafe { std::mem::zeroed() };
1094        ext.ExtType = EVENT_HEADER_EXT_TYPE_STACK_TRACE64 as u16;
1095        ext.DataSize = payload.len() as u16;
1096        ext.DataPtr = payload.as_ptr() as u64;
1097
1098        let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1099        record.ExtendedDataCount = 1;
1100        record.ExtendedData = &mut ext;
1101
1102        let parsed = extract_stack_trace(&record).expect("stack should parse");
1103        assert_eq!(parsed.match_id, 0x1122_3344_5566_7788u64);
1104        assert_eq!(parsed.frames, vec![0x1111, 0x2222]);
1105    }
1106
1107    #[test]
1108    fn test_extract_cpu_sample_reads_processor_number() {
1109        let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1110        unsafe {
1111            *(std::ptr::addr_of_mut!(record.BufferContext) as *mut u8) = 13;
1112        }
1113
1114        let sample = extract_cpu_sample(&record);
1115        assert_eq!(sample.processor_number, 13);
1116    }
1117
1118    #[test]
1119    fn test_next_batch_fails_when_raw_stream_disabled() {
1120        let mut trace = inert_trace(None, None);
1121        let mut out = Vec::new();
1122
1123        let result = trace.next_batch(&mut out);
1124        match result {
1125            Err(Error::Etw(EtwError::ConsumeFailed(e))) => {
1126                assert!(e.reason.contains("Raw event stream is disabled"));
1127            }
1128            _ => panic!("expected ConsumeFailed"),
1129        }
1130    }
1131
1132    #[test]
1133    fn test_next_batch_decoded_fails_when_decoded_stream_disabled() {
1134        let mut trace = inert_trace(None, None);
1135        let mut out = Vec::new();
1136
1137        let result = trace.next_batch_decoded(&mut out);
1138        match result {
1139            Err(Error::Etw(EtwError::ConsumeFailed(e))) => {
1140                assert!(e.reason.contains("Decoded event stream is disabled"));
1141            }
1142            _ => panic!("expected ConsumeFailed"),
1143        }
1144    }
1145
1146    #[test]
1147    fn test_next_batch_drains_raw_stream_and_updates_counter() {
1148        let (tx, rx) = mpsc::sync_channel(8);
1149        tx.send(make_trace_event(1, 100)).expect("send event 1");
1150        tx.send(make_trace_event(2, 200)).expect("send event 2");
1151        drop(tx);
1152
1153        let mut trace = inert_trace(Some(rx), None);
1154        let mut out = Vec::new();
1155
1156        let count = trace
1157            .next_batch(&mut out)
1158            .expect("next_batch should succeed");
1159        assert_eq!(count, 2);
1160        assert_eq!(out.len(), 2);
1161        assert_eq!(trace.events_processed(), 2);
1162    }
1163
1164    #[test]
1165    fn test_next_batch_with_filter_filters_during_drain() {
1166        let (tx, rx) = mpsc::sync_channel(8);
1167        tx.send(make_trace_event(1, 111)).expect("send event 1");
1168        tx.send(make_trace_event(2, 222)).expect("send event 2");
1169        tx.send(make_trace_event(3, 333)).expect("send event 3");
1170        drop(tx);
1171
1172        let mut trace = inert_trace(Some(rx), None);
1173        let mut out = Vec::new();
1174
1175        let count = trace
1176            .next_batch_with_filter(&mut out, |e| e.process_id != 222)
1177            .expect("next_batch_with_filter should succeed");
1178
1179        assert_eq!(count, 2);
1180        assert_eq!(out.len(), 2);
1181        assert!(out.iter().all(|e| e.process_id != 222));
1182        assert_eq!(trace.events_processed(), 2);
1183    }
1184
1185    #[test]
1186    fn test_next_batch_decoded_drains_stream_and_updates_counter() {
1187        let (tx, rx) = mpsc::sync_channel(8);
1188        tx.send(DecodedEvent::Unknown)
1189            .expect("send decoded event 1");
1190        tx.send(DecodedEvent::Unknown)
1191            .expect("send decoded event 2");
1192        drop(tx);
1193
1194        let mut trace = inert_trace(None, Some(rx));
1195        let mut out = Vec::new();
1196
1197        let count = trace
1198            .next_batch_decoded(&mut out)
1199            .expect("next_batch_decoded should succeed");
1200
1201        assert_eq!(count, 2);
1202        assert_eq!(out.len(), 2);
1203        assert_eq!(trace.events_processed(), 2);
1204    }
1205}