wraith/manipulation/spoof/
gadget.rs

1//! Gadget finder for legitimate return addresses
2//!
3//! Scans system modules (ntdll, kernel32, etc.) for code gadgets that can be
4//! used as legitimate-looking return addresses for syscall spoofing.
5
6use crate::error::{Result, WraithError};
7use crate::navigation::ModuleQuery;
8use crate::structures::Peb;
9
10#[cfg(feature = "std")]
11use std::collections::HashMap;
12
13#[cfg(all(not(feature = "std"), feature = "alloc"))]
14use alloc::{collections::BTreeMap, format, string::String, vec::Vec};
15
16#[cfg(feature = "std")]
17use std::{format, string::String, vec::Vec};
18
19#[cfg(feature = "std")]
20use std::sync::OnceLock;
21
22#[cfg(feature = "std")]
23static GADGET_CACHE: OnceLock<Result<GadgetCache>> = OnceLock::new();
24
25/// initialize the global gadget cache
26#[cfg(feature = "std")]
27pub fn init_global_cache() -> Result<()> {
28    let result = GADGET_CACHE.get_or_init(GadgetCache::build);
29    match result {
30        Ok(_) => Ok(()),
31        Err(e) => Err(WraithError::SyscallEnumerationFailed {
32            reason: format!("failed to build gadget cache: {}", e),
33        }),
34    }
35}
36
37/// get global gadget cache reference
38#[cfg(feature = "std")]
39pub fn get_global_cache() -> Result<&'static GadgetCache> {
40    let result = GADGET_CACHE.get_or_init(GadgetCache::build);
41    match result {
42        Ok(cache) => Ok(cache),
43        Err(e) => Err(WraithError::SyscallEnumerationFailed {
44            reason: format!("failed to get gadget cache: {}", e),
45        }),
46    }
47}
48
49/// type of gadget instruction sequence
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum GadgetType {
52    /// jmp rbx - jump to address in rbx
53    JmpRbx,
54    /// jmp rax - jump to address in rax
55    JmpRax,
56    /// jmp rcx - jump to address in rcx
57    JmpRcx,
58    /// jmp rdx - jump to address in rdx
59    JmpRdx,
60    /// jmp r8 - jump to address in r8
61    JmpR8,
62    /// jmp r9 - jump to address in r9
63    JmpR9,
64    /// jmp [rbx] - indirect jump through rbx
65    JmpIndirectRbx,
66    /// jmp [rax] - indirect jump through rax
67    JmpIndirectRax,
68    /// call rbx - call address in rbx
69    CallRbx,
70    /// call rax - call address in rax
71    CallRax,
72    /// ret - simple return
73    Ret,
74    /// add rsp, N; ret - stack cleanup before return
75    AddRspRet { offset: u8 },
76    /// pop reg; ret - pop register then return
77    PopRet { register: u8 },
78    /// push rbx; ret - push rbx onto stack and return (for setting up returns)
79    PushRbxRet,
80}
81
82impl GadgetType {
83    /// get the bytes that make up this gadget type
84    #[cfg(target_arch = "x86_64")]
85    pub fn bytes(&self) -> &'static [u8] {
86        match self {
87            Self::JmpRbx => &[0xFF, 0xE3],        // jmp rbx
88            Self::JmpRax => &[0xFF, 0xE0],        // jmp rax
89            Self::JmpRcx => &[0xFF, 0xE1],        // jmp rcx
90            Self::JmpRdx => &[0xFF, 0xE2],        // jmp rdx
91            Self::JmpR8 => &[0x41, 0xFF, 0xE0],   // jmp r8
92            Self::JmpR9 => &[0x41, 0xFF, 0xE1],   // jmp r9
93            Self::JmpIndirectRbx => &[0xFF, 0x23], // jmp [rbx]
94            Self::JmpIndirectRax => &[0xFF, 0x20], // jmp [rax]
95            Self::CallRbx => &[0xFF, 0xD3],       // call rbx
96            Self::CallRax => &[0xFF, 0xD0],       // call rax
97            Self::Ret => &[0xC3],                 // ret
98            Self::AddRspRet { .. } => &[],        // variable, handled separately
99            Self::PopRet { .. } => &[],           // variable, handled separately
100            Self::PushRbxRet => &[0x53, 0xC3],    // push rbx; ret
101        }
102    }
103
104    /// get friendly name for this gadget type
105    pub fn name(&self) -> &'static str {
106        match self {
107            Self::JmpRbx => "jmp rbx",
108            Self::JmpRax => "jmp rax",
109            Self::JmpRcx => "jmp rcx",
110            Self::JmpRdx => "jmp rdx",
111            Self::JmpR8 => "jmp r8",
112            Self::JmpR9 => "jmp r9",
113            Self::JmpIndirectRbx => "jmp [rbx]",
114            Self::JmpIndirectRax => "jmp [rax]",
115            Self::CallRbx => "call rbx",
116            Self::CallRax => "call rax",
117            Self::Ret => "ret",
118            Self::AddRspRet { offset: _ } => "add rsp, N; ret",
119            Self::PopRet { .. } => "pop reg; ret",
120            Self::PushRbxRet => "push rbx; ret",
121        }
122    }
123}
124
125/// a found gadget with its location and type
126#[derive(Debug, Clone)]
127pub struct Gadget {
128    /// absolute address of the gadget
129    pub address: usize,
130    /// type of gadget
131    pub gadget_type: GadgetType,
132    /// module containing this gadget
133    pub module_name: String,
134    /// offset within the module
135    pub module_offset: usize,
136    /// is this in a system module (more trustworthy)
137    pub is_system_module: bool,
138}
139
140impl Gadget {
141    /// check if gadget is still valid (bytes haven't changed)
142    pub fn is_valid(&self) -> bool {
143        let bytes = self.gadget_type.bytes();
144        if bytes.is_empty() {
145            return true; // variable-length gadgets need special handling
146        }
147
148        // SAFETY: we're reading from a previously-validated code address
149        let actual = unsafe { core::slice::from_raw_parts(self.address as *const u8, bytes.len()) };
150        actual == bytes
151    }
152}
153
154/// jmp-type gadget (for jumping to syscall stub)
155#[derive(Debug, Clone)]
156pub struct JmpGadget {
157    pub gadget: Gadget,
158}
159
160impl JmpGadget {
161    pub fn address(&self) -> usize {
162        self.gadget.address
163    }
164}
165
166/// ret-type gadget (for return address spoofing)
167#[derive(Debug, Clone)]
168pub struct RetGadget {
169    pub gadget: Gadget,
170    /// number of bytes the ret pops (for add rsp, N; ret patterns)
171    pub stack_adjustment: usize,
172}
173
174impl RetGadget {
175    pub fn address(&self) -> usize {
176        self.gadget.address
177    }
178}
179
180/// cache of found gadgets organized by type and module
181#[cfg(feature = "std")]
182#[derive(Debug)]
183pub struct GadgetCache {
184    /// gadgets indexed by type
185    by_type: HashMap<GadgetType, Vec<Gadget>>,
186    /// gadgets indexed by module name (lowercase)
187    by_module: HashMap<String, Vec<Gadget>>,
188    /// preferred jmp rbx gadget in ntdll
189    preferred_jmp_rbx: Option<Gadget>,
190    /// preferred jmp rax gadget in ntdll
191    preferred_jmp_rax: Option<Gadget>,
192    /// preferred ret gadget in kernel32
193    preferred_ret: Option<Gadget>,
194}
195
196#[cfg(feature = "std")]
197impl GadgetCache {
198    /// build gadget cache by scanning system modules
199    pub fn build() -> Result<Self> {
200        let finder = GadgetFinder::new()?;
201
202        let mut by_type: HashMap<GadgetType, Vec<Gadget>> = HashMap::new();
203        let mut by_module: HashMap<String, Vec<Gadget>> = HashMap::new();
204
205        // scan key system modules
206        let modules = ["ntdll.dll", "kernel32.dll", "kernelbase.dll"];
207
208        for module_name in modules {
209            if let Ok(gadgets) = finder.scan_module_all(module_name) {
210                for gadget in gadgets {
211                    let module_lower = gadget.module_name.to_lowercase();
212
213                    by_type
214                        .entry(gadget.gadget_type)
215                        .or_default()
216                        .push(gadget.clone());
217
218                    by_module.entry(module_lower).or_default().push(gadget);
219                }
220            }
221        }
222
223        // find preferred gadgets
224        let preferred_jmp_rbx = by_type
225            .get(&GadgetType::JmpRbx)
226            .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
227            .cloned();
228
229        let preferred_jmp_rax = by_type
230            .get(&GadgetType::JmpRax)
231            .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
232            .cloned();
233
234        let preferred_ret = by_type
235            .get(&GadgetType::Ret)
236            .and_then(|v| {
237                v.iter()
238                    .find(|g| g.module_name.eq_ignore_ascii_case("kernel32.dll"))
239            })
240            .cloned();
241
242        Ok(Self {
243            by_type,
244            by_module,
245            preferred_jmp_rbx,
246            preferred_jmp_rax,
247            preferred_ret,
248        })
249    }
250
251    /// get preferred jmp rbx gadget (in ntdll)
252    pub fn jmp_rbx(&self) -> Option<&Gadget> {
253        self.preferred_jmp_rbx.as_ref()
254    }
255
256    /// get preferred jmp rax gadget (in ntdll)
257    pub fn jmp_rax(&self) -> Option<&Gadget> {
258        self.preferred_jmp_rax.as_ref()
259    }
260
261    /// get preferred ret gadget (in kernel32)
262    pub fn ret_gadget(&self) -> Option<&Gadget> {
263        self.preferred_ret.as_ref()
264    }
265
266    /// get all gadgets of a specific type
267    pub fn get_by_type(&self, gadget_type: GadgetType) -> &[Gadget] {
268        self.by_type.get(&gadget_type).map(|v| v.as_slice()).unwrap_or(&[])
269    }
270
271    /// get all gadgets in a specific module
272    pub fn get_by_module(&self, module_name: &str) -> &[Gadget] {
273        self.by_module
274            .get(&module_name.to_lowercase())
275            .map(|v| v.as_slice())
276            .unwrap_or(&[])
277    }
278
279    /// get first available jmp gadget (tries rbx, then rax)
280    pub fn any_jmp_gadget(&self) -> Option<&Gadget> {
281        self.preferred_jmp_rbx
282            .as_ref()
283            .or(self.preferred_jmp_rax.as_ref())
284            .or_else(|| {
285                self.by_type
286                    .get(&GadgetType::JmpRbx)
287                    .and_then(|v| v.first())
288            })
289            .or_else(|| {
290                self.by_type
291                    .get(&GadgetType::JmpRax)
292                    .and_then(|v| v.first())
293            })
294    }
295}
296
297/// scanner for finding gadgets in loaded modules
298pub struct GadgetFinder {
299    peb: Peb,
300}
301
302impl GadgetFinder {
303    /// create new gadget finder
304    pub fn new() -> Result<Self> {
305        Ok(Self {
306            peb: Peb::current()?,
307        })
308    }
309
310    /// find all jmp rbx gadgets in a module
311    pub fn find_jmp_rbx(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
312        self.find_gadgets_of_type(module_name, GadgetType::JmpRbx)
313            .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
314    }
315
316    /// find all jmp rax gadgets in a module
317    pub fn find_jmp_rax(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
318        self.find_gadgets_of_type(module_name, GadgetType::JmpRax)
319            .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
320    }
321
322    /// find all ret gadgets in a module
323    pub fn find_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
324        self.find_gadgets_of_type(module_name, GadgetType::Ret)
325            .map(|gadgets| {
326                gadgets
327                    .into_iter()
328                    .map(|g| RetGadget {
329                        gadget: g,
330                        stack_adjustment: 0,
331                    })
332                    .collect()
333            })
334    }
335
336    /// find gadgets of a specific type in a module
337    pub fn find_gadgets_of_type(
338        &self,
339        module_name: &str,
340        gadget_type: GadgetType,
341    ) -> Result<Vec<Gadget>> {
342        let query = ModuleQuery::new(&self.peb);
343        let module = query.find_by_name(module_name)?;
344
345        let bytes = gadget_type.bytes();
346        if bytes.is_empty() {
347            return Ok(Vec::new());
348        }
349
350        let base = module.base();
351        let size = module.size();
352        let name = module.name();
353        let is_system = is_system_module(&name);
354
355        // scan for the byte pattern
356        // SAFETY: module memory is mapped and readable
357        let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
358
359        let mut gadgets = Vec::new();
360        let pattern_len = bytes.len();
361
362        // scan for gadget bytes
363        for offset in 0..=(size.saturating_sub(pattern_len)) {
364            if &data[offset..offset + pattern_len] == bytes {
365                gadgets.push(Gadget {
366                    address: base + offset,
367                    gadget_type,
368                    module_name: name.clone(),
369                    module_offset: offset,
370                    is_system_module: is_system,
371                });
372            }
373        }
374
375        Ok(gadgets)
376    }
377
378    /// find add rsp, N; ret gadgets (for stack cleanup)
379    pub fn find_add_rsp_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
380        let query = ModuleQuery::new(&self.peb);
381        let module = query.find_by_name(module_name)?;
382
383        let base = module.base();
384        let size = module.size();
385        let name = module.name();
386        let is_system = is_system_module(&name);
387
388        // SAFETY: module memory is mapped and readable
389        let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
390
391        let mut gadgets = Vec::new();
392
393        // patterns for add rsp, imm8; ret
394        // 48 83 C4 XX C3 = add rsp, XX; ret (5 bytes)
395        for offset in 0..=(size.saturating_sub(5)) {
396            if data[offset] == 0x48
397                && data[offset + 1] == 0x83
398                && data[offset + 2] == 0xC4
399                && data[offset + 4] == 0xC3
400            {
401                let stack_adj = data[offset + 3] as usize;
402                gadgets.push(RetGadget {
403                    gadget: Gadget {
404                        address: base + offset,
405                        gadget_type: GadgetType::AddRspRet {
406                            offset: data[offset + 3],
407                        },
408                        module_name: name.clone(),
409                        module_offset: offset,
410                        is_system_module: is_system,
411                    },
412                    stack_adjustment: stack_adj,
413                });
414            }
415        }
416
417        // also look for add rsp, imm32; ret
418        // 48 81 C4 XX XX XX XX C3 = add rsp, XXXXXXXX; ret (8 bytes)
419        for offset in 0..=(size.saturating_sub(8)) {
420            if data[offset] == 0x48
421                && data[offset + 1] == 0x81
422                && data[offset + 2] == 0xC4
423                && data[offset + 7] == 0xC3
424            {
425                let stack_adj = u32::from_le_bytes([
426                    data[offset + 3],
427                    data[offset + 4],
428                    data[offset + 5],
429                    data[offset + 6],
430                ]) as usize;
431
432                gadgets.push(RetGadget {
433                    gadget: Gadget {
434                        address: base + offset,
435                        gadget_type: GadgetType::AddRspRet {
436                            offset: 0, // too large for u8
437                        },
438                        module_name: name.clone(),
439                        module_offset: offset,
440                        is_system_module: is_system,
441                    },
442                    stack_adjustment: stack_adj,
443                });
444            }
445        }
446
447        Ok(gadgets)
448    }
449
450    /// find pop reg; ret gadgets
451    pub fn find_pop_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
452        let query = ModuleQuery::new(&self.peb);
453        let module = query.find_by_name(module_name)?;
454
455        let base = module.base();
456        let size = module.size();
457        let name = module.name();
458        let is_system = is_system_module(&name);
459
460        // SAFETY: module memory is mapped and readable
461        let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
462
463        let mut gadgets = Vec::new();
464
465        // pop rax; ret = 58 C3
466        // pop rcx; ret = 59 C3
467        // pop rdx; ret = 5A C3
468        // pop rbx; ret = 5B C3
469        // pop rsp; ret = 5C C3 (dangerous, skip)
470        // pop rbp; ret = 5D C3
471        // pop rsi; ret = 5E C3
472        // pop rdi; ret = 5F C3
473        for offset in 0..=(size.saturating_sub(2)) {
474            let first = data[offset];
475            if (0x58..=0x5F).contains(&first) && first != 0x5C && data[offset + 1] == 0xC3 {
476                gadgets.push(RetGadget {
477                    gadget: Gadget {
478                        address: base + offset,
479                        gadget_type: GadgetType::PopRet {
480                            register: first - 0x58,
481                        },
482                        module_name: name.clone(),
483                        module_offset: offset,
484                        is_system_module: is_system,
485                    },
486                    stack_adjustment: 8, // one pop
487                });
488            }
489        }
490
491        // also check for REX.W pop; ret (pop r8-r15)
492        // 41 58 C3 = pop r8; ret
493        // 41 59 C3 = pop r9; ret
494        // etc.
495        for offset in 0..=(size.saturating_sub(3)) {
496            if data[offset] == 0x41
497                && (0x58..=0x5F).contains(&data[offset + 1])
498                && data[offset + 1] != 0x5C
499                && data[offset + 2] == 0xC3
500            {
501                gadgets.push(RetGadget {
502                    gadget: Gadget {
503                        address: base + offset,
504                        gadget_type: GadgetType::PopRet {
505                            register: data[offset + 1] - 0x58 + 8,
506                        },
507                        module_name: name.clone(),
508                        module_offset: offset,
509                        is_system_module: is_system,
510                    },
511                    stack_adjustment: 8,
512                });
513            }
514        }
515
516        Ok(gadgets)
517    }
518
519    /// scan a module for all gadget types
520    pub fn scan_module_all(&self, module_name: &str) -> Result<Vec<Gadget>> {
521        let mut all_gadgets = Vec::new();
522
523        // basic jmp gadgets
524        for gadget_type in [
525            GadgetType::JmpRbx,
526            GadgetType::JmpRax,
527            GadgetType::JmpRcx,
528            GadgetType::JmpRdx,
529            GadgetType::CallRbx,
530            GadgetType::CallRax,
531            GadgetType::Ret,
532            GadgetType::PushRbxRet,
533        ] {
534            if let Ok(gadgets) = self.find_gadgets_of_type(module_name, gadget_type) {
535                all_gadgets.extend(gadgets);
536            }
537        }
538
539        // add rsp, N; ret gadgets
540        if let Ok(ret_gadgets) = self.find_add_rsp_ret(module_name) {
541            all_gadgets.extend(ret_gadgets.into_iter().map(|r| r.gadget));
542        }
543
544        // pop; ret gadgets
545        if let Ok(pop_gadgets) = self.find_pop_ret(module_name) {
546            all_gadgets.extend(pop_gadgets.into_iter().map(|r| r.gadget));
547        }
548
549        Ok(all_gadgets)
550    }
551
552    /// find the best jmp gadget for syscall spoofing
553    /// prefers ntdll > kernelbase > kernel32
554    pub fn find_best_jmp_gadget(&self) -> Result<JmpGadget> {
555        // try ntdll first (most legitimate for syscalls)
556        if let Ok(gadgets) = self.find_jmp_rbx("ntdll.dll") {
557            if let Some(g) = gadgets.into_iter().next() {
558                return Ok(g);
559            }
560        }
561
562        if let Ok(gadgets) = self.find_jmp_rax("ntdll.dll") {
563            if let Some(g) = gadgets.into_iter().next() {
564                return Ok(g);
565            }
566        }
567
568        // try kernelbase
569        if let Ok(gadgets) = self.find_jmp_rbx("kernelbase.dll") {
570            if let Some(g) = gadgets.into_iter().next() {
571                return Ok(g);
572            }
573        }
574
575        // try kernel32
576        if let Ok(gadgets) = self.find_jmp_rbx("kernel32.dll") {
577            if let Some(g) = gadgets.into_iter().next() {
578                return Ok(g);
579            }
580        }
581
582        Err(WraithError::SyscallEnumerationFailed {
583            reason: "no suitable jmp gadget found".into(),
584        })
585    }
586
587    /// find a ret gadget that looks legitimate
588    pub fn find_best_ret_gadget(&self) -> Result<RetGadget> {
589        // prefer kernel32 ret gadgets (look like normal API returns)
590        if let Ok(gadgets) = self.find_ret("kernel32.dll") {
591            if let Some(g) = gadgets.into_iter().next() {
592                return Ok(g);
593            }
594        }
595
596        if let Ok(gadgets) = self.find_ret("kernelbase.dll") {
597            if let Some(g) = gadgets.into_iter().next() {
598                return Ok(g);
599            }
600        }
601
602        if let Ok(gadgets) = self.find_ret("ntdll.dll") {
603            if let Some(g) = gadgets.into_iter().next() {
604                return Ok(g);
605            }
606        }
607
608        Err(WraithError::SyscallEnumerationFailed {
609            reason: "no suitable ret gadget found".into(),
610        })
611    }
612}
613
614/// check if a module is a system module
615fn is_system_module(name: &str) -> bool {
616    let lower = name.to_lowercase();
617    lower == "ntdll.dll"
618        || lower == "kernel32.dll"
619        || lower == "kernelbase.dll"
620        || lower == "user32.dll"
621        || lower == "gdi32.dll"
622        || lower == "advapi32.dll"
623        || lower == "msvcrt.dll"
624        || lower == "ws2_32.dll"
625        || lower == "ole32.dll"
626        || lower == "combase.dll"
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_find_jmp_rbx_ntdll() {
635        let finder = GadgetFinder::new().expect("should create finder");
636        let gadgets = finder.find_jmp_rbx("ntdll.dll").expect("should find gadgets");
637
638        // ntdll should have jmp rbx gadgets
639        assert!(!gadgets.is_empty(), "should find jmp rbx gadgets in ntdll");
640
641        // verify first gadget is valid
642        let first = &gadgets[0];
643        assert!(first.gadget.is_valid(), "gadget should be valid");
644        assert!(first.gadget.is_system_module, "should be system module");
645    }
646
647    #[test]
648    fn test_find_ret_gadgets() {
649        let finder = GadgetFinder::new().expect("should create finder");
650        let gadgets = finder.find_ret("kernel32.dll").expect("should find gadgets");
651
652        // kernel32 should have many ret gadgets
653        assert!(!gadgets.is_empty(), "should find ret gadgets in kernel32");
654
655        // verify gadget is valid
656        let first = &gadgets[0];
657        assert!(first.gadget.is_valid(), "gadget should be valid");
658    }
659
660    #[test]
661    fn test_find_add_rsp_ret() {
662        let finder = GadgetFinder::new().expect("should create finder");
663
664        if let Ok(gadgets) = finder.find_add_rsp_ret("ntdll.dll") {
665            // just check we can find them without crashing
666            for g in gadgets.iter().take(5) {
667                assert!(g.stack_adjustment > 0, "should have stack adjustment");
668            }
669        }
670    }
671
672    #[test]
673    fn test_gadget_cache() {
674        let cache = GadgetCache::build().expect("should build cache");
675
676        // should have found some gadgets
677        assert!(cache.jmp_rbx().is_some() || cache.jmp_rax().is_some());
678    }
679
680    #[test]
681    fn test_best_jmp_gadget() {
682        let finder = GadgetFinder::new().expect("should create finder");
683        let gadget = finder.find_best_jmp_gadget().expect("should find gadget");
684
685        assert!(gadget.gadget.is_valid(), "best gadget should be valid");
686    }
687}