mlscraper_rust/
search.rs

1use std::collections::HashMap;
2use std::ops::Deref;
3
4use anyhow::{anyhow, Result};
5use log::{info, trace, warn};
6use rand::Rng;
7use std::cell::RefCell;
8use std::collections::HashSet;
9use std::time::Instant;
10
11use crate::selectors::*;
12use crate::util;
13use crate::util::{find_root, TextRetrievalOption, TextRetrievalOptions};
14use tl::{NodeHandle, VDom};
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19/// Strategy for dealing with missing data (expected attribute value is `None`)
20#[derive(Debug)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub enum MissingDataStrategy {
23    /// If an expected attribute value is `None`, we do not expect the selector to match any node.
24    AllowMissingNode,
25
26    /// If an expected attribute value is `None`, the node must still exist, and it text value
27    /// (see [util::get_node_text]) must be `None` (empty text).
28    NodeMustExist,
29}
30
31/// Strategy for dealing with multiple nodes matching the expected attribute value
32#[derive(Debug)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub enum MultipleMatchesStrategy {
35    /// Choose the node which results in the best selector
36    PrioritizeBestSelector,
37
38    /// Prefer first matching node
39    PrioritizeFirstMatch,
40}
41
42/// Utility struct for constructing Attributes.
43pub struct AttributeBuilder<'a> {
44    name: String,
45    values: Option<Vec<Option<String>>>,
46    filter: Option<&'a dyn Fn(&Selector) -> bool>,
47}
48
49impl<'a> AttributeBuilder<'a> {
50    /// Create a new attribute with the given name. The name must be unique.
51    pub fn new<S: Into<String>>(name: S) -> Self {
52        AttributeBuilder {
53            name: name.into(),
54            values: None,
55            filter: None,
56        }
57    }
58
59    /// Set the values of this attribute. The order of the values must be consistent
60    /// with the order of the documents as passed to [Training].
61    pub fn values(mut self, values: &[Option<&str>]) -> Self {
62        self.values = Some(
63            values
64                .iter()
65                .map(|option| option.map(|string| string.to_string()))
66                .collect(),
67        );
68        self
69    }
70
71    /// Set the filter function of this attribute. A selector for this attribute
72    /// is only considered valid if the filter function returns true on it.
73    pub fn filter(mut self, function: &'a dyn Fn(&Selector) -> bool) -> Self {
74        self.filter = Some(function);
75        self
76    }
77
78    /// Build the attribute.
79    pub fn build(self) -> Attribute<'a> {
80        Attribute {
81            name: self.name,
82            values: self.values.unwrap_or(vec![]),
83            filter: self.filter,
84        }
85    }
86}
87
88/// An attribute of a (multiple) web page(s) that is to be scraped.
89/// It is recommended to use [AttributeBuilder] to construct [Attribute]s.
90///
91/// Must have a user-defined arbitrary but unique name for identification
92/// purposes.
93///
94/// Must have one value per document (although this value may be `None` if
95/// no value is expected on a particular document). The order of the values
96/// must be consistent with the order of the documents as passed to [Training].
97///
98/// Additionally, a filter function may be defined. A selector for this attribute
99/// is only considered valid if the filter function returns true on it.
100pub struct Attribute<'a> {
101    pub(crate) name: String,
102    pub(crate) values: Vec<Option<String>>,
103    pub(crate) filter: Option<&'a dyn Fn(&Selector) -> bool>,
104}
105
106/// A selector for a single attribute of a single web page.
107#[derive(Clone)]
108struct CheckedSelector {
109    selector: Selector,
110    checked_on_all_documents: bool,
111}
112
113impl CheckedSelector {
114    fn new(selector: Selector) -> Self {
115        Self {
116            selector,
117            checked_on_all_documents: false,
118        }
119    }
120
121    fn new_checked(selector: Selector) -> Self {
122        Self {
123            selector,
124            checked_on_all_documents: true,
125        }
126    }
127}
128
129impl Deref for CheckedSelector {
130    type Target = Selector;
131
132    fn deref(&self) -> &Self::Target {
133        &self.selector
134    }
135}
136
137/// Settings for the fuzzing algorithm.
138///
139/// The default settings should be good for many applications,
140/// but you might want to adjust how missing data is treated
141/// ([`FuzzerSettings::missing_data_strategy`]), how duplicate matches are
142/// handled ([`FuzzerSettings::multiple_matches_strategy`]), and what parts of
143/// the documents are considered text ([`FuzzerSettings::text_retrieval_options`]).
144///
145/// If you encounter performance problems or are not satisfied with the results,
146/// you can experiment with the random generation/mutation settings.
147#[derive(Debug)]
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149pub struct FuzzerSettings {
150    /// Strategy for dealing with missing data (expected attribute value is `None`)
151    pub missing_data_strategy: MissingDataStrategy,
152
153    /// Strategy for dealing with multiple nodes matching the expected attribute value
154    pub multiple_matches_strategy: MultipleMatchesStrategy,
155
156    /// Options for retrieving text from nodes
157    pub text_retrieval_options: util::TextRetrievalOptions,
158
159    /// Number of random selectors to generate per attribute
160    pub random_generation_count: usize,
161    /// Number of times to retry generating a random selector before giving up
162    pub random_generation_retries: usize,
163    /// Number of selectors to keep per attribute after random generation
164    pub survivor_count: usize,
165    /// Number of random mutations to generate after random generation
166    pub random_mutation_count: usize,
167}
168
169impl Default for FuzzerSettings {
170    /// Default settings for the fuzzer.
171    fn default() -> Self {
172        let mut default_text_retrieval_options = TextRetrievalOptions::new();
173        default_text_retrieval_options.push(TextRetrievalOption::InnerText);
174        default_text_retrieval_options.push(TextRetrievalOption::Attribute("title".into()));
175        default_text_retrieval_options.push(TextRetrievalOption::Attribute("alt".into()));
176
177        FuzzerSettings {
178            missing_data_strategy: MissingDataStrategy::NodeMustExist,
179            multiple_matches_strategy: MultipleMatchesStrategy::PrioritizeFirstMatch,
180            random_generation_count: 100,
181            random_generation_retries: 100,
182            survivor_count: 10,
183            random_mutation_count: 20,
184            text_retrieval_options: default_text_retrieval_options,
185        }
186    }
187}
188
189/// The result of "training" the fuzzer on a set of web pages.
190/// Contains the selectors for each attribute, as well as the original settings used.
191/// If training for a particular attribute failed, the attribute/selector pair will be not present in this object.
192///
193/// This result can also be used to extract data from previously unseen documents,
194/// for example:
195///
196/// ```ignore
197/// let mut dom = result.parse(&new_page).expect("parse");
198/// let attribute_1 = result.get_value(&dom, "attribute_1_name").expect("get_value");
199/// let attribute_2 = result.get_value(&dom, "attribute_2_name").expect("get_value");
200/// // ...
201/// ```
202///
203/// Enable the "serde" feature to enable serialization/deserialization using 
204/// serde. This can be useful for reusing previously computed training results.
205#[derive(Debug)]
206#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
207pub struct TrainingResult {
208    selectors: HashMap<String, Selector>,
209    settings: FuzzerSettings,
210}
211
212impl TrainingResult {
213    pub fn selectors(&self) -> &HashMap<String, Selector> {
214        &self.selectors
215    }
216
217    pub fn attributes<'a>(&'a self) -> Box<dyn Iterator<Item = &'a str> + 'a> {
218        Box::new(self.selectors.keys().map(|s| s.as_ref()))
219    }
220
221    /// Parse a document and return the DOM object.
222    /// Calling this and reusing the DOM object is more efficient than calling [`TrainingResult::parse_and_get_value`] multiple times.
223    pub fn parse<'s>(&self, document: &'s str) -> Result<VDom<'s>> {
224        tl::parse(document, tl::ParserOptions::default())
225            .map_err(|_| anyhow!("Failed to parse document!"))
226    }
227
228    /// Parse a document and return the value of the given attribute.
229    /// This is equivalent to calling [`TrainingResult::parse`] and then [`TrainingResult::get_value`].
230    pub fn parse_and_get_value(
231        &self,
232        document: &str,
233        attribute_name: &str,
234    ) -> Result<Option<String>> {
235        let dom = tl::parse(document, tl::ParserOptions::default())?;
236        self.get_value(&dom, attribute_name)
237    }
238
239    /// Get the value of the given attribute from the given DOM object.
240    pub fn get_value<'a>(&self, dom: &'a VDom<'a>, attribute_name: &str) -> Result<Option<String>> {
241        if !self.selectors.contains_key(attribute_name) {
242            return Err(anyhow!("Attribute {:?} not found!", attribute_name));
243        }
244
245        let root = find_root(&dom).ok_or(anyhow!("Could not find root node in document!"))?;
246        Ok(self
247            .selectors
248            .get(attribute_name)
249            .unwrap()
250            .try_select(*root, &dom.parser())
251            .and_then(|node| {
252                util::get_node_text(&dom, node, &self.settings.text_retrieval_options)
253            }))
254    }
255
256    /// Get the best selector for the given attribute.
257    pub fn get_selector<'a>(&'a self, attribute_name: &str) -> Option<&'a str> {
258        self.selectors
259            .get(attribute_name)
260            .map(|selector| selector.string.as_ref())
261    }
262
263    /// Highlight the selected elements for the given attribute in the given DOM object
264    /// by adding a red border around them.
265    ///
266    /// This will both alter the input DOM *and* return the resulting HTML as String,
267    /// which, as I realize writing this, may be a poor design choice. TODO.
268    ///
269    /// Example:
270    /// ´´´
271    /// let out_html = training_result.highlight_selections_with_red_border(&mut dom);
272    /// fs::write("out.html", out_html).expect("write");
273    /// ´´´
274    pub fn highlight_selections_with_red_border(&self, dom: &mut VDom<'_>) -> String {
275        self.selectors().values().for_each(|selector| {
276            util::style_selected_element(selector, dom);
277        });
278        dom.outer_html()
279    }
280}
281
282/// Represents a training process.
283/// Contains the documents and attributes that are used for training.
284pub struct Training<'a> {
285    documents: Vec<VDom<'a>>,
286    document_roots: Vec<NodeHandle>,
287    document_selector_caches: Vec<RefCell<SelectorCache>>,
288    attributes: Vec<Attribute<'a>>,
289    selector_pool: HashMap<String, Vec<CheckedSelector>>,
290    settings: FuzzerSettings,
291}
292
293impl<'a> Training<'a> {
294    /// The documents that are used for training.
295    pub fn documents<'l>(&'l self) -> &'l Vec<VDom<'a>> {
296        &self.documents
297    }
298
299    pub fn documents_mut<'l>(&'l mut self) -> &'l mut Vec<VDom<'a>> {
300        &mut self.documents
301    }
302
303    /// The attributes that are used for training.
304    pub fn attributes<'l>(&'l self) -> &'l Vec<Attribute<'a>> {
305        &self.attributes
306    }
307
308    pub fn new(documents: Vec<VDom<'a>>, attributes: Vec<Attribute<'a>>) -> Result<Self> {
309        Self::with_settings(documents, attributes, Default::default())
310    }
311
312    pub fn with_settings(
313        documents: Vec<VDom<'a>>,
314        attributes: Vec<Attribute<'a>>,
315        settings: FuzzerSettings,
316    ) -> Result<Self> {
317        let document_roots = documents
318            .iter()
319            .filter_map(find_root)
320            .copied()
321            .collect::<Vec<_>>();
322        if document_roots.len() != documents.len() {
323            return Err(anyhow!(
324                "Failed to find root node in at least one input document!"
325            ));
326        }
327
328        if attributes
329            .iter()
330            .any(|attr| attr.values.len() != documents.len())
331        {
332            return Err(anyhow!(
333                "At least one attribute has an incorrect number of values!"
334            ));
335        }
336
337        let mut unique = HashSet::new();
338        if let Some(duplicate) = attributes
339            .iter()
340            .find(|attr| !unique.insert(attr.name.clone()))
341        {
342            return Err(anyhow!("Duplicate attribute {:?}!", duplicate.name));
343        }
344
345        let document_selector_caches = documents
346            .iter()
347            .map(|_| RefCell::new(SelectorCache::new()))
348            .collect();
349
350        let training = Training {
351            documents,
352            document_roots,
353            document_selector_caches,
354            attributes,
355            selector_pool: Default::default(),
356            settings,
357        };
358
359        Ok(training)
360    }
361
362    /// Find all nodes in the given document that contain the given text as defined
363    /// by the [TextRetrievalOptions] in [FuzzerSettings].
364    fn find_all_nodes_with_text(&self, vdom: &VDom, text: &str) -> Vec<NodeHandle> {
365        let trim = text.trim();
366        vdom.nodes()
367            .iter()
368            .enumerate()
369            .map(|(i, _)| NodeHandle::new(i as u32))
370            .filter(|node| {
371                matches!(
372                    util::get_node_text(vdom, *node, &self.settings.text_retrieval_options),
373                    Some(text) if trim == text
374                )
375            })
376            .collect()
377    }
378
379    /// Check whether `selector` successfully selects `attribute` on all documents.
380    fn check_selector(
381        &self,
382        selector: &Selector,
383        attribute: &Attribute,
384        ignore_document: Option<usize>,
385    ) -> Result<(), usize> {
386        for i in 0..self.documents.len() {
387            if matches!(ignore_document, Some(d) if d == i) {
388                continue;
389            }
390            let node = self.document_selector_caches[i].borrow_mut().try_select(
391                selector,
392                self.document_roots[i],
393                self.documents[i].parser(),
394            );
395            let node_text_value = node.and_then(|node| {
396                util::get_node_text(
397                    &self.documents[i],
398                    node,
399                    &self.settings.text_retrieval_options,
400                )
401            });
402            let expected = attribute.values[i].as_ref();
403
404            if expected.is_none() {
405                match self.settings.missing_data_strategy {
406                    MissingDataStrategy::AllowMissingNode => {
407                        let ok = node.is_none() || node_text_value.is_none();
408                        if !ok {
409                            return Err(i);
410                        }
411                    }
412                    MissingDataStrategy::NodeMustExist => {
413                        let ok = node.is_some() && node_text_value.is_none();
414                        if !ok {
415                            return Err(i);
416                        }
417                    }
418                }
419            } else if node_text_value.is_none() || &node_text_value.unwrap() != expected.unwrap() {
420                return Err(i);
421            }
422        }
423        Ok(())
424    }
425
426    /// Perform one round of generation, mutation, and sorting for every attribute.
427    pub fn do_one_fuzzing_round<R: Rng>(&mut self, rng: &mut R) {
428        for attribute in &self.attributes {
429            let mut error_vote = vec![0; self.documents.len()];
430
431            // Generate selectors for each document
432            let mut document_selectors = self
433                .documents
434                .iter()
435                .enumerate()
436                .filter_map(|(i, vdom)| {
437                    // Determine target node by looking for node with text matching expected attribute
438                    // value.
439                    let target_nodes =
440                        self.find_all_nodes_with_text(vdom, attribute.values[i].as_ref()?.as_str());
441                    let mut random_target_weights = Vec::new();
442                    if target_nodes.is_empty() {
443                        warn!(
444                            "No matching target nodes for attribute {:?} in document {}",
445                            attribute.name, i
446                        );
447                        return None;
448                    }
449                    if target_nodes.len() > 1 {
450                        match self.settings.multiple_matches_strategy {
451                            MultipleMatchesStrategy::PrioritizeBestSelector => {
452                                // Randomly select between all target_nodes
453                                random_target_weights =
454                                    vec![1f32 / target_nodes.len() as f32; target_nodes.len()];
455                            }
456                            MultipleMatchesStrategy::PrioritizeFirstMatch => {
457                                // Prioritize first with 90%
458                                random_target_weights = vec![
459                                    0.1f32
460                                        / ((target_nodes.len() - 1) as f32);
461                                    target_nodes.len()
462                                ];
463                                random_target_weights[0] = 0.9f32;
464                            }
465                        }
466                    }
467
468                    let mut searcher = SelectorFuzzer::new();
469                    // TODO remove clone
470                    let mut selector_pool = self
471                        .selector_pool
472                        .get(&attribute.name)
473                        .cloned()
474                        .unwrap_or(Vec::new());
475                    trace!(
476                        "We have {} selectors for attribute {:?} from the previous iteration",
477                        selector_pool.len(),
478                        attribute.name
479                    );
480                    selector_pool.reserve(self.settings.random_generation_count);
481                    let start_time = Instant::now();
482                    (0..self.settings.random_generation_count)
483                        .filter_map(|_| {
484                            // Choose random target node based on weights
485                            let index = if target_nodes.len() == 1 {
486                                0
487                            } else {
488                                util::random_index_weighted(rng, &random_target_weights)
489                            };
490                            // Generate random selector for target node
491                            searcher
492                                .random_selector_for_node(
493                                    target_nodes[index],
494                                    self.document_roots[i],
495                                    vdom.parser(),
496                                    self.settings.random_generation_retries,
497                                    rng,
498                                )
499                                .map(CheckedSelector::new)
500                        })
501                        .for_each(|selector| {
502                            selector_pool.push(selector);
503                        });
504                    let elapsed_ms = start_time.elapsed().as_millis();
505
506                    trace!(
507                        "Generation: {} total selectors for attribute {:?} in document {}",
508                        selector_pool.len(),
509                        attribute.name,
510                        i
511                    );
512                    trace!(
513                        "Generation rate: {} in {}ms, {:.2}/s",
514                        self.settings.random_generation_count,
515                        elapsed_ms,
516                        self.settings.random_generation_count as f32 / elapsed_ms as f32 * 1000.
517                    );
518                    trace!(
519                        "Generation retries avg. {:.2} ({} total)",
520                        searcher.retries_used as f32 / self.settings.random_generation_count as f32,
521                        searcher.retries_used
522                    );
523                    selector_pool.dedup_by_key(|selector| selector.to_string());
524                    trace!("De-dup: {} selectors left", selector_pool.len());
525                    if let Some(filter) = attribute.filter {
526                        selector_pool.retain(|selector| filter(selector));
527                        trace!(
528                            "User-defined filter: {} selectors left",
529                            selector_pool.len()
530                        );
531                    }
532                    let start_time = Instant::now();
533                    selector_pool.retain_mut(|mut selector| {
534                        if selector.checked_on_all_documents {
535                            return true;
536                        }
537
538                        if let Err(index) = self.check_selector(selector, attribute, Some(i)) {
539                            error_vote[index] += 1;
540                            false
541                        } else {
542                            selector.checked_on_all_documents = true;
543                            true
544                        }
545                    });
546                    let elapsed_ms = start_time.elapsed().as_millis();
547                    trace!(
548                        "Matching all documents: {} selectors left (check_selector took {}ms)",
549                        selector_pool.len(),
550                        elapsed_ms
551                    );
552                    if selector_pool.is_empty() {
553                        return None;
554                    }
555                    selector_pool.sort_by_key(|selector| selector.score());
556                    if selector_pool.len() > self.settings.survivor_count {
557                        selector_pool.resize_with(self.settings.survivor_count, || unreachable!());
558                    }
559                    trace!("Survivors: {} selectors left", selector_pool.len());
560                    let start_time = Instant::now();
561                    for j in 0..usize::min(selector_pool.len(), self.settings.random_mutation_count)
562                    {
563                        let mutated = searcher.mutate_selector(
564                            &selector_pool[j],
565                            self.document_roots[i],
566                            self.documents[i].parser(),
567                            self.settings.random_generation_retries,
568                            rng,
569                        );
570                        if let Some(mutated) = mutated {
571                            if let Err(index) = self.check_selector(&mutated, attribute, Some(i)) {
572                                error_vote[index] += 1;
573                            } else {
574                                selector_pool.push(CheckedSelector::new_checked(mutated));
575                            }
576                        }
577                    }
578                    let elapsed_ms = start_time.elapsed().as_millis();
579                    trace!(
580                        "After mutation: {} selectors for attribute {:?} in document {}",
581                        selector_pool.len(),
582                        attribute.name,
583                        i
584                    );
585                    trace!(
586                        "Mutation rate: {} in {}ms, {:.2}/s",
587                        self.settings.random_mutation_count,
588                        elapsed_ms,
589                        self.settings.random_mutation_count as f32 / elapsed_ms as f32 * 1000.
590                    );
591                    selector_pool.dedup_by_key(|selector| selector.to_string());
592                    trace!(
593                        "De-dup after mutation: {} selectors left",
594                        selector_pool.len()
595                    );
596                    if let Some(filter) = attribute.filter {
597                        selector_pool.retain(|selector| filter(selector));
598                        trace!(
599                            "User-defined filter: {} selectors left",
600                            selector_pool.len()
601                        );
602                    }
603
604                    Some(selector_pool)
605                })
606                .flatten()
607                .collect::<Vec<_>>();
608
609            if document_selectors.is_empty() {
610                warn!(
611                    "No selectors for attribute {}! Likely problematic attribute value: {:?}",
612                    attribute.name,
613                    error_vote
614                        .iter()
615                        .zip(attribute.values.iter())
616                        .max_by_key(|(votes, _)| **votes)
617                        .map(|(_, name)| name)
618                        .unwrap()
619                );
620                continue;
621            }
622
623            document_selectors.dedup_by_key(|selector| selector.to_string());
624            document_selectors.sort_by_key(|selector| selector.score());
625            if document_selectors.len() > self.settings.survivor_count {
626                document_selectors.resize_with(self.settings.survivor_count, || unreachable!());
627            }
628
629            info!(
630                "Selector with best score for attribute {:?}:\n{}",
631                attribute.name,
632                document_selectors[0].to_string()
633            );
634            self.selector_pool
635                .insert(attribute.name.clone(), document_selectors);
636        }
637    }
638
639    /// Returns the best selector for the given attribute, if any.
640    pub fn get_best_selector_for(&self, attribute: &Attribute) -> Option<Selector> {
641        self.selector_pool
642            .get(&attribute.name)
643            .and_then(|selectors| selectors.get(0))
644            .map(|selector| selector.selector.clone())
645    }
646
647    /// Turns this training into a [`TrainingResult`], consuming the training.
648    pub fn to_result(self) -> TrainingResult {
649        let selectors = self
650            .attributes
651            .iter()
652            .filter_map(|attr| {
653                if let Some(selector) = self.get_best_selector_for(&attr) {
654                    Some((attr.name.clone(), selector))
655                } else {
656                    None
657                }
658            })
659            .collect();
660
661        TrainingResult {
662            selectors,
663            settings: self.settings,
664        }
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use crate::search::{Attribute, MissingDataStrategy, Training};
671    use crate::selectors::*;
672    use crate::*;
673    use rand::SeedableRng;
674    use rand_chacha::ChaCha8Rng;
675    use tl::VDom;
676
677    const HTML: [&'static str; 2] = [
678        r#"
679        <div id="root" class="root">
680            <img id="1" alt="blubb" />
681            <img id="2" title="glogg" />
682            <p id="3">plapp_before</p>
683        </div>
684        "#,
685        r#"
686        <div id="root" class="root">
687            <img id="1" alt="" />
688            <p id="3">plapp_after</p>
689        </div>
690        "#,
691    ];
692
693    fn get_simple_example() -> (VDom<'static>, VDom<'static>) {
694        (
695            tl::parse(HTML[0], tl::ParserOptions::default()).unwrap(),
696            tl::parse(HTML[1], tl::ParserOptions::default()).unwrap(),
697        )
698    }
699
700    #[test]
701    fn node_text() {
702        let (dom0, dom1) = get_simple_example();
703        let training = Training::new(vec![dom0, dom1], vec![]).unwrap();
704        assert_eq!(
705            util::get_node_text(
706                &training.documents[0],
707                training.documents[0].get_element_by_id("1").unwrap(),
708                &training.settings.text_retrieval_options
709            ),
710            Some("blubb".into())
711        );
712        assert_eq!(
713            util::get_node_text(
714                &training.documents[0],
715                training.documents[0].get_element_by_id("2").unwrap(),
716                &training.settings.text_retrieval_options
717            ),
718            Some("glogg".into())
719        );
720        assert_eq!(
721            util::get_node_text(
722                &training.documents[0],
723                training.documents[0].get_element_by_id("3").unwrap(),
724                &training.settings.text_retrieval_options
725            ),
726            Some("plapp_before".into())
727        );
728        assert_eq!(
729            util::get_node_text(
730                &training.documents[1],
731                training.documents[1].get_element_by_id("1").unwrap(),
732                &training.settings.text_retrieval_options
733            ),
734            None
735        );
736        assert_eq!(training.documents[1].get_element_by_id("2"), None);
737        assert_eq!(
738            util::get_node_text(
739                &training.documents[1],
740                training.documents[1].get_element_by_id("3").unwrap(),
741                &training.settings.text_retrieval_options
742            ),
743            Some("plapp_after".into())
744        );
745        assert_eq!(
746            util::get_node_text(
747                &training.documents[0],
748                training.documents[0].get_element_by_id("root").unwrap(),
749                &training.settings.text_retrieval_options
750            ),
751            None
752        );
753    }
754
755    #[test]
756    fn find_nodes() {
757        let (dom0, dom1) = get_simple_example();
758        let training = Training::new(vec![dom0, dom1], vec![]).unwrap();
759
760        assert_eq!(
761            training.find_all_nodes_with_text(&training.documents[0], "blubb"),
762            vec![training.documents[0].get_element_by_id("1").unwrap()]
763        );
764        assert_eq!(
765            training.find_all_nodes_with_text(&training.documents[0], "glogg"),
766            vec![training.documents[0].get_element_by_id("2").unwrap()]
767        );
768        assert_eq!(
769            training.find_all_nodes_with_text(&training.documents[0], "plapp_before"),
770            vec![training.documents[0].get_element_by_id("3").unwrap()]
771        );
772        assert_eq!(
773            training.find_all_nodes_with_text(&training.documents[1], "plapp_after"),
774            vec![training.documents[1].get_element_by_id("3").unwrap()]
775        );
776    }
777
778    #[test]
779    fn node_matching() {
780        let (dom0, dom1) = get_simple_example();
781        let mut training = Training::new(
782            vec![dom0, dom1],
783            vec![
784                Attribute {
785                    name: "attr1".to_string(),
786                    values: vec![Some("blubb".into()), None],
787                    filter: None,
788                },
789                Attribute {
790                    name: "attr2".to_string(),
791                    values: vec![Some("glogg".into()), None],
792                    filter: None,
793                },
794                Attribute {
795                    name: "attr3".to_string(),
796                    values: vec![Some("plapp_before".into()), Some("plapp_after".into())],
797                    filter: None,
798                },
799                Attribute {
800                    name: "failing_attr1".to_string(),
801                    values: vec![Some("blubb".into()), Some("wrong".into())],
802                    filter: None,
803                },
804            ],
805        )
806        .unwrap();
807
808        let sel1 = Selector::new_from_parts(vec![SelectorPart::Id("1".into())]);
809        let sel2 = Selector::new_from_parts(vec![SelectorPart::Id("2".into())]);
810        let sel3 = Selector::new_from_parts(vec![SelectorPart::Id("3".into())]);
811
812        training.settings.missing_data_strategy = MissingDataStrategy::AllowMissingNode;
813
814        // Correct values in both documents
815        assert!(training
816            .check_selector(&sel1, &training.attributes[0], None)
817            .is_ok());
818        assert!(training
819            .check_selector(&sel2, &training.attributes[1], None)
820            .is_ok());
821        assert!(training
822            .check_selector(&sel3, &training.attributes[2], None)
823            .is_ok());
824
825        // Correct values in both documents, but node 2 is missing => should error
826        training.settings.missing_data_strategy = MissingDataStrategy::NodeMustExist;
827        assert!(training
828            .check_selector(&sel1, &training.attributes[0], None)
829            .is_ok());
830        assert!(training
831            .check_selector(&sel2, &training.attributes[1], None)
832            .is_err());
833        assert!(training
834            .check_selector(&sel3, &training.attributes[2], None)
835            .is_ok());
836
837        // Wrong value in second document
838        assert!(training
839            .check_selector(&sel1, &training.attributes[3], None)
840            .is_err());
841        assert!(training
842            .check_selector(&sel1, &training.attributes[3], Some(1))
843            .is_ok());
844    }
845
846    #[test]
847    fn fuzzing() {
848        let (dom0, dom1) = get_simple_example();
849        let mut training = Training::new(
850            vec![dom0, dom1],
851            vec![
852                Attribute {
853                    name: "attr1".to_string(),
854                    values: vec![Some("blubb".into()), None],
855                    filter: None,
856                },
857                Attribute {
858                    name: "attr2".to_string(),
859                    values: vec![Some("glogg".into()), None],
860                    filter: None,
861                },
862                Attribute {
863                    name: "attr3".to_string(),
864                    values: vec![Some("plapp_before".into()), Some("plapp_after".into())],
865                    filter: None,
866                },
867            ],
868        )
869        .unwrap();
870
871        training.settings.missing_data_strategy = MissingDataStrategy::AllowMissingNode;
872
873        let mut rng = ChaCha8Rng::seed_from_u64(1337);
874        training.do_one_fuzzing_round(&mut rng);
875
876        let result = training.to_result();
877        let (dom0, dom1) = get_simple_example();
878        assert_eq!(
879            result.get_value(&dom0, "attr1").unwrap_or(None),
880            Some("blubb".into())
881        );
882        assert_eq!(
883            result.get_value(&dom0, "attr2").unwrap_or(None),
884            Some("glogg".into())
885        );
886        assert_eq!(
887            result.get_value(&dom0, "attr3").unwrap_or(None),
888            Some("plapp_before".into())
889        );
890        assert_eq!(result.get_value(&dom1, "attr1").unwrap_or(None), None);
891        assert_eq!(result.get_value(&dom1, "attr2").unwrap_or(None), None);
892        assert_eq!(
893            result.get_value(&dom1, "attr3").unwrap_or(None),
894            Some("plapp_after".into())
895        );
896    }
897
898    #[cfg(feature = "serde")]
899    #[test]
900    fn serialization() {
901        let (dom0, dom1) = get_simple_example();
902        let mut training = Training::new(
903            vec![dom0],
904            vec![Attribute {
905                name: "attr3".to_string(),
906                values: vec![Some("plapp_before".into())],
907                filter: None,
908            }],
909        )
910        .unwrap();
911
912        let mut rng = ChaCha8Rng::seed_from_u64(1337);
913        training.do_one_fuzzing_round(&mut rng);
914        let result = training.to_result();
915        assert_eq!(result.get_selector("attr3"), Some("#3"));
916        assert_eq!(
917            result.get_value(&dom1, "attr3").unwrap_or(None),
918            Some("plapp_after".into())
919        );
920
921        let ser = serde_json::to_string(&result).unwrap();
922        assert!(ser.starts_with(r#"{"selectors":{"attr3":{"parts":[{"Id":"3"}]"#));
923
924        let de: TrainingResult = serde_json::from_str(ser.as_ref()).unwrap();
925        assert_eq!(de.get_selector("attr3"), Some("#3"));
926        assert_eq!(
927            de.get_value(&dom1, "attr3").unwrap_or(None),
928            Some("plapp_after".into())
929        );
930    }
931}