1#[cfg(feature = "regex")]
7use alloc::string::ToString;
8use alloc::{borrow::ToOwned, rc::Rc, string::String, vec::Vec};
9#[cfg(feature = "regex")]
10use core::str::from_utf8;
11use core::{
12 cell::RefCell,
13 hash::{BuildHasher, Hasher},
14};
15
16use ahash::RandomState;
17use hashbrown::HashMap;
18use libafl_bolts::{Error, HasLen};
19#[cfg(feature = "regex")]
20use regex::Regex;
21use serde::{Deserialize, Serialize};
22
23use crate::{corpus::CorpusId, inputs::Input};
24
25pub trait InputEncoder<T>
27where
28 T: Tokenizer,
29{
30 fn encode(&mut self, bytes: &[u8], tokenizer: &mut T) -> Result<EncodedInput, Error>;
32}
33
34pub trait InputDecoder {
36 fn decode(&self, input: &EncodedInput, bytes: &mut Vec<u8>) -> Result<(), Error>;
38}
39
40pub trait Tokenizer {
42 fn tokenize(&self, bytes: &[u8]) -> Result<Vec<String>, Error>;
44}
45
46#[derive(Clone, Debug)]
48pub struct TokenInputEncoderDecoder {
49 token_table: HashMap<String, u32>,
51 id_table: HashMap<u32, String>,
53 next_id: u32,
55}
56
57impl<T> InputEncoder<T> for TokenInputEncoderDecoder
58where
59 T: Tokenizer,
60{
61 fn encode(&mut self, bytes: &[u8], tokenizer: &mut T) -> Result<EncodedInput, Error> {
62 let mut codes = vec![];
63 let tokens = tokenizer.tokenize(bytes)?;
64 for tok in tokens {
65 if let Some(id) = self.token_table.get(&tok) {
66 codes.push(*id);
67 } else {
68 self.token_table.insert(tok.clone(), self.next_id);
69 self.id_table.insert(self.next_id, tok.clone());
70 codes.push(self.next_id);
71 self.next_id += 1;
72 }
73 }
74 Ok(EncodedInput::new(codes))
75 }
76}
77
78impl InputDecoder for TokenInputEncoderDecoder {
79 fn decode(&self, input: &EncodedInput, bytes: &mut Vec<u8>) -> Result<(), Error> {
80 for id in input.codes() {
81 let tok = self
82 .id_table
83 .get(&(id % self.next_id))
84 .ok_or_else(|| Error::illegal_state(format!("Id {id} not in the decoder table")))?;
85 bytes.extend_from_slice(tok.as_bytes());
86 bytes.push(b' ');
87 }
88 Ok(())
89 }
90}
91
92impl TokenInputEncoderDecoder {
93 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 token_table: HashMap::default(),
98 id_table: HashMap::default(),
99 next_id: 0,
100 }
101 }
102}
103
104impl Default for TokenInputEncoderDecoder {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[cfg(feature = "regex")]
112#[derive(Clone, Debug)]
113pub struct NaiveTokenizer {
114 ident_re: Regex,
116 comment_re: Regex,
118 string_re: Regex,
120}
121
122#[cfg(feature = "regex")]
123impl NaiveTokenizer {
124 #[must_use]
126 pub fn new(ident_re: Regex, comment_re: Regex, string_re: Regex) -> Self {
127 Self {
128 ident_re,
129 comment_re,
130 string_re,
131 }
132 }
133}
134
135#[cfg(feature = "regex")]
136impl Default for NaiveTokenizer {
137 fn default() -> Self {
138 Self {
139 ident_re: Regex::new("[A-Za-z0-9_$]+").unwrap(),
141 comment_re: Regex::new(r"(/\*[^*]*\*/)|(//[^*]*)").unwrap(),
143 string_re: Regex::new("\"(\\\\|\\\\\"|[^\"])*\"|'(\\\\|\\\\'|[^'])*'").unwrap(),
145 }
146 }
147}
148
149#[cfg(feature = "regex")]
150impl Tokenizer for NaiveTokenizer {
151 fn tokenize(&self, bytes: &[u8]) -> Result<Vec<String>, Error> {
152 let mut tokens = vec![];
153 let string =
154 from_utf8(bytes).map_err(|_| Error::illegal_argument("Invalid UTF-8".to_owned()))?;
155 let string = self.comment_re.replace_all(string, "").to_string();
156 let mut str_prev = 0;
157 for str_match in self.string_re.find_iter(&string) {
158 if str_match.start() > str_prev {
159 for ws_tok in string[str_prev..str_match.start()].split_whitespace() {
160 let mut ident_prev = 0;
161 for ident_match in self.ident_re.find_iter(ws_tok) {
162 if ident_match.start() > ident_prev {
163 tokens.push(ws_tok[ident_prev..ident_match.start()].to_owned());
164 }
165 tokens.push(ws_tok[ident_match.start()..ident_match.end()].to_owned());
166 ident_prev = ident_match.end();
167 }
168 if ident_prev < ws_tok.len() {
169 tokens.push(ws_tok[ident_prev..].to_owned());
170 }
171 }
172 }
173 tokens.push(string[str_match.start()..str_match.end()].to_owned());
174 str_prev = str_match.end();
175 }
176 if str_prev < string.len() {
177 for ws_tok in string[str_prev..].split_whitespace() {
178 let mut ident_prev = 0;
179 for ident_match in self.ident_re.find_iter(ws_tok) {
180 if ident_match.start() > ident_prev {
181 tokens.push(ws_tok[ident_prev..ident_match.start()].to_owned());
182 }
183 tokens.push(ws_tok[ident_match.start()..ident_match.end()].to_owned());
184 ident_prev = ident_match.end();
185 }
186 if ident_prev < ws_tok.len() {
187 tokens.push(ws_tok[ident_prev..].to_owned());
188 }
189 }
190 }
191 Ok(tokens)
192 }
193}
194
195#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, Hash)]
197pub struct EncodedInput {
198 codes: Vec<u32>,
200}
201
202impl Input for EncodedInput {
203 fn generate_name(&self, _id: Option<CorpusId>) -> String {
205 let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
206 for code in &self.codes {
207 hasher.write(&code.to_le_bytes());
208 }
209 format!("{:016x}", hasher.finish())
210 }
211}
212
213impl From<EncodedInput> for Rc<RefCell<EncodedInput>> {
215 fn from(input: EncodedInput) -> Self {
216 Rc::new(RefCell::new(input))
217 }
218}
219
220impl HasLen for EncodedInput {
221 #[inline]
222 fn len(&self) -> usize {
223 self.codes.len()
224 }
225}
226
227impl From<Vec<u32>> for EncodedInput {
228 fn from(codes: Vec<u32>) -> Self {
229 Self::new(codes)
230 }
231}
232
233impl From<&[u32]> for EncodedInput {
234 fn from(codes: &[u32]) -> Self {
235 Self::new(codes.to_owned())
236 }
237}
238
239impl EncodedInput {
240 #[must_use]
242 pub fn new(codes: Vec<u32>) -> Self {
243 Self { codes }
244 }
245
246 #[must_use]
248 pub fn codes(&self) -> &[u32] {
249 &self.codes
250 }
251
252 #[must_use]
254 pub fn codes_mut(&mut self) -> &mut Vec<u32> {
255 &mut self.codes
256 }
257}
258
259#[cfg(feature = "regex")]
260#[cfg(test)]
261mod tests {
262 use alloc::borrow::ToOwned;
263 use core::str::from_utf8;
264
265 use crate::inputs::encoded::{
266 InputDecoder, InputEncoder, NaiveTokenizer, TokenInputEncoderDecoder,
267 };
268
269 #[test]
270 #[cfg_attr(all(miri, target_arch = "aarch64", target_vendor = "apple"), ignore)] fn test_input() {
272 let mut t = NaiveTokenizer::default();
273 let mut ed = TokenInputEncoderDecoder::new();
274 let input = ed
275 .encode("/* test */a = 'pippo baudo'; b=c+a\n".as_bytes(), &mut t)
276 .unwrap();
277 let mut bytes = vec![];
278 ed.decode(&input, &mut bytes).unwrap();
279 assert_eq!(
280 from_utf8(&bytes).unwrap(),
281 "a = 'pippo baudo' ; b = c + a ".to_owned()
282 );
283 }
284}