mlscraper_rust/
selectors.rs

1use crate::util;
2use std::borrow::BorrowMut;
3
4use radix_trie::{Trie, TrieKey};
5use rand::Rng;
6use std::fmt::Write;
7use tl::NodeHandle;
8use tl::{HTMLTag, Node, Parser};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub enum SelectorPart {
16    Tag(String),
17    Class(String),
18    Id(String),
19    NthChild(usize),
20}
21
22impl ToString for SelectorPart {
23    fn to_string(&self) -> String {
24        let mut out = String::new();
25        match self {
26            SelectorPart::Tag(tag) => {
27                write!(&mut out, "{tag}")
28            }
29            SelectorPart::Class(class) => {
30                write!(&mut out, ".{class}")
31            }
32            SelectorPart::Id(id) => {
33                write!(&mut out, "#{id}")
34            }
35            SelectorPart::NthChild(n) => {
36                write!(&mut out, ":nth-child({n})")
37            }
38        }
39        .expect("write");
40        out
41    }
42}
43
44impl SelectorPart {
45    /// Returns true if this selector part matches the given HTML tag
46    fn matches(&self, tag: &HTMLTag) -> bool {
47        match self {
48            SelectorPart::Tag(tagname) => tag.name() == tagname.as_str(),
49            SelectorPart::Class(class) => tag.attributes().is_class_member(class),
50            SelectorPart::Id(id) => tag
51                .attributes()
52                .id()
53                .map(|other_id| other_id == id.as_str())
54                .unwrap_or(false),
55            SelectorPart::NthChild(_) => {
56                panic!("cannot match :nth-child selector on its own!")
57            }
58        }
59    }
60
61    /// Tries to find a node matching this SelectorPart by searching all children starting
62    /// from `node`. A value will be returned iff exactly one element matched.
63    fn try_select(&self, node: NodeHandle, parser: &Parser) -> Option<NodeHandle> {
64        let tag = node.get(parser)?.as_tag()?;
65
66        // Handle :nth-child selector
67        if let SelectorPart::NthChild(n) = self {
68            debug_assert!(*n >= 1);
69            return tag
70                .children()
71                .top()
72                .iter()
73                .filter(|child| {
74                    // Only consider children that are tags
75                    util::node_is_tag(child, parser)
76                })
77                .nth(*n - 1)
78                .cloned();
79        }
80
81        let results = tag
82            .children()
83            .all(parser)
84            .iter()
85            .enumerate()
86            .filter(|(_i, child)| matches!(child, Node::Tag(..)))
87            .filter(|(_i, child)| self.matches(child.as_tag().unwrap()))
88            .take(2)
89            .collect::<Vec<_>>();
90
91        if results.is_empty() || results.len() >= 2 {
92            None
93        } else {
94            results
95                .get(0)
96                .map(|(i, _child)| NodeHandle::new(*i as u32 + 1 + node.get_inner()))
97        }
98    }
99
100    /// Score of this SelectorPart (lower is better)
101    fn score(&self) -> i32 {
102        match self {
103            SelectorPart::Tag(tag) => tag.len() as i32 + 1,
104            SelectorPart::Class(class) => class.len() as i32 + 1,
105            SelectorPart::Id(_) => 0,
106            SelectorPart::NthChild(n) => 13 + (*n as i32 / 2),
107        }
108    }
109}
110
111#[derive(Clone)]
112#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
113pub struct Selector {
114    parts: Vec<SelectorPart>,
115    pub string: String,
116    pub score: i32,
117}
118
119impl PartialEq for Selector {
120    fn eq(&self, other: &Selector) -> bool {
121        self.string.eq(&other.string)
122    }
123}
124
125impl Eq for Selector {}
126
127impl TrieKey for Selector {
128    fn encode_bytes(&self) -> Vec<u8> {
129        TrieKey::encode_bytes(&self.string)
130    }
131}
132
133impl Selector {
134    /// Create a new selector from multiple SelectorParts.
135    /// The parts are interspersed with " > "; this means the element matched by each SelectorPart
136    /// must be the *direct parent* of the element matched by the nexted SelectorPart.
137    pub fn new_from_parts(parts: Vec<SelectorPart>) -> Self {
138        // TODO use intersperse once stabilized
139        let string = parts
140            .iter()
141            .map(|part| part.to_string())
142            .collect::<Vec<String>>()
143            .join(" > ");
144        let score = parts.iter().map(|part| part.score()).sum();
145        Selector {
146            parts,
147            string,
148            score,
149        }
150    }
151
152    pub fn len(&self) -> usize {
153        self.parts.len()
154    }
155
156    pub fn try_select_with_skip(
157        &self,
158        handle: NodeHandle,
159        parser: &Parser,
160        skip: usize,
161    ) -> Option<NodeHandle> {
162        self.parts
163            .iter()
164            .skip(skip)
165            .fold(Some(handle), |acc, selector| {
166                acc.and_then(|node| selector.try_select(node, parser))
167            })
168    }
169
170    pub fn try_select_with_skip_path(
171        &self,
172        handle: NodeHandle,
173        parser: &Parser,
174        skip: usize,
175        max_len: usize,
176    ) -> Vec<Option<NodeHandle>> {
177        self.parts
178            .iter()
179            .skip(skip)
180            .fold(vec![], |mut path, selector| {
181                if path.len() >= max_len {
182                    return path;
183                }
184
185                // Continue from last node or root node
186                let last = if path.is_empty() {
187                    Some(handle)
188                } else {
189                    *path.last().unwrap()
190                };
191
192                if let Some(last_node) = last {
193                    path.push(selector.try_select(last_node, parser));
194                } else {
195                    path.push(None);
196                }
197
198                path
199            })
200    }
201
202    /// Tries to find a node matching this Selector by searching all nodes below
203    /// `handle`. A result will be returned iff exactly one element matched.
204    pub fn try_select(&self, handle: NodeHandle, parser: &Parser) -> Option<NodeHandle> {
205        self.try_select_with_skip(handle, parser, 0)
206    }
207
208    pub fn try_select_path(
209        &self,
210        handle: NodeHandle,
211        parser: &Parser,
212        max_len: usize,
213    ) -> Vec<Option<NodeHandle>> {
214        self.try_select_with_skip_path(handle, parser, 0, max_len)
215    }
216
217    pub(crate) fn score(&self) -> i32 {
218        self.score
219    }
220
221    /// Creates a new selector which is the combination of this selector and `other`.
222    /// `other` will be the lower part of the selector.
223    fn append(&self, mut other: Selector) -> Self {
224        let mut selectors = Vec::with_capacity(other.parts.len() + self.parts.len());
225        selectors.append(&mut self.parts.clone());
226        selectors.append(&mut other.parts);
227
228        Selector::new_from_parts(selectors)
229    }
230
231    /// Creates two new selectors by splitting the parts of this selector at `depth`.
232    fn split_at(&self, depth: usize) -> (Self, Self) {
233        let mut cloned = self.parts.clone();
234        let tail = cloned.split_off(depth);
235        (
236            Selector::new_from_parts(cloned),
237            Selector::new_from_parts(tail),
238        )
239    }
240}
241
242impl std::fmt::Debug for Selector {
243    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
244        write!(fmt, "{}", self.string)
245    }
246}
247
248impl From<SelectorPart> for Selector {
249    fn from(value: SelectorPart) -> Self {
250        Selector::new_from_parts(vec![value])
251    }
252}
253
254impl ToString for Selector {
255    fn to_string(&self) -> String {
256        self.string.clone()
257    }
258}
259
260pub(crate) struct SelectorCache {
261    selector_cache: Trie<Selector, (usize, Option<NodeHandle>)>,
262}
263
264impl SelectorCache {
265    /// Enable/disable caching
266    const ENABLED: bool = true;
267
268    /// Whether we always cache the "leaf node", i.e. the "deepest" result of the selector
269    const ALWAYS_CACHE_LEAF: bool = true;
270
271    /// If AGGRESSIVE_ADD_MAX_DEPTH is > 0, cache elements up to a depth of
272    /// AGGRESSIVE_ADD_MAX_DEPTH from the root node even if they have not
273    /// been explicitely requested.
274    const AGGRESSIVE_ADD_MAX_DEPTH: usize = 4;
275
276    pub(crate) fn new() -> Self {
277        SelectorCache {
278            selector_cache: Default::default(),
279        }
280    }
281
282    /// Tries to select a target node by applying the selector to root.
283    /// Uses a Trie to cache and reuse partial results.
284    pub(crate) fn try_select(
285        &mut self,
286        selector: &Selector,
287        root: NodeHandle,
288        parser: &Parser,
289    ) -> Option<NodeHandle> {
290        if let Some((ancestor_length, ancestor_handle)) =
291            self.selector_cache.get_ancestor_value(selector)
292        {
293            if ancestor_handle.is_some() && *ancestor_length < selector.len() {
294                let target = selector.try_select_with_skip(
295                    ancestor_handle.unwrap(),
296                    parser,
297                    *ancestor_length,
298                );
299                if SelectorCache::ENABLED {
300                    let len = *ancestor_length;
301                    if SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH > len {
302                        selector
303                            .try_select_with_skip_path(
304                                ancestor_handle.unwrap(),
305                                parser,
306                                len,
307                                SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH - len,
308                            )
309                            .iter()
310                            .enumerate()
311                            .for_each(|(i, subnode)| {
312                                self.selector_cache.insert(
313                                    selector.split_at(len + i + 1).0,
314                                    (len + i + 1, *subnode),
315                                );
316                            });
317                    }
318                    if SelectorCache::ALWAYS_CACHE_LEAF
319                        && SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH - len < selector.len()
320                    {
321                        self.selector_cache
322                            .insert(selector.clone(), (selector.len(), target));
323                    }
324                }
325                target
326            } else {
327                *ancestor_handle
328            }
329        } else {
330            let target = selector.try_select(root, parser);
331            if SelectorCache::ENABLED {
332                if SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH > 0 {
333                    selector
334                        .try_select_path(root, parser, SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH)
335                        .iter()
336                        .enumerate()
337                        .for_each(|(i, subnode)| {
338                            self.selector_cache
339                                .insert(selector.split_at(i + 1).0, (i + 1, *subnode));
340                        });
341                }
342                if SelectorCache::ALWAYS_CACHE_LEAF
343                    && SelectorCache::AGGRESSIVE_ADD_MAX_DEPTH < selector.len()
344                {
345                    self.selector_cache
346                        .insert(selector.clone(), (selector.len(), target));
347                }
348            }
349            target
350        }
351    }
352}
353
354pub struct SelectorFuzzer {
355    root_selector_cache: SelectorCache,
356    pub(crate) retries_used: usize,
357}
358
359impl SelectorFuzzer {
360    pub fn new() -> Self {
361        SelectorFuzzer {
362            root_selector_cache: SelectorCache::new(),
363            retries_used: 0,
364        }
365    }
366
367    /// Attempt to create a new selector by mutating the given input selector.
368    /// Mutating in this case means we split the selector at a random point and create a new
369    /// random selector for the upper part of the selector.
370    pub(crate) fn mutate_selector<R: Rng>(
371        &mut self,
372        selector: &Selector,
373        root: NodeHandle,
374        parser: &Parser,
375        retries: usize,
376        rng: &mut R,
377    ) -> Option<Selector> {
378        if selector.parts.len() <= 1 {
379            return None;
380        }
381
382        let random_index = rng.borrow_mut().gen_range(1..selector.parts.len());
383        let (left, right) = selector.split_at(random_index);
384        let left_node = self.root_selector_cache.try_select(&left, root, parser)?;
385        let new_left = self.random_selector_for_node(left_node, root, parser, retries, rng)?;
386        Some(new_left.append(right))
387    }
388
389    /// Recursively generate a random selector for node `handle`. `root` is the root-node
390    /// of the subtree.
391    pub fn random_selector_for_node<R: Rng>(
392        &mut self,
393        handle: NodeHandle,
394        root: NodeHandle,
395        parser: &Parser,
396        retries: usize,
397        rng: &mut R,
398    ) -> Option<Selector> {
399        let tag = handle.get(parser)?.as_tag()?;
400
401        if let Some(id) = util::get_id(handle, parser) {
402            return Some(Selector::from(SelectorPart::Id(id.to_string())));
403        }
404
405        let parent = util::find_parent(handle, parser);
406        let has_parent = parent.is_some();
407        for tries in 0..retries {
408            self.retries_used += 1;
409            let typ = rng.gen_range(0..3);
410
411            let selector = match typ {
412                0 => Selector::from(SelectorPart::Tag(tag.name().as_utf8_str().to_string())),
413                1 => {
414                    let classes = tag.attributes().class_iter()?.collect::<Vec<_>>();
415                    if classes.is_empty() {
416                        continue;
417                    }
418                    let random_index = rng.gen_range(0..classes.len());
419                    Selector::from(SelectorPart::Class(classes[random_index].to_string()))
420                }
421                2 => {
422                    if !has_parent {
423                        continue;
424                    }
425                    let parent = parent.unwrap().get(parser).unwrap();
426                    let index = parent
427                        .children()
428                        .unwrap()
429                        .top()
430                        .iter()
431                        .filter(|child| util::node_is_tag(child, parser))
432                        .position(|child| child.get_inner() == handle.get_inner())
433                        .expect("child of parent should exists in parent.children()");
434                    Selector::from(SelectorPart::NthChild(index + 1))
435                }
436                _ => unreachable!(),
437            };
438
439            let globally_unique = typ != 2
440                && matches!(self.root_selector_cache.try_select(&selector, root, parser), Some(h) if h == handle);
441            if globally_unique {
442                return Some(selector);
443            }
444            let locally_unique = has_parent
445                && matches!(selector.try_select(parent.unwrap(), parser), Some(h) if h == handle);
446            if locally_unique {
447                let parent_selector = self.random_selector_for_node(
448                    parent.unwrap(),
449                    root,
450                    parser,
451                    retries - tries,
452                    rng,
453                );
454                if let Some(parent_selector) = parent_selector {
455                    let combined_selector = parent_selector.append(selector);
456                    return Some(combined_selector);
457                }
458            }
459        }
460        None
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use crate::selectors::*;
467    use crate::util;
468    use rand::SeedableRng;
469    use rand_chacha::ChaCha8Rng;
470
471    use tl::VDom;
472
473    const HTML: &'static str = r#"
474        <div class="div_class">
475            <div id="div_id">
476                <p class="p_class">TARGET</p>
477                <p class="other_class">...</p>
478            </div>
479        </div>
480        "#;
481
482    fn get_simple_example() -> VDom<'static> {
483        let dom = tl::parse(HTML, tl::ParserOptions::default()).unwrap();
484        dom
485    }
486
487    #[test]
488    fn test_find_node_with_text() {
489        let dom = get_simple_example();
490        let parser = dom.parser();
491        let node = util::find_node_with_text(&dom, "TARGET").unwrap();
492        assert_eq!(util::get_classes(node, parser).unwrap(), "p_class")
493    }
494
495    #[test]
496    fn test_find_parent() {
497        let dom = get_simple_example();
498        let parser = dom.parser();
499        let element: NodeHandle = dom.query_selector("p").unwrap().next().unwrap();
500        assert_eq!(util::get_classes(element, parser).unwrap(), "p_class");
501        let parent = util::find_parent(element, parser).unwrap();
502        assert_eq!(util::get_id(parent, parser).unwrap(), "div_id");
503        let parent_parent = util::find_parent(parent, parser).unwrap();
504        assert_eq!(
505            util::get_classes(parent_parent, parser).unwrap(),
506            "div_class"
507        );
508        let parent_parent_parent = util::find_parent(parent_parent, parser);
509        assert_eq!(parent_parent_parent, None)
510    }
511
512    #[test]
513    fn test_selector() {
514        fn test_selector(
515            selector: Selector,
516            expected_str: &str,
517            _parser: &Parser,
518        ) -> Option<NodeHandle> {
519            assert_eq!(selector.to_string(), expected_str);
520            let dom = get_simple_example();
521            let parser = dom.parser();
522            selector.try_select(/* root node = */ NodeHandle::new(1), parser)
523        }
524
525        let dom = get_simple_example();
526        let parser = dom.parser();
527
528        let target = test_selector(
529            Selector::new_from_parts(vec![SelectorPart::Id("div_id".into())]),
530            "#div_id",
531            parser,
532        );
533        assert_eq!(util::get_id(target.unwrap(), parser).unwrap(), "div_id");
534
535        let target = test_selector(
536            Selector::new_from_parts(vec![
537                SelectorPart::Id("div_id".into()),
538                SelectorPart::NthChild(1),
539            ]),
540            "#div_id > :nth-child(1)",
541            parser,
542        );
543        assert_eq!(
544            util::get_classes(target.unwrap(), parser).unwrap(),
545            "p_class"
546        );
547
548        let target = test_selector(
549            Selector::new_from_parts(vec![
550                SelectorPart::Id("div_id".into()),
551                SelectorPart::NthChild(2),
552            ]),
553            "#div_id > :nth-child(2)",
554            parser,
555        );
556        assert_eq!(
557            util::get_classes(target.unwrap(), parser).unwrap(),
558            "other_class"
559        );
560
561        let target = test_selector(
562            Selector::new_from_parts(vec![
563                SelectorPart::NthChild(1),
564                SelectorPart::Class("p_class".into()),
565            ]),
566            ":nth-child(1) > .p_class",
567            parser,
568        );
569        assert_eq!(
570            util::get_classes(target.unwrap(), parser).unwrap(),
571            "p_class"
572        );
573
574        let target = test_selector(
575            Selector::new_from_parts(vec![SelectorPart::Class("p_class".into())]),
576            ".p_class",
577            parser,
578        );
579        assert_eq!(
580            util::get_classes(target.unwrap(), parser).unwrap(),
581            "p_class"
582        );
583
584        let target = test_selector(
585            Selector::new_from_parts(vec![SelectorPart::Tag("div".into())]),
586            "div",
587            parser,
588        ); // Only one div because the outer div is the root node
589        assert_eq!(util::get_id(target.unwrap(), parser).unwrap(), "div_id");
590
591        let target = test_selector(
592            Selector::new_from_parts(vec![SelectorPart::Tag("p".into())]),
593            "p",
594            parser,
595        );
596        assert_eq!(target, None);
597
598        let target = test_selector(
599            Selector::new_from_parts(vec![
600                SelectorPart::Tag("div".into()),
601                SelectorPart::Tag("p".into()),
602            ]),
603            "div > p",
604            parser,
605        );
606        assert_eq!(target, None);
607    }
608
609    #[test]
610    fn test_random_selector() {
611        const HTML: &'static str = r#"
612        <root>
613        <div class="div_class">
614            <div class="div_id_class">
615                <p class="p_class">TARGET</p>
616                <p class="p_class">...</p>
617            </div>
618        </div>
619        </root>
620        "#;
621
622        let dom = tl::parse(HTML, tl::ParserOptions::default()).unwrap();
623        let parser = dom.parser();
624        let root = util::find_root(&dom).unwrap();
625        let target = dom.query_selector(".p_class").unwrap().next().unwrap();
626        let mut rng = ChaCha8Rng::seed_from_u64(1337);
627        let mut searcher = SelectorFuzzer::new();
628        println!(
629            "{:?}",
630            searcher
631                .random_selector_for_node(target, *root, parser, 10, &mut rng)
632                .map(|sel| sel.to_string())
633        );
634    }
635}