1use crate::util;
2use std::borrow::BorrowMut;
3
4use radix_trie::{Trie, TrieKey};
5use rand::Rng;
6use std::fmt::Write;
7use tl::NodeHandle;
8use tl::{HTMLTag, Node, Parser};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub enum SelectorPart {
16 Tag(String),
17 Class(String),
18 Id(String),
19 NthChild(usize),
20}
21
22impl ToString for SelectorPart {
23 fn to_string(&self) -> String {
24 let mut out = String::new();
25 match self {
26 SelectorPart::Tag(tag) => {
27 write!(&mut out, "{tag}")
28 }
29 SelectorPart::Class(class) => {
30 write!(&mut out, ".{class}")
31 }
32 SelectorPart::Id(id) => {
33 write!(&mut out, "#{id}")
34 }
35 SelectorPart::NthChild(n) => {
36 write!(&mut out, ":nth-child({n})")
37 }
38 }
39 .expect("write");
40 out
41 }
42}
43
44impl SelectorPart {
45 fn matches(&self, tag: &HTMLTag) -> bool {
47 match self {
48 SelectorPart::Tag(tagname) => tag.name() == tagname.as_str(),
49 SelectorPart::Class(class) => tag.attributes().is_class_member(class),
50 SelectorPart::Id(id) => tag
51 .attributes()
52 .id()
53 .map(|other_id| other_id == id.as_str())
54 .unwrap_or(false),
55 SelectorPart::NthChild(_) => {
56 panic!("cannot match :nth-child selector on its own!")
57 }
58 }
59 }
60
61 fn try_select(&self, node: NodeHandle, parser: &Parser) -> Option<NodeHandle> {
64 let tag = node.get(parser)?.as_tag()?;
65
66 if let SelectorPart::NthChild(n) = self {
68 debug_assert!(*n >= 1);
69 return tag
70 .children()
71 .top()
72 .iter()
73 .filter(|child| {
74 util::node_is_tag(child, parser)
76 })
77 .nth(*n - 1)
78 .cloned();
79 }
80
81 let results = tag
82 .children()
83 .all(parser)
84 .iter()
85 .enumerate()
86 .filter(|(_i, child)| matches!(child, Node::Tag(..)))
87 .filter(|(_i, child)| self.matches(child.as_tag().unwrap()))
88 .take(2)
89 .collect::<Vec<_>>();
90
91 if results.is_empty() || results.len() >= 2 {
92 None
93 } else {
94 results
95 .get(0)
96 .map(|(i, _child)| NodeHandle::new(*i as u32 + 1 + node.get_inner()))
97 }
98 }
99
100 fn score(&self) -> i32 {
102 match self {
103 SelectorPart::Tag(tag) => tag.len() as i32 + 1,
104 SelectorPart::Class(class) => class.len() as i32 + 1,
105 SelectorPart::Id(_) => 0,
106 SelectorPart::NthChild(n) => 13 + (*n as i32 / 2),
107 }
108 }
109}
110
111#[derive(Clone)]
112#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
113pub struct Selector {
114 parts: Vec<SelectorPart>,
115 pub string: String,
116 pub score: i32,
117}
118
119impl PartialEq for Selector {
120 fn eq(&self, other: &Selector) -> bool {
121 self.string.eq(&other.string)
122 }
123}
124
125impl Eq for Selector {}
126
127impl TrieKey for Selector {
128 fn encode_bytes(&self) -> Vec<u8> {
129 TrieKey::encode_bytes(&self.string)
130 }
131}
132
133impl Selector {
134 pub fn new_from_parts(parts: Vec<SelectorPart>) -> Self {
138 let string = parts
140 .iter()
141 .map(|part| part.to_string())
142 .collect::<Vec<String>>()
143 .join(" > ");
144 let score = parts.iter().map(|part| part.score()).sum();
145 Selector {
146 parts,
147 string,
148 score,
149 }
150 }
151
152 pub fn len(&self) -> usize {
153 self.parts.len()
154 }
155
156 pub fn try_select_with_skip(
157 &self,
158 handle: NodeHandle,
159 parser: &Parser,
160 skip: usize,
161 ) -> Option<NodeHandle> {
162 self.parts
163 .iter()
164 .skip(skip)
165 .fold(Some(handle), |acc, selector| {
166 acc.and_then(|node| selector.try_select(node, parser))
167 })
168 }
169
170 pub fn try_select_with_skip_path(
171 &self,
172 handle: NodeHandle,
173 parser: &Parser,
174 skip: usize,
175 max_len: usize,
176 ) -> Vec<Option<NodeHandle>> {
177 self.parts
178 .iter()
179 .skip(skip)
180 .fold(vec![], |mut path, selector| {
181 if path.len() >= max_len {
182 return path;
183 }
184
185 let last = if path.is_empty() {
187 Some(handle)
188 } else {
189 *path.last().unwrap()
190 };
191
192 if let Some(last_node) = last {
193 path.push(selector.try_select(last_node, parser));
194 } else {
195 path.push(None);
196 }
197
198 path
199 })
200 }
201
202 pub fn try_select(&self, handle: NodeHandle, parser: &Parser) -> Option<NodeHandle> {
205 self.try_select_with_skip(handle, parser, 0)
206 }
207
208 pub fn try_select_path(
209 &self,
210 handle: NodeHandle,
211 parser: &Parser,
212 max_len: usize,
213 ) -> Vec<Option<NodeHandle>> {
214 self.try_select_with_skip_path(handle, parser, 0, max_len)
215 }
216
217 pub(crate) fn score(&self) -> i32 {
218 self.score
219 }
220
221 fn append(&self, mut other: Selector) -> Self {
224 let mut selectors = Vec::with_capacity(other.parts.len() + self.parts.len());
225 selectors.append(&mut self.parts.clone());
226 selectors.append(&mut other.parts);
227
228 Selector::new_from_parts(selectors)
229 }
230
231 fn split_at(&self, depth: usize) -> (Self, Self) {
233 let mut cloned = self.parts.clone();
234 let tail = cloned.split_off(depth);
235 (
236 Selector::new_from_parts(cloned),
237 Selector::new_from_parts(tail),
238 )
239 }
240}
241
242impl std::fmt::Debug for Selector {
243 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
244 write!(fmt, "{}", self.string)
245 }
246}
247
248impl From<SelectorPart> for Selector {
249 fn from(value: SelectorPart) -> Self {
250 Selector::new_from_parts(vec![value])
251 }
252}
253
254impl ToString for Selector {
255 fn to_string(&self) -> String {
256 self.string.clone()
257 }
258}
259
260pub(crate) struct SelectorCache {
261 selector_cache: Trie<Selector, (usize, Option<NodeHandle>)>,
262}
263
264impl SelectorCache {
265 const ENABLED: bool = true;
267
268 const ALWAYS_CACHE_LEAF: bool = true;
270
271 const AGGRESSIVE_ADD_MAX_DEPTH: usize = 4;
275
276 pub(crate) fn new() -> Self {
277 SelectorCache {
278 selector_cache: Default::default(),
279 }
280 }
281
282 pub(crate) fn try_select(
285 &mut self,
286 selector: &Selector,
287 root: NodeHandle,
288 parser: &Parser,
289 ) -> Option<NodeHandle> {
290 if let Some((ancestor_length, ancestor_handle)) =
291 self.selector_cache.get_ancestor_value(selector)
292 {
293 if ancestor_handle.is_some() && *ancestor_length < selector.len() {
294 let target = selector.try_select_with_skip(
295 ancestor_handle.unwrap(),
296 parser,
297 *ancestor_length,
298 );
299 if SelectorCache::ENABLED {
300 let len = *ancestor_length;
301 if SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH > len {
302 selector
303 .try_select_with_skip_path(
304 ancestor_handle.unwrap(),
305 parser,
306 len,
307 SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH - len,
308 )
309 .iter()
310 .enumerate()
311 .for_each(|(i, subnode)| {
312 self.selector_cache.insert(
313 selector.split_at(len + i + 1).0,
314 (len + i + 1, *subnode),
315 );
316 });
317 }
318 if SelectorCache::ALWAYS_CACHE_LEAF
319 && SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH - len < selector.len()
320 {
321 self.selector_cache
322 .insert(selector.clone(), (selector.len(), target));
323 }
324 }
325 target
326 } else {
327 *ancestor_handle
328 }
329 } else {
330 let target = selector.try_select(root, parser);
331 if SelectorCache::ENABLED {
332 if SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH > 0 {
333 selector
334 .try_select_path(root, parser, SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH)
335 .iter()
336 .enumerate()
337 .for_each(|(i, subnode)| {
338 self.selector_cache
339 .insert(selector.split_at(i + 1).0, (i + 1, *subnode));
340 });
341 }
342 if SelectorCache::ALWAYS_CACHE_LEAF
343 && SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH < selector.len()
344 {
345 self.selector_cache
346 .insert(selector.clone(), (selector.len(), target));
347 }
348 }
349 target
350 }
351 }
352}
353
354pub struct SelectorFuzzer {
355 root_selector_cache: SelectorCache,
356 pub(crate) retries_used: usize,
357}
358
359impl SelectorFuzzer {
360 pub fn new() -> Self {
361 SelectorFuzzer {
362 root_selector_cache: SelectorCache::new(),
363 retries_used: 0,
364 }
365 }
366
367 pub(crate) fn mutate_selector<R: Rng>(
371 &mut self,
372 selector: &Selector,
373 root: NodeHandle,
374 parser: &Parser,
375 retries: usize,
376 rng: &mut R,
377 ) -> Option<Selector> {
378 if selector.parts.len() <= 1 {
379 return None;
380 }
381
382 let random_index = rng.borrow_mut().gen_range(1..selector.parts.len());
383 let (left, right) = selector.split_at(random_index);
384 let left_node = self.root_selector_cache.try_select(&left, root, parser)?;
385 let new_left = self.random_selector_for_node(left_node, root, parser, retries, rng)?;
386 Some(new_left.append(right))
387 }
388
389 pub fn random_selector_for_node<R: Rng>(
392 &mut self,
393 handle: NodeHandle,
394 root: NodeHandle,
395 parser: &Parser,
396 retries: usize,
397 rng: &mut R,
398 ) -> Option<Selector> {
399 let tag = handle.get(parser)?.as_tag()?;
400
401 if let Some(id) = util::get_id(handle, parser) {
402 return Some(Selector::from(SelectorPart::Id(id.to_string())));
403 }
404
405 let parent = util::find_parent(handle, parser);
406 let has_parent = parent.is_some();
407 for tries in 0..retries {
408 self.retries_used += 1;
409 let typ = rng.gen_range(0..3);
410
411 let selector = match typ {
412 0 => Selector::from(SelectorPart::Tag(tag.name().as_utf8_str().to_string())),
413 1 => {
414 let classes = tag.attributes().class_iter()?.collect::<Vec<_>>();
415 if classes.is_empty() {
416 continue;
417 }
418 let random_index = rng.gen_range(0..classes.len());
419 Selector::from(SelectorPart::Class(classes[random_index].to_string()))
420 }
421 2 => {
422 if !has_parent {
423 continue;
424 }
425 let parent = parent.unwrap().get(parser).unwrap();
426 let index = parent
427 .children()
428 .unwrap()
429 .top()
430 .iter()
431 .filter(|child| util::node_is_tag(child, parser))
432 .position(|child| child.get_inner() == handle.get_inner())
433 .expect("child of parent should exists in parent.children()");
434 Selector::from(SelectorPart::NthChild(index + 1))
435 }
436 _ => unreachable!(),
437 };
438
439 let globally_unique = typ != 2
440 && matches!(self.root_selector_cache.try_select(&selector, root, parser), Some(h) if h == handle);
441 if globally_unique {
442 return Some(selector);
443 }
444 let locally_unique = has_parent
445 && matches!(selector.try_select(parent.unwrap(), parser), Some(h) if h == handle);
446 if locally_unique {
447 let parent_selector = self.random_selector_for_node(
448 parent.unwrap(),
449 root,
450 parser,
451 retries - tries,
452 rng,
453 );
454 if let Some(parent_selector) = parent_selector {
455 let combined_selector = parent_selector.append(selector);
456 return Some(combined_selector);
457 }
458 }
459 }
460 None
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use crate::selectors::*;
467 use crate::util;
468 use rand::SeedableRng;
469 use rand_chacha::ChaCha8Rng;
470
471 use tl::VDom;
472
473 const HTML: &'static str = r#"
474 <div class="div_class">
475 <div id="div_id">
476 <p class="p_class">TARGET</p>
477 <p class="other_class">...</p>
478 </div>
479 </div>
480 "#;
481
482 fn get_simple_example() -> VDom<'static> {
483 let dom = tl::parse(HTML, tl::ParserOptions::default()).unwrap();
484 dom
485 }
486
487 #[test]
488 fn test_find_node_with_text() {
489 let dom = get_simple_example();
490 let parser = dom.parser();
491 let node = util::find_node_with_text(&dom, "TARGET").unwrap();
492 assert_eq!(util::get_classes(node, parser).unwrap(), "p_class")
493 }
494
495 #[test]
496 fn test_find_parent() {
497 let dom = get_simple_example();
498 let parser = dom.parser();
499 let element: NodeHandle = dom.query_selector("p").unwrap().next().unwrap();
500 assert_eq!(util::get_classes(element, parser).unwrap(), "p_class");
501 let parent = util::find_parent(element, parser).unwrap();
502 assert_eq!(util::get_id(parent, parser).unwrap(), "div_id");
503 let parent_parent = util::find_parent(parent, parser).unwrap();
504 assert_eq!(
505 util::get_classes(parent_parent, parser).unwrap(),
506 "div_class"
507 );
508 let parent_parent_parent = util::find_parent(parent_parent, parser);
509 assert_eq!(parent_parent_parent, None)
510 }
511
512 #[test]
513 fn test_selector() {
514 fn test_selector(
515 selector: Selector,
516 expected_str: &str,
517 _parser: &Parser,
518 ) -> Option<NodeHandle> {
519 assert_eq!(selector.to_string(), expected_str);
520 let dom = get_simple_example();
521 let parser = dom.parser();
522 selector.try_select(NodeHandle::new(1), parser)
523 }
524
525 let dom = get_simple_example();
526 let parser = dom.parser();
527
528 let target = test_selector(
529 Selector::new_from_parts(vec![SelectorPart::Id("div_id".into())]),
530 "#div_id",
531 parser,
532 );
533 assert_eq!(util::get_id(target.unwrap(), parser).unwrap(), "div_id");
534
535 let target = test_selector(
536 Selector::new_from_parts(vec![
537 SelectorPart::Id("div_id".into()),
538 SelectorPart::NthChild(1),
539 ]),
540 "#div_id > :nth-child(1)",
541 parser,
542 );
543 assert_eq!(
544 util::get_classes(target.unwrap(), parser).unwrap(),
545 "p_class"
546 );
547
548 let target = test_selector(
549 Selector::new_from_parts(vec![
550 SelectorPart::Id("div_id".into()),
551 SelectorPart::NthChild(2),
552 ]),
553 "#div_id > :nth-child(2)",
554 parser,
555 );
556 assert_eq!(
557 util::get_classes(target.unwrap(), parser).unwrap(),
558 "other_class"
559 );
560
561 let target = test_selector(
562 Selector::new_from_parts(vec![
563 SelectorPart::NthChild(1),
564 SelectorPart::Class("p_class".into()),
565 ]),
566 ":nth-child(1) > .p_class",
567 parser,
568 );
569 assert_eq!(
570 util::get_classes(target.unwrap(), parser).unwrap(),
571 "p_class"
572 );
573
574 let target = test_selector(
575 Selector::new_from_parts(vec![SelectorPart::Class("p_class".into())]),
576 ".p_class",
577 parser,
578 );
579 assert_eq!(
580 util::get_classes(target.unwrap(), parser).unwrap(),
581 "p_class"
582 );
583
584 let target = test_selector(
585 Selector::new_from_parts(vec![SelectorPart::Tag("div".into())]),
586 "div",
587 parser,
588 ); assert_eq!(util::get_id(target.unwrap(), parser).unwrap(), "div_id");
590
591 let target = test_selector(
592 Selector::new_from_parts(vec![SelectorPart::Tag("p".into())]),
593 "p",
594 parser,
595 );
596 assert_eq!(target, None);
597
598 let target = test_selector(
599 Selector::new_from_parts(vec![
600 SelectorPart::Tag("div".into()),
601 SelectorPart::Tag("p".into()),
602 ]),
603 "div > p",
604 parser,
605 );
606 assert_eq!(target, None);
607 }
608
609 #[test]
610 fn test_random_selector() {
611 const HTML: &'static str = r#"
612 <root>
613 <div class="div_class">
614 <div class="div_id_class">
615 <p class="p_class">TARGET</p>
616 <p class="p_class">...</p>
617 </div>
618 </div>
619 </root>
620 "#;
621
622 let dom = tl::parse(HTML, tl::ParserOptions::default()).unwrap();
623 let parser = dom.parser();
624 let root = util::find_root(&dom).unwrap();
625 let target = dom.query_selector(".p_class").unwrap().next().unwrap();
626 let mut rng = ChaCha8Rng::seed_from_u64(1337);
627 let mut searcher = SelectorFuzzer::new();
628 println!(
629 "{:?}",
630 searcher
631 .random_selector_for_node(target, *root, parser, 10, &mut rng)
632 .map(|sel| sel.to_string())
633 );
634 }
635}