azathoth_utils/
psearch.rs

1use crate::errors::{AzUtilErrorCode, AzUtilResult};
2
3/// A trait for types that can be searched for within a byte slice
4pub trait Pattern {
5    /// Returns true if the pattern matches the given window of bytes
6    fn matches(&self, window: &[u8]) -> bool;
7
8    /// Returns the length of the pattern
9    fn len(&self) -> usize;
10
11    /// Checks if the pattern is empty (has a length of 0)
12    fn is_empty(&self) -> bool {
13        self.len() == 0
14    }
15}
16
17/// A basic byte pattern
18///
19/// # Example
20/// ```
21/// use azathoth_utils::psearch::{BasePattern, Searcher};
22///
23/// fn main() {
24///     let memory_region = b"deadbeef_and_more_deadbeef_and_the_final_deadbeef";
25///     let bpattern = BasePattern::new(b"deadbeef");
26///     let mut bsearcher = Searcher::new(bpattern).unwrap();
27///     let offsets: Vec<_> = bsearcher.search_all(memory_region).collect();
28///     println!("Found offsets: {:?}", offsets);
29/// }
30/// ```
31#[derive(Debug, Clone, Copy)]
32pub struct BasePattern<'a> {
33    pattern: &'a [u8],
34}
35
36impl<'a> BasePattern<'a> {
37    /// Creates a new basic pattern object
38    pub fn new(pattern: &'a [u8]) -> Self {
39        Self { pattern }
40    }
41}
42impl<'a> Pattern for BasePattern<'a> {
43    #[inline(always)]
44    fn matches(&self, window: &[u8]) -> bool {
45        self.pattern == window
46    }
47
48    #[inline(always)]
49    fn len(&self) -> usize {
50        self.pattern.len()
51    }
52}
53
54/// A byte pattern with a wildcard mask
55/// In the mask, `1` means the byte must match, and `0` means it's a wildcard.
56///
57/// # Example
58/// ```
59/// use azathoth_utils::psearch::{MaskedPattern, Searcher};
60///
61/// fn main() {
62///     let memory_region = b"deadbeef_and_more_deadbeef_and_the_final_deadbeef";
63///     let mpattern = MaskedPattern::new(b"deadbeef", &[1, 0, 0, 1, 0, 0, 0, 1]).unwrap();
64///     let mut msearcher = Searcher::new(mpattern).unwrap();
65///     let offsets: Vec<_> = msearcher.search_all(memory_region).collect();
66///     println!("Found offsets: {:?}", offsets);
67/// }
68/// ```
69pub struct MaskedPattern<'a> {
70    pattern: &'a [u8],
71    mask: &'a [u8],
72}
73impl<'a> MaskedPattern<'a> {
74    /// Creates a new masked pattern object
75    pub fn new(pattern: &'a [u8], mask: &'a [u8]) -> AzUtilResult<Self> {
76        if pattern.len() != mask.len() {
77            return Err(AzUtilErrorCode::HashError);
78        }
79        Ok(Self { pattern, mask })
80    }
81}
82
83impl<'a> Pattern for MaskedPattern<'a> {
84    fn matches(&self, window: &[u8]) -> bool {
85        self.pattern
86            .iter()
87            .zip(self.mask.iter())
88            .zip(window.iter())
89            .all(|((&p, &m), &w)| m == 0 || p == w)
90    }
91    fn len(&self) -> usize {
92        self.pattern.len()
93    }
94}
95
96/// A Generic searcher for any type that implements the [`Pattern`] trait
97#[derive(Default, Debug)]
98pub struct Searcher<P: Pattern> {
99    pattern: P,
100    last_result: Option<usize>,
101}
102
103impl<P: Pattern> Searcher<P> {
104    /// Creates a new pattern searcher object
105    pub fn new(pattern: P) -> AzUtilResult<Self> {
106        if pattern.is_empty() {
107            return Err(AzUtilErrorCode::HashError);
108        }
109        Ok(Self {
110            pattern,
111            last_result: None,
112        })
113    }
114
115    /// Searches a given memory region for the first occurance of the pattern
116    pub fn search(&mut self, region: &[u8]) -> Option<usize> {
117        if region.len() < self.pattern.len() {
118            return None;
119        }
120        let result = region
121            .windows(self.pattern.len())
122            .position(|window| self.pattern.matches(window));
123        self.last_result = result;
124        result
125    }
126
127    /// Returns an iterator over all non-overlapping matches of the pattern in the region.
128    pub fn search_all<'searcher, 'region>(&'searcher mut self, region: &'region [u8]) -> SearchAll<'searcher, 'region, P> {
129        SearchAll {
130            searcher: self,
131            region,
132            current_offset: 0
133        }
134    }
135
136    /// Gets the last found result, if any.
137    #[inline(always)]
138    pub fn result(&self) -> Option<usize> {
139        self.last_result
140    }
141
142    /// Resets the context result for reuse with the same pattern.
143    #[inline(always)]
144    pub fn reset(&mut self) {
145        self.last_result = None;
146    }
147
148    /// Returns a reference to the underlying pattern.
149    pub fn pattern(&self) -> &P {
150        &self.pattern
151    }
152
153    /// Resets the searcher and sets a new pattern
154    pub fn set_pattern(&mut self, pattern: P) {
155        self.pattern = pattern;
156        self.reset();
157    }
158}
159
160/// Iterator over all matches in a region. Produced by [`Searcher::search_all`]
161pub struct SearchAll<'searcher, 'region, P: Pattern> {
162    searcher: &'searcher mut Searcher<P>,
163    region: &'region [u8],
164    current_offset: usize
165}
166
167impl<'searcher, 'region, P: Pattern> Iterator for SearchAll<'searcher, 'region, P> {
168    type Item = usize;
169    fn next(&mut self) -> Option<Self::Item> {
170        if self.current_offset + self.searcher.pattern.len() > self.region.len() {
171            return None;
172        }
173
174        let remaining_region = &self.region[self.current_offset..];
175        match self.searcher.search(remaining_region) {
176            Some(found) => {
177                let pos = self.current_offset + found;
178                self.current_offset += found + 1;
179                Some(pos)
180            }
181            None => {
182                self.current_offset += self.region.len();
183                None
184            }
185        }
186    }
187}