wraith/navigation/
thread_iter.rs

1//! Thread enumeration
2
3#[cfg(all(not(feature = "std"), feature = "alloc"))]
4use alloc::vec::Vec;
5
6#[cfg(feature = "std")]
7use std::vec::Vec;
8
9use crate::arch::segment;
10use crate::error::{Result, WraithError};
11
12/// information about the current thread from TEB
13#[derive(Debug, Clone)]
14pub struct ThreadInfo {
15    pub thread_id: u32,
16    pub teb_address: usize,
17    pub stack_base: usize,
18    pub stack_limit: usize,
19}
20
21impl ThreadInfo {
22    /// get info for current thread from TEB
23    pub fn current() -> Result<Self> {
24        // SAFETY: get_teb always returns valid TEB for current thread
25        let teb = unsafe { segment::get_teb() };
26        if teb.is_null() {
27            return Err(WraithError::InvalidTebAccess);
28        }
29
30        #[cfg(target_arch = "x86_64")]
31        let (tid, stack_base, stack_limit) = {
32            // SAFETY: TEB offsets are well-known for x64
33            let tid = unsafe { segment::get_current_tid() };
34            let stack_base = unsafe { *(teb.add(0x08) as *const u64) } as usize;
35            let stack_limit = unsafe { *(teb.add(0x10) as *const u64) } as usize;
36            (tid, stack_base, stack_limit)
37        };
38
39        #[cfg(target_arch = "x86")]
40        let (tid, stack_base, stack_limit) = {
41            // SAFETY: TEB offsets are well-known for x86
42            let tid = unsafe { segment::get_current_tid() };
43            let stack_base = unsafe { *(teb.add(0x04) as *const u32) } as usize;
44            let stack_limit = unsafe { *(teb.add(0x08) as *const u32) } as usize;
45            (tid, stack_base, stack_limit)
46        };
47
48        Ok(Self {
49            thread_id: tid,
50            teb_address: teb as usize,
51            stack_base,
52            stack_limit,
53        })
54    }
55
56    /// check if an address is on this thread's stack
57    pub fn is_on_stack(&self, address: usize) -> bool {
58        // stack grows downward: limit < address < base
59        address >= self.stack_limit && address < self.stack_base
60    }
61}
62
63/// iterator over threads in current process
64pub struct ThreadIterator {
65    snapshot: *mut core::ffi::c_void,
66    first: bool,
67    target_pid: u32,
68}
69
70impl ThreadIterator {
71    /// create new thread iterator for current process
72    pub fn new() -> Result<Self> {
73        // SAFETY: get_current_pid always returns valid PID
74        let current_pid = unsafe { segment::get_current_pid() };
75        Self::for_process(current_pid)
76    }
77
78    /// enumerate threads for a specific process
79    pub fn for_process(pid: u32) -> Result<Self> {
80        // SAFETY: CreateToolhelp32Snapshot is safe to call with valid flags
81        let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) };
82
83        if snapshot == INVALID_HANDLE_VALUE {
84            return Err(WraithError::from_last_error("CreateToolhelp32Snapshot"));
85        }
86
87        Ok(Self {
88            snapshot,
89            first: true,
90            target_pid: pid,
91        })
92    }
93}
94
95impl Iterator for ThreadIterator {
96    type Item = ThreadEntry;
97
98    fn next(&mut self) -> Option<Self::Item> {
99        let mut entry = ThreadEntry32 {
100            size: core::mem::size_of::<ThreadEntry32>() as u32,
101            ..Default::default()
102        };
103
104        loop {
105            // SAFETY: snapshot is valid, entry is correctly sized
106            let success = if self.first {
107                self.first = false;
108                unsafe { Thread32First(self.snapshot, &mut entry) }
109            } else {
110                unsafe { Thread32Next(self.snapshot, &mut entry) }
111            };
112
113            if success == 0 {
114                return None;
115            }
116
117            // filter to target process
118            if entry.owner_process_id == self.target_pid {
119                return Some(ThreadEntry {
120                    thread_id: entry.thread_id,
121                    owner_process_id: entry.owner_process_id,
122                    base_priority: entry.base_priority,
123                });
124            }
125        }
126    }
127}
128
129impl Drop for ThreadIterator {
130    fn drop(&mut self) {
131        if self.snapshot != INVALID_HANDLE_VALUE {
132            // SAFETY: snapshot is valid handle
133            unsafe {
134                CloseHandle(self.snapshot);
135            }
136        }
137    }
138}
139
140/// thread entry from enumeration
141#[derive(Debug, Clone)]
142pub struct ThreadEntry {
143    pub thread_id: u32,
144    pub owner_process_id: u32,
145    pub base_priority: i32,
146}
147
148// internal structures for toolhelp
149#[repr(C)]
150#[derive(Default)]
151struct ThreadEntry32 {
152    size: u32,
153    usage: u32,
154    thread_id: u32,
155    owner_process_id: u32,
156    base_priority: i32,
157    delta_priority: i32,
158    flags: u32,
159}
160
161const TH32CS_SNAPTHREAD: u32 = 0x00000004;
162const INVALID_HANDLE_VALUE: *mut core::ffi::c_void = -1isize as *mut _;
163
164#[link(name = "kernel32")]
165extern "system" {
166    fn CreateToolhelp32Snapshot(flags: u32, process_id: u32) -> *mut core::ffi::c_void;
167    fn Thread32First(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
168    fn Thread32Next(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
169    fn CloseHandle(handle: *mut core::ffi::c_void) -> i32;
170}
171
172/// get list of all thread IDs in current process
173pub fn get_thread_ids() -> Result<Vec<u32>> {
174    Ok(ThreadIterator::new()?.map(|t| t.thread_id).collect())
175}
176
177/// count threads in current process
178pub fn thread_count() -> Result<usize> {
179    Ok(ThreadIterator::new()?.count())
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_current_thread() {
188        let info = ThreadInfo::current().expect("should get thread info");
189        assert!(info.thread_id > 0);
190        assert!(info.stack_base > info.stack_limit);
191    }
192
193    #[test]
194    fn test_thread_iterator() {
195        let threads: Vec<_> = ThreadIterator::new()
196            .expect("should create iterator")
197            .collect();
198
199        // should have at least one thread (ourselves)
200        assert!(!threads.is_empty());
201    }
202
203    #[test]
204    fn test_get_thread_ids() {
205        let ids = get_thread_ids().expect("should get thread ids");
206        assert!(!ids.is_empty());
207
208        // current thread should be in the list
209        let current = ThreadInfo::current().expect("should get current thread");
210        assert!(ids.contains(&current.thread_id));
211    }
212
213    #[test]
214    fn test_thread_count() {
215        let count = thread_count().expect("should get thread count");
216        assert!(count >= 1);
217    }
218}