wraith/navigation/
thread_iter.rs

1//! Thread enumeration
2
3use crate::arch::segment;
4use crate::error::{Result, WraithError};
5
6/// information about the current thread from TEB
7#[derive(Debug, Clone)]
8pub struct ThreadInfo {
9    pub thread_id: u32,
10    pub teb_address: usize,
11    pub stack_base: usize,
12    pub stack_limit: usize,
13}
14
15impl ThreadInfo {
16    /// get info for current thread from TEB
17    pub fn current() -> Result<Self> {
18        // SAFETY: get_teb always returns valid TEB for current thread
19        let teb = unsafe { segment::get_teb() };
20        if teb.is_null() {
21            return Err(WraithError::InvalidTebAccess);
22        }
23
24        #[cfg(target_arch = "x86_64")]
25        let (tid, stack_base, stack_limit) = {
26            // SAFETY: TEB offsets are well-known for x64
27            let tid = unsafe { segment::get_current_tid() };
28            let stack_base = unsafe { *(teb.add(0x08) as *const u64) } as usize;
29            let stack_limit = unsafe { *(teb.add(0x10) as *const u64) } as usize;
30            (tid, stack_base, stack_limit)
31        };
32
33        #[cfg(target_arch = "x86")]
34        let (tid, stack_base, stack_limit) = {
35            // SAFETY: TEB offsets are well-known for x86
36            let tid = unsafe { segment::get_current_tid() };
37            let stack_base = unsafe { *(teb.add(0x04) as *const u32) } as usize;
38            let stack_limit = unsafe { *(teb.add(0x08) as *const u32) } as usize;
39            (tid, stack_base, stack_limit)
40        };
41
42        Ok(Self {
43            thread_id: tid,
44            teb_address: teb as usize,
45            stack_base,
46            stack_limit,
47        })
48    }
49
50    /// check if an address is on this thread's stack
51    pub fn is_on_stack(&self, address: usize) -> bool {
52        // stack grows downward: limit < address < base
53        address >= self.stack_limit && address < self.stack_base
54    }
55}
56
57/// iterator over threads in current process
58pub struct ThreadIterator {
59    snapshot: *mut core::ffi::c_void,
60    first: bool,
61    target_pid: u32,
62}
63
64impl ThreadIterator {
65    /// create new thread iterator for current process
66    pub fn new() -> Result<Self> {
67        // SAFETY: get_current_pid always returns valid PID
68        let current_pid = unsafe { segment::get_current_pid() };
69        Self::for_process(current_pid)
70    }
71
72    /// enumerate threads for a specific process
73    pub fn for_process(pid: u32) -> Result<Self> {
74        // SAFETY: CreateToolhelp32Snapshot is safe to call with valid flags
75        let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) };
76
77        if snapshot == INVALID_HANDLE_VALUE {
78            return Err(WraithError::from_last_error("CreateToolhelp32Snapshot"));
79        }
80
81        Ok(Self {
82            snapshot,
83            first: true,
84            target_pid: pid,
85        })
86    }
87}
88
89impl Iterator for ThreadIterator {
90    type Item = ThreadEntry;
91
92    fn next(&mut self) -> Option<Self::Item> {
93        let mut entry = ThreadEntry32 {
94            size: core::mem::size_of::<ThreadEntry32>() as u32,
95            ..Default::default()
96        };
97
98        loop {
99            // SAFETY: snapshot is valid, entry is correctly sized
100            let success = if self.first {
101                self.first = false;
102                unsafe { Thread32First(self.snapshot, &mut entry) }
103            } else {
104                unsafe { Thread32Next(self.snapshot, &mut entry) }
105            };
106
107            if success == 0 {
108                return None;
109            }
110
111            // filter to target process
112            if entry.owner_process_id == self.target_pid {
113                return Some(ThreadEntry {
114                    thread_id: entry.thread_id,
115                    owner_process_id: entry.owner_process_id,
116                    base_priority: entry.base_priority,
117                });
118            }
119        }
120    }
121}
122
123impl Drop for ThreadIterator {
124    fn drop(&mut self) {
125        if self.snapshot != INVALID_HANDLE_VALUE {
126            // SAFETY: snapshot is valid handle
127            unsafe {
128                CloseHandle(self.snapshot);
129            }
130        }
131    }
132}
133
134/// thread entry from enumeration
135#[derive(Debug, Clone)]
136pub struct ThreadEntry {
137    pub thread_id: u32,
138    pub owner_process_id: u32,
139    pub base_priority: i32,
140}
141
142// internal structures for toolhelp
143#[repr(C)]
144#[derive(Default)]
145struct ThreadEntry32 {
146    size: u32,
147    usage: u32,
148    thread_id: u32,
149    owner_process_id: u32,
150    base_priority: i32,
151    delta_priority: i32,
152    flags: u32,
153}
154
155const TH32CS_SNAPTHREAD: u32 = 0x00000004;
156const INVALID_HANDLE_VALUE: *mut core::ffi::c_void = -1isize as *mut _;
157
158#[link(name = "kernel32")]
159extern "system" {
160    fn CreateToolhelp32Snapshot(flags: u32, process_id: u32) -> *mut core::ffi::c_void;
161    fn Thread32First(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
162    fn Thread32Next(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
163    fn CloseHandle(handle: *mut core::ffi::c_void) -> i32;
164}
165
166/// get list of all thread IDs in current process
167pub fn get_thread_ids() -> Result<Vec<u32>> {
168    Ok(ThreadIterator::new()?.map(|t| t.thread_id).collect())
169}
170
171/// count threads in current process
172pub fn thread_count() -> Result<usize> {
173    Ok(ThreadIterator::new()?.count())
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_current_thread() {
182        let info = ThreadInfo::current().expect("should get thread info");
183        assert!(info.thread_id > 0);
184        assert!(info.stack_base > info.stack_limit);
185    }
186
187    #[test]
188    fn test_thread_iterator() {
189        let threads: Vec<_> = ThreadIterator::new()
190            .expect("should create iterator")
191            .collect();
192
193        // should have at least one thread (ourselves)
194        assert!(!threads.is_empty());
195    }
196
197    #[test]
198    fn test_get_thread_ids() {
199        let ids = get_thread_ids().expect("should get thread ids");
200        assert!(!ids.is_empty());
201
202        // current thread should be in the list
203        let current = ThreadInfo::current().expect("should get current thread");
204        assert!(ids.contains(&current.thread_id));
205    }
206
207    #[test]
208    fn test_thread_count() {
209        let count = thread_count().expect("should get thread count");
210        assert!(count >= 1);
211    }
212}