wraith/manipulation/spoof/
gadget.rs

1//! Gadget finder for legitimate return addresses
2//!
3//! Scans system modules (ntdll, kernel32, etc.) for code gadgets that can be
4//! used as legitimate-looking return addresses for syscall spoofing.
5
6use crate::error::{Result, WraithError};
7use crate::navigation::ModuleQuery;
8use crate::structures::Peb;
9use std::collections::HashMap;
10use std::sync::OnceLock;
11
12/// global gadget cache
13static GADGET_CACHE: OnceLock<Result<GadgetCache>> = OnceLock::new();
14
15/// initialize the global gadget cache
16pub fn init_global_cache() -> Result<()> {
17    let result = GADGET_CACHE.get_or_init(GadgetCache::build);
18    match result {
19        Ok(_) => Ok(()),
20        Err(e) => Err(WraithError::SyscallEnumerationFailed {
21            reason: format!("failed to build gadget cache: {}", e),
22        }),
23    }
24}
25
26/// get global gadget cache reference
27pub fn get_global_cache() -> Result<&'static GadgetCache> {
28    let result = GADGET_CACHE.get_or_init(GadgetCache::build);
29    match result {
30        Ok(cache) => Ok(cache),
31        Err(e) => Err(WraithError::SyscallEnumerationFailed {
32            reason: format!("failed to get gadget cache: {}", e),
33        }),
34    }
35}
36
37/// type of gadget instruction sequence
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum GadgetType {
40    /// jmp rbx - jump to address in rbx
41    JmpRbx,
42    /// jmp rax - jump to address in rax
43    JmpRax,
44    /// jmp rcx - jump to address in rcx
45    JmpRcx,
46    /// jmp rdx - jump to address in rdx
47    JmpRdx,
48    /// jmp r8 - jump to address in r8
49    JmpR8,
50    /// jmp r9 - jump to address in r9
51    JmpR9,
52    /// jmp [rbx] - indirect jump through rbx
53    JmpIndirectRbx,
54    /// jmp [rax] - indirect jump through rax
55    JmpIndirectRax,
56    /// call rbx - call address in rbx
57    CallRbx,
58    /// call rax - call address in rax
59    CallRax,
60    /// ret - simple return
61    Ret,
62    /// add rsp, N; ret - stack cleanup before return
63    AddRspRet { offset: u8 },
64    /// pop reg; ret - pop register then return
65    PopRet { register: u8 },
66    /// push rbx; ret - push rbx onto stack and return (for setting up returns)
67    PushRbxRet,
68}
69
70impl GadgetType {
71    /// get the bytes that make up this gadget type
72    #[cfg(target_arch = "x86_64")]
73    pub fn bytes(&self) -> &'static [u8] {
74        match self {
75            Self::JmpRbx => &[0xFF, 0xE3],        // jmp rbx
76            Self::JmpRax => &[0xFF, 0xE0],        // jmp rax
77            Self::JmpRcx => &[0xFF, 0xE1],        // jmp rcx
78            Self::JmpRdx => &[0xFF, 0xE2],        // jmp rdx
79            Self::JmpR8 => &[0x41, 0xFF, 0xE0],   // jmp r8
80            Self::JmpR9 => &[0x41, 0xFF, 0xE1],   // jmp r9
81            Self::JmpIndirectRbx => &[0xFF, 0x23], // jmp [rbx]
82            Self::JmpIndirectRax => &[0xFF, 0x20], // jmp [rax]
83            Self::CallRbx => &[0xFF, 0xD3],       // call rbx
84            Self::CallRax => &[0xFF, 0xD0],       // call rax
85            Self::Ret => &[0xC3],                 // ret
86            Self::AddRspRet { .. } => &[],        // variable, handled separately
87            Self::PopRet { .. } => &[],           // variable, handled separately
88            Self::PushRbxRet => &[0x53, 0xC3],    // push rbx; ret
89        }
90    }
91
92    /// get friendly name for this gadget type
93    pub fn name(&self) -> &'static str {
94        match self {
95            Self::JmpRbx => "jmp rbx",
96            Self::JmpRax => "jmp rax",
97            Self::JmpRcx => "jmp rcx",
98            Self::JmpRdx => "jmp rdx",
99            Self::JmpR8 => "jmp r8",
100            Self::JmpR9 => "jmp r9",
101            Self::JmpIndirectRbx => "jmp [rbx]",
102            Self::JmpIndirectRax => "jmp [rax]",
103            Self::CallRbx => "call rbx",
104            Self::CallRax => "call rax",
105            Self::Ret => "ret",
106            Self::AddRspRet { offset: _ } => "add rsp, N; ret",
107            Self::PopRet { .. } => "pop reg; ret",
108            Self::PushRbxRet => "push rbx; ret",
109        }
110    }
111}
112
113/// a found gadget with its location and type
114#[derive(Debug, Clone)]
115pub struct Gadget {
116    /// absolute address of the gadget
117    pub address: usize,
118    /// type of gadget
119    pub gadget_type: GadgetType,
120    /// module containing this gadget
121    pub module_name: String,
122    /// offset within the module
123    pub module_offset: usize,
124    /// is this in a system module (more trustworthy)
125    pub is_system_module: bool,
126}
127
128impl Gadget {
129    /// check if gadget is still valid (bytes haven't changed)
130    pub fn is_valid(&self) -> bool {
131        let bytes = self.gadget_type.bytes();
132        if bytes.is_empty() {
133            return true; // variable-length gadgets need special handling
134        }
135
136        // SAFETY: we're reading from a previously-validated code address
137        let actual = unsafe { std::slice::from_raw_parts(self.address as *const u8, bytes.len()) };
138        actual == bytes
139    }
140}
141
142/// jmp-type gadget (for jumping to syscall stub)
143#[derive(Debug, Clone)]
144pub struct JmpGadget {
145    pub gadget: Gadget,
146}
147
148impl JmpGadget {
149    pub fn address(&self) -> usize {
150        self.gadget.address
151    }
152}
153
154/// ret-type gadget (for return address spoofing)
155#[derive(Debug, Clone)]
156pub struct RetGadget {
157    pub gadget: Gadget,
158    /// number of bytes the ret pops (for add rsp, N; ret patterns)
159    pub stack_adjustment: usize,
160}
161
162impl RetGadget {
163    pub fn address(&self) -> usize {
164        self.gadget.address
165    }
166}
167
168/// cache of found gadgets organized by type and module
169#[derive(Debug)]
170pub struct GadgetCache {
171    /// gadgets indexed by type
172    by_type: HashMap<GadgetType, Vec<Gadget>>,
173    /// gadgets indexed by module name (lowercase)
174    by_module: HashMap<String, Vec<Gadget>>,
175    /// preferred jmp rbx gadget in ntdll
176    preferred_jmp_rbx: Option<Gadget>,
177    /// preferred jmp rax gadget in ntdll
178    preferred_jmp_rax: Option<Gadget>,
179    /// preferred ret gadget in kernel32
180    preferred_ret: Option<Gadget>,
181}
182
183impl GadgetCache {
184    /// build gadget cache by scanning system modules
185    pub fn build() -> Result<Self> {
186        let finder = GadgetFinder::new()?;
187
188        let mut by_type: HashMap<GadgetType, Vec<Gadget>> = HashMap::new();
189        let mut by_module: HashMap<String, Vec<Gadget>> = HashMap::new();
190
191        // scan key system modules
192        let modules = ["ntdll.dll", "kernel32.dll", "kernelbase.dll"];
193
194        for module_name in modules {
195            if let Ok(gadgets) = finder.scan_module_all(module_name) {
196                for gadget in gadgets {
197                    let module_lower = gadget.module_name.to_lowercase();
198
199                    by_type
200                        .entry(gadget.gadget_type)
201                        .or_default()
202                        .push(gadget.clone());
203
204                    by_module.entry(module_lower).or_default().push(gadget);
205                }
206            }
207        }
208
209        // find preferred gadgets
210        let preferred_jmp_rbx = by_type
211            .get(&GadgetType::JmpRbx)
212            .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
213            .cloned();
214
215        let preferred_jmp_rax = by_type
216            .get(&GadgetType::JmpRax)
217            .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
218            .cloned();
219
220        let preferred_ret = by_type
221            .get(&GadgetType::Ret)
222            .and_then(|v| {
223                v.iter()
224                    .find(|g| g.module_name.eq_ignore_ascii_case("kernel32.dll"))
225            })
226            .cloned();
227
228        Ok(Self {
229            by_type,
230            by_module,
231            preferred_jmp_rbx,
232            preferred_jmp_rax,
233            preferred_ret,
234        })
235    }
236
237    /// get preferred jmp rbx gadget (in ntdll)
238    pub fn jmp_rbx(&self) -> Option<&Gadget> {
239        self.preferred_jmp_rbx.as_ref()
240    }
241
242    /// get preferred jmp rax gadget (in ntdll)
243    pub fn jmp_rax(&self) -> Option<&Gadget> {
244        self.preferred_jmp_rax.as_ref()
245    }
246
247    /// get preferred ret gadget (in kernel32)
248    pub fn ret_gadget(&self) -> Option<&Gadget> {
249        self.preferred_ret.as_ref()
250    }
251
252    /// get all gadgets of a specific type
253    pub fn get_by_type(&self, gadget_type: GadgetType) -> &[Gadget] {
254        self.by_type.get(&gadget_type).map(|v| v.as_slice()).unwrap_or(&[])
255    }
256
257    /// get all gadgets in a specific module
258    pub fn get_by_module(&self, module_name: &str) -> &[Gadget] {
259        self.by_module
260            .get(&module_name.to_lowercase())
261            .map(|v| v.as_slice())
262            .unwrap_or(&[])
263    }
264
265    /// get first available jmp gadget (tries rbx, then rax)
266    pub fn any_jmp_gadget(&self) -> Option<&Gadget> {
267        self.preferred_jmp_rbx
268            .as_ref()
269            .or(self.preferred_jmp_rax.as_ref())
270            .or_else(|| {
271                self.by_type
272                    .get(&GadgetType::JmpRbx)
273                    .and_then(|v| v.first())
274            })
275            .or_else(|| {
276                self.by_type
277                    .get(&GadgetType::JmpRax)
278                    .and_then(|v| v.first())
279            })
280    }
281}
282
283/// scanner for finding gadgets in loaded modules
284pub struct GadgetFinder {
285    peb: Peb,
286}
287
288impl GadgetFinder {
289    /// create new gadget finder
290    pub fn new() -> Result<Self> {
291        Ok(Self {
292            peb: Peb::current()?,
293        })
294    }
295
296    /// find all jmp rbx gadgets in a module
297    pub fn find_jmp_rbx(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
298        self.find_gadgets_of_type(module_name, GadgetType::JmpRbx)
299            .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
300    }
301
302    /// find all jmp rax gadgets in a module
303    pub fn find_jmp_rax(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
304        self.find_gadgets_of_type(module_name, GadgetType::JmpRax)
305            .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
306    }
307
308    /// find all ret gadgets in a module
309    pub fn find_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
310        self.find_gadgets_of_type(module_name, GadgetType::Ret)
311            .map(|gadgets| {
312                gadgets
313                    .into_iter()
314                    .map(|g| RetGadget {
315                        gadget: g,
316                        stack_adjustment: 0,
317                    })
318                    .collect()
319            })
320    }
321
322    /// find gadgets of a specific type in a module
323    pub fn find_gadgets_of_type(
324        &self,
325        module_name: &str,
326        gadget_type: GadgetType,
327    ) -> Result<Vec<Gadget>> {
328        let query = ModuleQuery::new(&self.peb);
329        let module = query.find_by_name(module_name)?;
330
331        let bytes = gadget_type.bytes();
332        if bytes.is_empty() {
333            return Ok(Vec::new());
334        }
335
336        let base = module.base();
337        let size = module.size();
338        let name = module.name();
339        let is_system = is_system_module(&name);
340
341        // scan for the byte pattern
342        // SAFETY: module memory is mapped and readable
343        let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
344
345        let mut gadgets = Vec::new();
346        let pattern_len = bytes.len();
347
348        // scan for gadget bytes
349        for offset in 0..=(size.saturating_sub(pattern_len)) {
350            if &data[offset..offset + pattern_len] == bytes {
351                gadgets.push(Gadget {
352                    address: base + offset,
353                    gadget_type,
354                    module_name: name.clone(),
355                    module_offset: offset,
356                    is_system_module: is_system,
357                });
358            }
359        }
360
361        Ok(gadgets)
362    }
363
364    /// find add rsp, N; ret gadgets (for stack cleanup)
365    pub fn find_add_rsp_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
366        let query = ModuleQuery::new(&self.peb);
367        let module = query.find_by_name(module_name)?;
368
369        let base = module.base();
370        let size = module.size();
371        let name = module.name();
372        let is_system = is_system_module(&name);
373
374        // SAFETY: module memory is mapped and readable
375        let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
376
377        let mut gadgets = Vec::new();
378
379        // patterns for add rsp, imm8; ret
380        // 48 83 C4 XX C3 = add rsp, XX; ret (5 bytes)
381        for offset in 0..=(size.saturating_sub(5)) {
382            if data[offset] == 0x48
383                && data[offset + 1] == 0x83
384                && data[offset + 2] == 0xC4
385                && data[offset + 4] == 0xC3
386            {
387                let stack_adj = data[offset + 3] as usize;
388                gadgets.push(RetGadget {
389                    gadget: Gadget {
390                        address: base + offset,
391                        gadget_type: GadgetType::AddRspRet {
392                            offset: data[offset + 3],
393                        },
394                        module_name: name.clone(),
395                        module_offset: offset,
396                        is_system_module: is_system,
397                    },
398                    stack_adjustment: stack_adj,
399                });
400            }
401        }
402
403        // also look for add rsp, imm32; ret
404        // 48 81 C4 XX XX XX XX C3 = add rsp, XXXXXXXX; ret (8 bytes)
405        for offset in 0..=(size.saturating_sub(8)) {
406            if data[offset] == 0x48
407                && data[offset + 1] == 0x81
408                && data[offset + 2] == 0xC4
409                && data[offset + 7] == 0xC3
410            {
411                let stack_adj = u32::from_le_bytes([
412                    data[offset + 3],
413                    data[offset + 4],
414                    data[offset + 5],
415                    data[offset + 6],
416                ]) as usize;
417
418                gadgets.push(RetGadget {
419                    gadget: Gadget {
420                        address: base + offset,
421                        gadget_type: GadgetType::AddRspRet {
422                            offset: 0, // too large for u8
423                        },
424                        module_name: name.clone(),
425                        module_offset: offset,
426                        is_system_module: is_system,
427                    },
428                    stack_adjustment: stack_adj,
429                });
430            }
431        }
432
433        Ok(gadgets)
434    }
435
436    /// find pop reg; ret gadgets
437    pub fn find_pop_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
438        let query = ModuleQuery::new(&self.peb);
439        let module = query.find_by_name(module_name)?;
440
441        let base = module.base();
442        let size = module.size();
443        let name = module.name();
444        let is_system = is_system_module(&name);
445
446        // SAFETY: module memory is mapped and readable
447        let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
448
449        let mut gadgets = Vec::new();
450
451        // pop rax; ret = 58 C3
452        // pop rcx; ret = 59 C3
453        // pop rdx; ret = 5A C3
454        // pop rbx; ret = 5B C3
455        // pop rsp; ret = 5C C3 (dangerous, skip)
456        // pop rbp; ret = 5D C3
457        // pop rsi; ret = 5E C3
458        // pop rdi; ret = 5F C3
459        for offset in 0..=(size.saturating_sub(2)) {
460            let first = data[offset];
461            if (0x58..=0x5F).contains(&first) && first != 0x5C && data[offset + 1] == 0xC3 {
462                gadgets.push(RetGadget {
463                    gadget: Gadget {
464                        address: base + offset,
465                        gadget_type: GadgetType::PopRet {
466                            register: first - 0x58,
467                        },
468                        module_name: name.clone(),
469                        module_offset: offset,
470                        is_system_module: is_system,
471                    },
472                    stack_adjustment: 8, // one pop
473                });
474            }
475        }
476
477        // also check for REX.W pop; ret (pop r8-r15)
478        // 41 58 C3 = pop r8; ret
479        // 41 59 C3 = pop r9; ret
480        // etc.
481        for offset in 0..=(size.saturating_sub(3)) {
482            if data[offset] == 0x41
483                && (0x58..=0x5F).contains(&data[offset + 1])
484                && data[offset + 1] != 0x5C
485                && data[offset + 2] == 0xC3
486            {
487                gadgets.push(RetGadget {
488                    gadget: Gadget {
489                        address: base + offset,
490                        gadget_type: GadgetType::PopRet {
491                            register: data[offset + 1] - 0x58 + 8,
492                        },
493                        module_name: name.clone(),
494                        module_offset: offset,
495                        is_system_module: is_system,
496                    },
497                    stack_adjustment: 8,
498                });
499            }
500        }
501
502        Ok(gadgets)
503    }
504
505    /// scan a module for all gadget types
506    pub fn scan_module_all(&self, module_name: &str) -> Result<Vec<Gadget>> {
507        let mut all_gadgets = Vec::new();
508
509        // basic jmp gadgets
510        for gadget_type in [
511            GadgetType::JmpRbx,
512            GadgetType::JmpRax,
513            GadgetType::JmpRcx,
514            GadgetType::JmpRdx,
515            GadgetType::CallRbx,
516            GadgetType::CallRax,
517            GadgetType::Ret,
518            GadgetType::PushRbxRet,
519        ] {
520            if let Ok(gadgets) = self.find_gadgets_of_type(module_name, gadget_type) {
521                all_gadgets.extend(gadgets);
522            }
523        }
524
525        // add rsp, N; ret gadgets
526        if let Ok(ret_gadgets) = self.find_add_rsp_ret(module_name) {
527            all_gadgets.extend(ret_gadgets.into_iter().map(|r| r.gadget));
528        }
529
530        // pop; ret gadgets
531        if let Ok(pop_gadgets) = self.find_pop_ret(module_name) {
532            all_gadgets.extend(pop_gadgets.into_iter().map(|r| r.gadget));
533        }
534
535        Ok(all_gadgets)
536    }
537
538    /// find the best jmp gadget for syscall spoofing
539    /// prefers ntdll > kernelbase > kernel32
540    pub fn find_best_jmp_gadget(&self) -> Result<JmpGadget> {
541        // try ntdll first (most legitimate for syscalls)
542        if let Ok(gadgets) = self.find_jmp_rbx("ntdll.dll") {
543            if let Some(g) = gadgets.into_iter().next() {
544                return Ok(g);
545            }
546        }
547
548        if let Ok(gadgets) = self.find_jmp_rax("ntdll.dll") {
549            if let Some(g) = gadgets.into_iter().next() {
550                return Ok(g);
551            }
552        }
553
554        // try kernelbase
555        if let Ok(gadgets) = self.find_jmp_rbx("kernelbase.dll") {
556            if let Some(g) = gadgets.into_iter().next() {
557                return Ok(g);
558            }
559        }
560
561        // try kernel32
562        if let Ok(gadgets) = self.find_jmp_rbx("kernel32.dll") {
563            if let Some(g) = gadgets.into_iter().next() {
564                return Ok(g);
565            }
566        }
567
568        Err(WraithError::SyscallEnumerationFailed {
569            reason: "no suitable jmp gadget found".into(),
570        })
571    }
572
573    /// find a ret gadget that looks legitimate
574    pub fn find_best_ret_gadget(&self) -> Result<RetGadget> {
575        // prefer kernel32 ret gadgets (look like normal API returns)
576        if let Ok(gadgets) = self.find_ret("kernel32.dll") {
577            if let Some(g) = gadgets.into_iter().next() {
578                return Ok(g);
579            }
580        }
581
582        if let Ok(gadgets) = self.find_ret("kernelbase.dll") {
583            if let Some(g) = gadgets.into_iter().next() {
584                return Ok(g);
585            }
586        }
587
588        if let Ok(gadgets) = self.find_ret("ntdll.dll") {
589            if let Some(g) = gadgets.into_iter().next() {
590                return Ok(g);
591            }
592        }
593
594        Err(WraithError::SyscallEnumerationFailed {
595            reason: "no suitable ret gadget found".into(),
596        })
597    }
598}
599
600/// check if a module is a system module
601fn is_system_module(name: &str) -> bool {
602    let lower = name.to_lowercase();
603    lower == "ntdll.dll"
604        || lower == "kernel32.dll"
605        || lower == "kernelbase.dll"
606        || lower == "user32.dll"
607        || lower == "gdi32.dll"
608        || lower == "advapi32.dll"
609        || lower == "msvcrt.dll"
610        || lower == "ws2_32.dll"
611        || lower == "ole32.dll"
612        || lower == "combase.dll"
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    #[test]
620    fn test_find_jmp_rbx_ntdll() {
621        let finder = GadgetFinder::new().expect("should create finder");
622        let gadgets = finder.find_jmp_rbx("ntdll.dll").expect("should find gadgets");
623
624        // ntdll should have jmp rbx gadgets
625        assert!(!gadgets.is_empty(), "should find jmp rbx gadgets in ntdll");
626
627        // verify first gadget is valid
628        let first = &gadgets[0];
629        assert!(first.gadget.is_valid(), "gadget should be valid");
630        assert!(first.gadget.is_system_module, "should be system module");
631    }
632
633    #[test]
634    fn test_find_ret_gadgets() {
635        let finder = GadgetFinder::new().expect("should create finder");
636        let gadgets = finder.find_ret("kernel32.dll").expect("should find gadgets");
637
638        // kernel32 should have many ret gadgets
639        assert!(!gadgets.is_empty(), "should find ret gadgets in kernel32");
640
641        // verify gadget is valid
642        let first = &gadgets[0];
643        assert!(first.gadget.is_valid(), "gadget should be valid");
644    }
645
646    #[test]
647    fn test_find_add_rsp_ret() {
648        let finder = GadgetFinder::new().expect("should create finder");
649
650        if let Ok(gadgets) = finder.find_add_rsp_ret("ntdll.dll") {
651            // just check we can find them without crashing
652            for g in gadgets.iter().take(5) {
653                assert!(g.stack_adjustment > 0, "should have stack adjustment");
654            }
655        }
656    }
657
658    #[test]
659    fn test_gadget_cache() {
660        let cache = GadgetCache::build().expect("should build cache");
661
662        // should have found some gadgets
663        assert!(cache.jmp_rbx().is_some() || cache.jmp_rax().is_some());
664    }
665
666    #[test]
667    fn test_best_jmp_gadget() {
668        let finder = GadgetFinder::new().expect("should create finder");
669        let gadget = finder.find_best_jmp_gadget().expect("should find gadget");
670
671        assert!(gadget.gadget.is_valid(), "best gadget should be valid");
672    }
673}