1use serde::{Deserialize, Serialize};
37use wafrift_encoding::Strategy;
38
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub struct EncodingChain {
44 pub strategies: Vec<Strategy>,
45}
46
47impl EncodingChain {
48 #[must_use]
52 pub fn to_chain_names(&self) -> Vec<String> {
53 self.strategies
54 .iter()
55 .map(|s| s.as_str().to_string())
56 .collect()
57 }
58
59 #[must_use]
61 pub fn depth(&self) -> usize {
62 self.strategies.len()
63 }
64}
65
66#[derive(Debug, Clone)]
71pub struct LatticeSearch {
72 pub strategies: Vec<Strategy>,
76 pub min_depth: usize,
79 pub max_depth: usize,
82 pub skip_consecutive_dupes: bool,
87 pub max_chains: usize,
91}
92
93impl LatticeSearch {
94 #[must_use]
96 pub fn new(strategies: Vec<Strategy>) -> Self {
97 Self {
98 strategies,
99 min_depth: 1,
100 max_depth: 3,
101 skip_consecutive_dupes: true,
102 max_chains: 0,
103 }
104 }
105
106 #[must_use]
108 pub fn with_max_depth(mut self, d: usize) -> Self {
109 self.max_depth = d;
110 self
111 }
112
113 #[must_use]
115 pub fn with_min_depth(mut self, d: usize) -> Self {
116 self.min_depth = d;
117 self
118 }
119
120 #[must_use]
122 pub fn with_max_chains(mut self, n: usize) -> Self {
123 self.max_chains = n;
124 self
125 }
126
127 #[must_use]
129 pub fn allowing_consecutive_dupes(mut self) -> Self {
130 self.skip_consecutive_dupes = false;
131 self
132 }
133
134 #[must_use]
138 pub fn enumerate_chains(&self) -> Vec<EncodingChain> {
139 let mut out = vec![];
140 if self.strategies.is_empty() || self.min_depth == 0 || self.max_depth < self.min_depth {
141 return out;
142 }
143 for depth in self.min_depth..=self.max_depth {
144 self.enumerate_at_depth(depth, &mut Vec::with_capacity(depth), &mut out);
145 if self.max_chains > 0 && out.len() >= self.max_chains {
146 out.truncate(self.max_chains);
147 return out;
148 }
149 }
150 out
151 }
152
153 fn enumerate_at_depth(
154 &self,
155 remaining: usize,
156 prefix: &mut Vec<Strategy>,
157 out: &mut Vec<EncodingChain>,
158 ) {
159 if self.max_chains > 0 && out.len() >= self.max_chains {
160 return;
161 }
162 if remaining == 0 {
163 out.push(EncodingChain {
164 strategies: prefix.clone(),
165 });
166 return;
167 }
168 for &s in &self.strategies {
169 if self.skip_consecutive_dupes
170 && let Some(last) = prefix.last()
171 && *last == s
172 {
173 continue;
174 }
175 prefix.push(s);
176 self.enumerate_at_depth(remaining - 1, prefix, out);
177 prefix.pop();
178 }
179 }
180
181 #[must_use]
184 pub fn estimated_chain_count(&self) -> usize {
185 if self.strategies.is_empty() || self.min_depth == 0 || self.max_depth < self.min_depth {
186 return 0;
187 }
188 let n = self.strategies.len();
189 let mut total = 0usize;
190 for depth in self.min_depth..=self.max_depth {
191 let count = if self.skip_consecutive_dupes && n >= 2 {
192 n * (n - 1).saturating_pow((depth - 1) as u32)
196 } else {
197 n.saturating_pow(depth as u32)
198 };
199 total = total.saturating_add(count);
200 if self.max_chains > 0 && total >= self.max_chains {
201 return self.max_chains;
202 }
203 }
204 total
205 }
206}
207
208pub fn apply_chain(payload: &[u8], chain: &EncodingChain) -> Result<String, ChainApplyError> {
212 let mut current: String = std::str::from_utf8(payload)
213 .map_err(|_| ChainApplyError::InvalidUtf8)?
214 .to_string();
215 for &strategy in &chain.strategies {
216 match wafrift_encoding::encode(¤t, strategy) {
217 Ok(encoded) => current = encoded,
218 Err(e) => {
219 return Err(ChainApplyError::EncoderRejected(format!(
220 "{strategy:?}: {e}"
221 )));
222 }
223 }
224 }
225 Ok(current)
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
230pub enum ChainApplyError {
231 InvalidUtf8,
233 EncoderRejected(String),
236}
237
238impl std::fmt::Display for ChainApplyError {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 match self {
241 Self::InvalidUtf8 => f.write_str("input is not valid UTF-8"),
242 Self::EncoderRejected(s) => write!(f, "encoder rejected: {s}"),
243 }
244 }
245}
246
247impl std::error::Error for ChainApplyError {}
248
249#[must_use]
253pub fn shallow_lattice() -> LatticeSearch {
254 LatticeSearch::new(wafrift_encoding::all_strategies().to_vec()).with_max_depth(2)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use wafrift_encoding::Strategy;
261
262 fn three_strategies() -> Vec<Strategy> {
263 vec![
264 Strategy::UrlEncode,
265 Strategy::UnicodeEncode,
266 Strategy::CaseAlternation,
267 ]
268 }
269
270 #[test]
271 fn empty_palette_yields_no_chains() {
272 let s = LatticeSearch::new(vec![]);
273 assert_eq!(s.enumerate_chains().len(), 0);
274 assert_eq!(s.estimated_chain_count(), 0);
275 }
276
277 #[test]
278 fn depth_1_one_chain_per_strategy() {
279 let s = LatticeSearch::new(three_strategies()).with_max_depth(1);
280 let chains = s.enumerate_chains();
281 assert_eq!(chains.len(), 3);
282 for c in &chains {
283 assert_eq!(c.depth(), 1);
284 }
285 }
286
287 #[test]
288 fn depth_2_no_consecutive_dupes_default() {
289 let s = LatticeSearch::new(three_strategies()).with_max_depth(2);
290 assert_eq!(s.estimated_chain_count(), 9);
292 let chains = s.enumerate_chains();
293 assert_eq!(chains.len(), 9);
294 for c in &chains {
296 for w in c.strategies.windows(2) {
297 assert_ne!(w[0], w[1]);
298 }
299 }
300 }
301
302 #[test]
303 fn allowing_consecutive_dupes_gives_full_product() {
304 let s = LatticeSearch::new(three_strategies())
305 .with_max_depth(2)
306 .allowing_consecutive_dupes();
307 assert_eq!(s.estimated_chain_count(), 12);
309 assert_eq!(s.enumerate_chains().len(), 12);
310 }
311
312 #[test]
313 fn depth_3_count_correct_with_dedup() {
314 let s = LatticeSearch::new(three_strategies())
315 .with_min_depth(3)
316 .with_max_depth(3);
317 assert_eq!(s.estimated_chain_count(), 12);
319 assert_eq!(s.enumerate_chains().len(), 12);
320 }
321
322 #[test]
323 fn enumeration_is_deterministic() {
324 let s = LatticeSearch::new(three_strategies()).with_max_depth(3);
325 let a = s.enumerate_chains();
326 let b = s.enumerate_chains();
327 assert_eq!(a, b);
328 }
329
330 #[test]
331 fn max_chains_caps_output() {
332 let s = LatticeSearch::new(three_strategies())
333 .with_max_depth(3)
334 .with_max_chains(5);
335 assert_eq!(s.enumerate_chains().len(), 5);
336 assert_eq!(s.estimated_chain_count(), 5);
337 }
338
339 #[test]
340 fn max_chains_zero_means_no_cap() {
341 let s = LatticeSearch::new(three_strategies())
342 .with_max_depth(2)
343 .with_max_chains(0);
344 assert_eq!(s.enumerate_chains().len(), 9);
345 }
346
347 #[test]
348 fn min_greater_than_max_yields_empty() {
349 let s = LatticeSearch::new(three_strategies())
350 .with_min_depth(5)
351 .with_max_depth(3);
352 assert!(s.enumerate_chains().is_empty());
353 assert_eq!(s.estimated_chain_count(), 0);
354 }
355
356 #[test]
357 fn to_chain_names_round_trips() {
358 let chain = EncodingChain {
359 strategies: vec![Strategy::UrlEncode, Strategy::Base64Encode],
360 };
361 let names = chain.to_chain_names();
362 assert_eq!(
363 names,
364 vec!["UrlEncode".to_string(), "Base64Encode".to_string()]
365 );
366 }
367
368 #[test]
369 fn apply_chain_url_then_case() {
370 let chain = EncodingChain {
371 strategies: vec![Strategy::UrlEncode, Strategy::CaseAlternation],
372 };
373 let out = apply_chain(b"SELECT", &chain).expect("apply");
374 assert!(!out.is_empty());
377 }
378
379 #[test]
380 fn apply_chain_invalid_utf8_errors() {
381 let chain = EncodingChain {
382 strategies: vec![Strategy::CaseAlternation],
383 };
384 let invalid = vec![0xFF, 0xFE, 0xFD];
385 let r = apply_chain(&invalid, &chain);
386 assert!(matches!(
388 r,
389 Err(ChainApplyError::InvalidUtf8 | ChainApplyError::EncoderRejected(_))
390 ));
391 }
392
393 #[test]
394 fn apply_empty_chain_returns_input() {
395 let chain = EncodingChain { strategies: vec![] };
396 let out = apply_chain(b"hello", &chain).expect("apply");
397 assert_eq!(out, "hello");
398 }
399
400 #[test]
401 fn shallow_lattice_uses_full_palette() {
402 let s = shallow_lattice();
403 assert_eq!(s.max_depth, 2);
404 assert!(!s.strategies.is_empty());
405 }
406
407 #[test]
408 fn chain_serializes_round_trip() {
409 let chain = EncodingChain {
410 strategies: vec![Strategy::UrlEncode, Strategy::HtmlEntityEncode],
411 };
412 let json = serde_json::to_string(&chain).expect("ser");
413 let back: EncodingChain = serde_json::from_str(&json).expect("de");
414 assert_eq!(chain, back);
415 }
416
417 #[test]
418 fn chain_depth_reports_correct_length() {
419 let chain = EncodingChain {
420 strategies: vec![
421 Strategy::UrlEncode,
422 Strategy::CaseAlternation,
423 Strategy::Base64Encode,
424 ],
425 };
426 assert_eq!(chain.depth(), 3);
427 }
428
429 #[test]
430 fn lex_order_first_chain_is_first_strategy() {
431 let s = LatticeSearch::new(three_strategies()).with_max_depth(1);
432 let chains = s.enumerate_chains();
433 assert_eq!(chains[0].strategies, vec![Strategy::UrlEncode]);
435 }
436
437 #[test]
438 fn estimated_count_matches_actual() {
439 for max_depth in 1..=4 {
442 let s = LatticeSearch::new(three_strategies()).with_max_depth(max_depth);
443 let actual = s.enumerate_chains().len();
444 let estimated = s.estimated_chain_count();
445 assert_eq!(
446 actual, estimated,
447 "depth {max_depth}: actual {actual} vs estimated {estimated}"
448 );
449 }
450 }
451
452 #[test]
453 fn skip_consecutive_dupes_pins_no_aa() {
454 let s = LatticeSearch::new(three_strategies()).with_max_depth(2);
455 let chains = s.enumerate_chains();
456 for c in &chains {
458 if c.depth() == 2 {
459 assert_ne!(c.strategies[0], c.strategies[1]);
460 }
461 }
462 }
463
464 #[test]
465 fn adversarial_huge_max_depth_capped_by_max_chains() {
466 let pal = wafrift_encoding::all_strategies().to_vec();
469 let s = LatticeSearch::new(pal)
470 .with_max_depth(6)
471 .with_max_chains(100);
472 let chains = s.enumerate_chains();
473 assert!(chains.len() <= 100);
474 }
475
476 #[test]
477 fn estimated_count_saturates_on_huge_palette() {
478 let pal = wafrift_encoding::all_strategies().to_vec();
479 let s = LatticeSearch::new(pal)
480 .with_max_depth(10)
481 .with_max_chains(50);
482 assert_eq!(s.estimated_chain_count(), 50);
484 }
485
486 #[test]
487 fn min_depth_zero_yields_empty() {
488 let s = LatticeSearch::new(three_strategies())
489 .with_min_depth(0)
490 .with_max_depth(0);
491 assert!(s.enumerate_chains().is_empty());
492 }
493
494 #[test]
495 fn single_strategy_palette() {
496 let s = LatticeSearch::new(vec![Strategy::UrlEncode]).with_max_depth(3);
497 assert_eq!(s.enumerate_chains().len(), 1);
500 }
501
502 #[test]
503 fn single_strategy_palette_allowing_dupes() {
504 let s = LatticeSearch::new(vec![Strategy::UrlEncode])
505 .with_max_depth(3)
506 .allowing_consecutive_dupes();
507 assert_eq!(s.enumerate_chains().len(), 3);
509 }
510}