1use crate::error::{Result, WraithError};
7use crate::navigation::ModuleQuery;
8use crate::structures::Peb;
9
10#[cfg(feature = "std")]
11use std::collections::HashMap;
12
13#[cfg(all(not(feature = "std"), feature = "alloc"))]
14use alloc::{collections::BTreeMap, format, string::String, vec::Vec};
15
16#[cfg(feature = "std")]
17use std::{format, string::String, vec::Vec};
18
19#[cfg(feature = "std")]
20use std::sync::OnceLock;
21
22#[cfg(feature = "std")]
23static GADGET_CACHE: OnceLock<Result<GadgetCache>> = OnceLock::new();
24
25#[cfg(feature = "std")]
27pub fn init_global_cache() -> Result<()> {
28 let result = GADGET_CACHE.get_or_init(GadgetCache::build);
29 match result {
30 Ok(_) => Ok(()),
31 Err(e) => Err(WraithError::SyscallEnumerationFailed {
32 reason: format!("failed to build gadget cache: {}", e),
33 }),
34 }
35}
36
37#[cfg(feature = "std")]
39pub fn get_global_cache() -> Result<&'static GadgetCache> {
40 let result = GADGET_CACHE.get_or_init(GadgetCache::build);
41 match result {
42 Ok(cache) => Ok(cache),
43 Err(e) => Err(WraithError::SyscallEnumerationFailed {
44 reason: format!("failed to get gadget cache: {}", e),
45 }),
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum GadgetType {
52 JmpRbx,
54 JmpRax,
56 JmpRcx,
58 JmpRdx,
60 JmpR8,
62 JmpR9,
64 JmpIndirectRbx,
66 JmpIndirectRax,
68 CallRbx,
70 CallRax,
72 Ret,
74 AddRspRet { offset: u8 },
76 PopRet { register: u8 },
78 PushRbxRet,
80}
81
82impl GadgetType {
83 #[cfg(target_arch = "x86_64")]
85 pub fn bytes(&self) -> &'static [u8] {
86 match self {
87 Self::JmpRbx => &[0xFF, 0xE3], Self::JmpRax => &[0xFF, 0xE0], Self::JmpRcx => &[0xFF, 0xE1], Self::JmpRdx => &[0xFF, 0xE2], Self::JmpR8 => &[0x41, 0xFF, 0xE0], Self::JmpR9 => &[0x41, 0xFF, 0xE1], Self::JmpIndirectRbx => &[0xFF, 0x23], Self::JmpIndirectRax => &[0xFF, 0x20], Self::CallRbx => &[0xFF, 0xD3], Self::CallRax => &[0xFF, 0xD0], Self::Ret => &[0xC3], Self::AddRspRet { .. } => &[], Self::PopRet { .. } => &[], Self::PushRbxRet => &[0x53, 0xC3], }
102 }
103
104 pub fn name(&self) -> &'static str {
106 match self {
107 Self::JmpRbx => "jmp rbx",
108 Self::JmpRax => "jmp rax",
109 Self::JmpRcx => "jmp rcx",
110 Self::JmpRdx => "jmp rdx",
111 Self::JmpR8 => "jmp r8",
112 Self::JmpR9 => "jmp r9",
113 Self::JmpIndirectRbx => "jmp [rbx]",
114 Self::JmpIndirectRax => "jmp [rax]",
115 Self::CallRbx => "call rbx",
116 Self::CallRax => "call rax",
117 Self::Ret => "ret",
118 Self::AddRspRet { offset: _ } => "add rsp, N; ret",
119 Self::PopRet { .. } => "pop reg; ret",
120 Self::PushRbxRet => "push rbx; ret",
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct Gadget {
128 pub address: usize,
130 pub gadget_type: GadgetType,
132 pub module_name: String,
134 pub module_offset: usize,
136 pub is_system_module: bool,
138}
139
140impl Gadget {
141 pub fn is_valid(&self) -> bool {
143 let bytes = self.gadget_type.bytes();
144 if bytes.is_empty() {
145 return true; }
147
148 let actual = unsafe { core::slice::from_raw_parts(self.address as *const u8, bytes.len()) };
150 actual == bytes
151 }
152}
153
154#[derive(Debug, Clone)]
156pub struct JmpGadget {
157 pub gadget: Gadget,
158}
159
160impl JmpGadget {
161 pub fn address(&self) -> usize {
162 self.gadget.address
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct RetGadget {
169 pub gadget: Gadget,
170 pub stack_adjustment: usize,
172}
173
174impl RetGadget {
175 pub fn address(&self) -> usize {
176 self.gadget.address
177 }
178}
179
180#[cfg(feature = "std")]
182#[derive(Debug)]
183pub struct GadgetCache {
184 by_type: HashMap<GadgetType, Vec<Gadget>>,
186 by_module: HashMap<String, Vec<Gadget>>,
188 preferred_jmp_rbx: Option<Gadget>,
190 preferred_jmp_rax: Option<Gadget>,
192 preferred_ret: Option<Gadget>,
194}
195
196#[cfg(feature = "std")]
197impl GadgetCache {
198 pub fn build() -> Result<Self> {
200 let finder = GadgetFinder::new()?;
201
202 let mut by_type: HashMap<GadgetType, Vec<Gadget>> = HashMap::new();
203 let mut by_module: HashMap<String, Vec<Gadget>> = HashMap::new();
204
205 let modules = ["ntdll.dll", "kernel32.dll", "kernelbase.dll"];
207
208 for module_name in modules {
209 if let Ok(gadgets) = finder.scan_module_all(module_name) {
210 for gadget in gadgets {
211 let module_lower = gadget.module_name.to_lowercase();
212
213 by_type
214 .entry(gadget.gadget_type)
215 .or_default()
216 .push(gadget.clone());
217
218 by_module.entry(module_lower).or_default().push(gadget);
219 }
220 }
221 }
222
223 let preferred_jmp_rbx = by_type
225 .get(&GadgetType::JmpRbx)
226 .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
227 .cloned();
228
229 let preferred_jmp_rax = by_type
230 .get(&GadgetType::JmpRax)
231 .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
232 .cloned();
233
234 let preferred_ret = by_type
235 .get(&GadgetType::Ret)
236 .and_then(|v| {
237 v.iter()
238 .find(|g| g.module_name.eq_ignore_ascii_case("kernel32.dll"))
239 })
240 .cloned();
241
242 Ok(Self {
243 by_type,
244 by_module,
245 preferred_jmp_rbx,
246 preferred_jmp_rax,
247 preferred_ret,
248 })
249 }
250
251 pub fn jmp_rbx(&self) -> Option<&Gadget> {
253 self.preferred_jmp_rbx.as_ref()
254 }
255
256 pub fn jmp_rax(&self) -> Option<&Gadget> {
258 self.preferred_jmp_rax.as_ref()
259 }
260
261 pub fn ret_gadget(&self) -> Option<&Gadget> {
263 self.preferred_ret.as_ref()
264 }
265
266 pub fn get_by_type(&self, gadget_type: GadgetType) -> &[Gadget] {
268 self.by_type.get(&gadget_type).map(|v| v.as_slice()).unwrap_or(&[])
269 }
270
271 pub fn get_by_module(&self, module_name: &str) -> &[Gadget] {
273 self.by_module
274 .get(&module_name.to_lowercase())
275 .map(|v| v.as_slice())
276 .unwrap_or(&[])
277 }
278
279 pub fn any_jmp_gadget(&self) -> Option<&Gadget> {
281 self.preferred_jmp_rbx
282 .as_ref()
283 .or(self.preferred_jmp_rax.as_ref())
284 .or_else(|| {
285 self.by_type
286 .get(&GadgetType::JmpRbx)
287 .and_then(|v| v.first())
288 })
289 .or_else(|| {
290 self.by_type
291 .get(&GadgetType::JmpRax)
292 .and_then(|v| v.first())
293 })
294 }
295}
296
297pub struct GadgetFinder {
299 peb: Peb,
300}
301
302impl GadgetFinder {
303 pub fn new() -> Result<Self> {
305 Ok(Self {
306 peb: Peb::current()?,
307 })
308 }
309
310 pub fn find_jmp_rbx(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
312 self.find_gadgets_of_type(module_name, GadgetType::JmpRbx)
313 .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
314 }
315
316 pub fn find_jmp_rax(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
318 self.find_gadgets_of_type(module_name, GadgetType::JmpRax)
319 .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
320 }
321
322 pub fn find_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
324 self.find_gadgets_of_type(module_name, GadgetType::Ret)
325 .map(|gadgets| {
326 gadgets
327 .into_iter()
328 .map(|g| RetGadget {
329 gadget: g,
330 stack_adjustment: 0,
331 })
332 .collect()
333 })
334 }
335
336 pub fn find_gadgets_of_type(
338 &self,
339 module_name: &str,
340 gadget_type: GadgetType,
341 ) -> Result<Vec<Gadget>> {
342 let query = ModuleQuery::new(&self.peb);
343 let module = query.find_by_name(module_name)?;
344
345 let bytes = gadget_type.bytes();
346 if bytes.is_empty() {
347 return Ok(Vec::new());
348 }
349
350 let base = module.base();
351 let size = module.size();
352 let name = module.name();
353 let is_system = is_system_module(&name);
354
355 let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
358
359 let mut gadgets = Vec::new();
360 let pattern_len = bytes.len();
361
362 for offset in 0..=(size.saturating_sub(pattern_len)) {
364 if &data[offset..offset + pattern_len] == bytes {
365 gadgets.push(Gadget {
366 address: base + offset,
367 gadget_type,
368 module_name: name.clone(),
369 module_offset: offset,
370 is_system_module: is_system,
371 });
372 }
373 }
374
375 Ok(gadgets)
376 }
377
378 pub fn find_add_rsp_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
380 let query = ModuleQuery::new(&self.peb);
381 let module = query.find_by_name(module_name)?;
382
383 let base = module.base();
384 let size = module.size();
385 let name = module.name();
386 let is_system = is_system_module(&name);
387
388 let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
390
391 let mut gadgets = Vec::new();
392
393 for offset in 0..=(size.saturating_sub(5)) {
396 if data[offset] == 0x48
397 && data[offset + 1] == 0x83
398 && data[offset + 2] == 0xC4
399 && data[offset + 4] == 0xC3
400 {
401 let stack_adj = data[offset + 3] as usize;
402 gadgets.push(RetGadget {
403 gadget: Gadget {
404 address: base + offset,
405 gadget_type: GadgetType::AddRspRet {
406 offset: data[offset + 3],
407 },
408 module_name: name.clone(),
409 module_offset: offset,
410 is_system_module: is_system,
411 },
412 stack_adjustment: stack_adj,
413 });
414 }
415 }
416
417 for offset in 0..=(size.saturating_sub(8)) {
420 if data[offset] == 0x48
421 && data[offset + 1] == 0x81
422 && data[offset + 2] == 0xC4
423 && data[offset + 7] == 0xC3
424 {
425 let stack_adj = u32::from_le_bytes([
426 data[offset + 3],
427 data[offset + 4],
428 data[offset + 5],
429 data[offset + 6],
430 ]) as usize;
431
432 gadgets.push(RetGadget {
433 gadget: Gadget {
434 address: base + offset,
435 gadget_type: GadgetType::AddRspRet {
436 offset: 0, },
438 module_name: name.clone(),
439 module_offset: offset,
440 is_system_module: is_system,
441 },
442 stack_adjustment: stack_adj,
443 });
444 }
445 }
446
447 Ok(gadgets)
448 }
449
450 pub fn find_pop_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
452 let query = ModuleQuery::new(&self.peb);
453 let module = query.find_by_name(module_name)?;
454
455 let base = module.base();
456 let size = module.size();
457 let name = module.name();
458 let is_system = is_system_module(&name);
459
460 let data = unsafe { core::slice::from_raw_parts(base as *const u8, size) };
462
463 let mut gadgets = Vec::new();
464
465 for offset in 0..=(size.saturating_sub(2)) {
474 let first = data[offset];
475 if (0x58..=0x5F).contains(&first) && first != 0x5C && data[offset + 1] == 0xC3 {
476 gadgets.push(RetGadget {
477 gadget: Gadget {
478 address: base + offset,
479 gadget_type: GadgetType::PopRet {
480 register: first - 0x58,
481 },
482 module_name: name.clone(),
483 module_offset: offset,
484 is_system_module: is_system,
485 },
486 stack_adjustment: 8, });
488 }
489 }
490
491 for offset in 0..=(size.saturating_sub(3)) {
496 if data[offset] == 0x41
497 && (0x58..=0x5F).contains(&data[offset + 1])
498 && data[offset + 1] != 0x5C
499 && data[offset + 2] == 0xC3
500 {
501 gadgets.push(RetGadget {
502 gadget: Gadget {
503 address: base + offset,
504 gadget_type: GadgetType::PopRet {
505 register: data[offset + 1] - 0x58 + 8,
506 },
507 module_name: name.clone(),
508 module_offset: offset,
509 is_system_module: is_system,
510 },
511 stack_adjustment: 8,
512 });
513 }
514 }
515
516 Ok(gadgets)
517 }
518
519 pub fn scan_module_all(&self, module_name: &str) -> Result<Vec<Gadget>> {
521 let mut all_gadgets = Vec::new();
522
523 for gadget_type in [
525 GadgetType::JmpRbx,
526 GadgetType::JmpRax,
527 GadgetType::JmpRcx,
528 GadgetType::JmpRdx,
529 GadgetType::CallRbx,
530 GadgetType::CallRax,
531 GadgetType::Ret,
532 GadgetType::PushRbxRet,
533 ] {
534 if let Ok(gadgets) = self.find_gadgets_of_type(module_name, gadget_type) {
535 all_gadgets.extend(gadgets);
536 }
537 }
538
539 if let Ok(ret_gadgets) = self.find_add_rsp_ret(module_name) {
541 all_gadgets.extend(ret_gadgets.into_iter().map(|r| r.gadget));
542 }
543
544 if let Ok(pop_gadgets) = self.find_pop_ret(module_name) {
546 all_gadgets.extend(pop_gadgets.into_iter().map(|r| r.gadget));
547 }
548
549 Ok(all_gadgets)
550 }
551
552 pub fn find_best_jmp_gadget(&self) -> Result<JmpGadget> {
555 if let Ok(gadgets) = self.find_jmp_rbx("ntdll.dll") {
557 if let Some(g) = gadgets.into_iter().next() {
558 return Ok(g);
559 }
560 }
561
562 if let Ok(gadgets) = self.find_jmp_rax("ntdll.dll") {
563 if let Some(g) = gadgets.into_iter().next() {
564 return Ok(g);
565 }
566 }
567
568 if let Ok(gadgets) = self.find_jmp_rbx("kernelbase.dll") {
570 if let Some(g) = gadgets.into_iter().next() {
571 return Ok(g);
572 }
573 }
574
575 if let Ok(gadgets) = self.find_jmp_rbx("kernel32.dll") {
577 if let Some(g) = gadgets.into_iter().next() {
578 return Ok(g);
579 }
580 }
581
582 Err(WraithError::SyscallEnumerationFailed {
583 reason: "no suitable jmp gadget found".into(),
584 })
585 }
586
587 pub fn find_best_ret_gadget(&self) -> Result<RetGadget> {
589 if let Ok(gadgets) = self.find_ret("kernel32.dll") {
591 if let Some(g) = gadgets.into_iter().next() {
592 return Ok(g);
593 }
594 }
595
596 if let Ok(gadgets) = self.find_ret("kernelbase.dll") {
597 if let Some(g) = gadgets.into_iter().next() {
598 return Ok(g);
599 }
600 }
601
602 if let Ok(gadgets) = self.find_ret("ntdll.dll") {
603 if let Some(g) = gadgets.into_iter().next() {
604 return Ok(g);
605 }
606 }
607
608 Err(WraithError::SyscallEnumerationFailed {
609 reason: "no suitable ret gadget found".into(),
610 })
611 }
612}
613
614fn is_system_module(name: &str) -> bool {
616 let lower = name.to_lowercase();
617 lower == "ntdll.dll"
618 || lower == "kernel32.dll"
619 || lower == "kernelbase.dll"
620 || lower == "user32.dll"
621 || lower == "gdi32.dll"
622 || lower == "advapi32.dll"
623 || lower == "msvcrt.dll"
624 || lower == "ws2_32.dll"
625 || lower == "ole32.dll"
626 || lower == "combase.dll"
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_find_jmp_rbx_ntdll() {
635 let finder = GadgetFinder::new().expect("should create finder");
636 let gadgets = finder.find_jmp_rbx("ntdll.dll").expect("should find gadgets");
637
638 assert!(!gadgets.is_empty(), "should find jmp rbx gadgets in ntdll");
640
641 let first = &gadgets[0];
643 assert!(first.gadget.is_valid(), "gadget should be valid");
644 assert!(first.gadget.is_system_module, "should be system module");
645 }
646
647 #[test]
648 fn test_find_ret_gadgets() {
649 let finder = GadgetFinder::new().expect("should create finder");
650 let gadgets = finder.find_ret("kernel32.dll").expect("should find gadgets");
651
652 assert!(!gadgets.is_empty(), "should find ret gadgets in kernel32");
654
655 let first = &gadgets[0];
657 assert!(first.gadget.is_valid(), "gadget should be valid");
658 }
659
660 #[test]
661 fn test_find_add_rsp_ret() {
662 let finder = GadgetFinder::new().expect("should create finder");
663
664 if let Ok(gadgets) = finder.find_add_rsp_ret("ntdll.dll") {
665 for g in gadgets.iter().take(5) {
667 assert!(g.stack_adjustment > 0, "should have stack adjustment");
668 }
669 }
670 }
671
672 #[test]
673 fn test_gadget_cache() {
674 let cache = GadgetCache::build().expect("should build cache");
675
676 assert!(cache.jmp_rbx().is_some() || cache.jmp_rax().is_some());
678 }
679
680 #[test]
681 fn test_best_jmp_gadget() {
682 let finder = GadgetFinder::new().expect("should create finder");
683 let gadget = finder.find_best_jmp_gadget().expect("should find gadget");
684
685 assert!(gadget.gadget.is_valid(), "best gadget should be valid");
686 }
687}