Skip to main content

spider_util/
selector.rs

1//! Cached CSS selector helpers.
2//!
3//! HTML-heavy crawls often reuse the same selectors across thousands of pages.
4//! This module keeps compiled selectors cached so repeated parsing work stays low.
5
6use crate::error::SpiderError;
7use ego_tree::NodeId;
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use scraper::{ElementRef, Html, Selector};
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13
14// Global selector cache to avoid repeated compilation
15static SELECTOR_CACHE: Lazy<RwLock<HashMap<String, Selector>>> =
16    Lazy::new(|| RwLock::new(HashMap::new()));
17static COMPILED_SELECTOR_CACHE: Lazy<RwLock<HashMap<String, CompiledSelector>>> =
18    Lazy::new(|| RwLock::new(HashMap::new()));
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub(crate) enum ExtractionKind {
22    Element,
23    Text,
24    Attr(String),
25}
26
27#[derive(Debug, Clone)]
28pub(crate) struct CompiledSelector {
29    selector: Selector,
30    extraction: ExtractionKind,
31}
32
33impl CompiledSelector {
34    pub(crate) fn selector(&self) -> &Selector {
35        &self.selector
36    }
37
38    pub(crate) fn extraction(&self) -> &ExtractionKind {
39        &self.extraction
40    }
41}
42
43/// A node selected from an HTML document using the builtin CSS selector API.
44#[derive(Debug, Clone)]
45pub struct SelectorNode {
46    document: Arc<Html>,
47    node_id: NodeId,
48    extraction: ExtractionKind,
49}
50
51/// A Scrapy-like selection result list.
52#[derive(Debug, Clone)]
53pub struct SelectorList {
54    document: Arc<Html>,
55    node_ids: Vec<NodeId>,
56    extraction: ExtractionKind,
57}
58
59impl SelectorNode {
60    pub(crate) fn new(document: Arc<Html>, node_id: NodeId, extraction: ExtractionKind) -> Self {
61        Self {
62            document,
63            node_id,
64            extraction,
65        }
66    }
67
68    /// Applies a CSS selector relative to this node.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
73    /// when chaining from a text/attribute extraction.
74    pub fn css(&self, query: &str) -> Result<SelectorList, SpiderError> {
75        if self.extraction != ExtractionKind::Element {
76            return Err(SpiderError::HtmlParseError(
77                "css() can only be chained from element selections".to_string(),
78            ));
79        }
80
81        let compiled = get_cached_compiled_selector(query)?;
82        let Some(scope) = self.element_ref() else {
83            return Ok(SelectorList::empty(
84                self.document.clone(),
85                compiled.extraction().clone(),
86            ));
87        };
88
89        let node_ids = scope
90            .select(compiled.selector())
91            .map(|element| element.id())
92            .collect();
93
94        Ok(SelectorList::new(
95            self.document.clone(),
96            node_ids,
97            compiled.extraction().clone(),
98        ))
99    }
100
101    /// Returns the extracted value for this node, if present.
102    pub fn get(&self) -> Option<String> {
103        self.element_ref()
104            .and_then(|element| extract_element_value(element, &self.extraction))
105    }
106
107    /// Returns this node's extracted value as a single-element vector or an empty one.
108    pub fn get_all(&self) -> Vec<String> {
109        self.get().into_iter().collect()
110    }
111
112    /// Returns the named attribute from the selected element.
113    pub fn attrib(&self, name: &str) -> Option<String> {
114        self.element_ref()
115            .and_then(|element| element.attr(name).map(ToOwned::to_owned))
116    }
117
118    /// Returns the concatenated text content of the selected element.
119    pub fn text_content(&self) -> Option<String> {
120        self.element_ref()
121            .map(|element| element.text().collect::<String>())
122    }
123
124    /// Returns `true` when this element has any descendant matching `query`.
125    ///
126    /// # Errors
127    ///
128    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
129    /// when called on a text/attribute extraction.
130    pub fn has_css(&self, query: &str) -> Result<bool, SpiderError> {
131        Ok(!self.css(query)?.is_empty())
132    }
133
134    /// Returns `true` when any ancestor of this element matches `query`.
135    ///
136    /// # Errors
137    ///
138    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
139    /// when called on a text/attribute extraction.
140    pub fn has_ancestor(&self, query: &str) -> Result<bool, SpiderError> {
141        let selector =
142            Selector::parse(query).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
143        let Some(element) = self.element_ref() else {
144            return Ok(false);
145        };
146
147        Ok(element
148            .ancestors()
149            .filter_map(ElementRef::wrap)
150            .any(|ancestor| selector.matches(&ancestor)))
151    }
152
153    fn element_ref(&self) -> Option<ElementRef<'_>> {
154        element_ref_by_id(&self.document, self.node_id)
155    }
156}
157
158impl SelectorList {
159    pub(crate) fn new(
160        document: Arc<Html>,
161        node_ids: Vec<NodeId>,
162        extraction: ExtractionKind,
163    ) -> Self {
164        Self {
165            document,
166            node_ids,
167            extraction,
168        }
169    }
170
171    pub(crate) fn from_document_query(
172        document: Arc<Html>,
173        query: &str,
174    ) -> Result<Self, SpiderError> {
175        let compiled = get_cached_compiled_selector(query)?;
176        let node_ids = document
177            .select(compiled.selector())
178            .map(|element| element.id())
179            .collect();
180
181        Ok(Self::new(document, node_ids, compiled.extraction().clone()))
182    }
183
184    pub(crate) fn empty(document: Arc<Html>, extraction: ExtractionKind) -> Self {
185        Self::new(document, Vec::new(), extraction)
186    }
187
188    /// Applies a CSS selector relative to every node in the list.
189    ///
190    /// # Errors
191    ///
192    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
193    /// when chaining from a text/attribute extraction.
194    pub fn css(&self, query: &str) -> Result<Self, SpiderError> {
195        if self.extraction != ExtractionKind::Element {
196            return Err(SpiderError::HtmlParseError(
197                "css() can only be chained from element selections".to_string(),
198            ));
199        }
200
201        let compiled = get_cached_compiled_selector(query)?;
202        let mut seen = HashSet::new();
203        let mut node_ids = Vec::new();
204
205        for node_id in &self.node_ids {
206            let Some(scope) = element_ref_by_id(&self.document, *node_id) else {
207                continue;
208            };
209
210            for element in scope.select(compiled.selector()) {
211                let id = element.id();
212                if seen.insert(id) {
213                    node_ids.push(id);
214                }
215            }
216        }
217
218        Ok(Self::new(
219            self.document.clone(),
220            node_ids,
221            compiled.extraction().clone(),
222        ))
223    }
224
225    /// Returns the first extracted value in the selection.
226    pub fn get(&self) -> Option<String> {
227        self.first().and_then(|node| node.get())
228    }
229
230    /// Returns all extracted values in the selection.
231    pub fn get_all(&self) -> Vec<String> {
232        self.node_ids
233            .iter()
234            .filter_map(|node_id| {
235                element_ref_by_id(&self.document, *node_id)
236                    .and_then(|element| extract_element_value(element, &self.extraction))
237            })
238            .collect()
239    }
240
241    /// Returns the named attribute from the first selected element.
242    pub fn attrib(&self, name: &str) -> Option<String> {
243        self.first().and_then(|node| node.attrib(name))
244    }
245
246    /// Returns the first selected node.
247    pub fn first(&self) -> Option<SelectorNode> {
248        self.node_ids.first().copied().map(|node_id| {
249            SelectorNode::new(self.document.clone(), node_id, self.extraction.clone())
250        })
251    }
252
253    /// Returns the number of matched nodes.
254    pub fn len(&self) -> usize {
255        self.node_ids.len()
256    }
257
258    /// Returns `true` when the selection has no matched nodes.
259    pub fn is_empty(&self) -> bool {
260        self.node_ids.is_empty()
261    }
262}
263
264impl IntoIterator for SelectorList {
265    type Item = SelectorNode;
266    type IntoIter = std::vec::IntoIter<SelectorNode>;
267
268    fn into_iter(self) -> Self::IntoIter {
269        self.node_ids
270            .into_iter()
271            .map(|node_id| {
272                SelectorNode::new(self.document.clone(), node_id, self.extraction.clone())
273            })
274            .collect::<Vec<_>>()
275            .into_iter()
276    }
277}
278
279/// Returns a compiled selector from the cache, compiling it on first use.
280pub fn get_cached_selector(selector_str: &str) -> Option<Selector> {
281    {
282        let cache = SELECTOR_CACHE.read();
283        if let Some(cached) = cache.get(selector_str) {
284            return Some(cached.clone());
285        }
286    }
287
288    match Selector::parse(selector_str) {
289        Ok(selector) => {
290            {
291                let mut cache = SELECTOR_CACHE.write();
292                if let Some(cached) = cache.get(selector_str) {
293                    return Some(cached.clone());
294                }
295                cache.insert(selector_str.to_string(), selector.clone());
296            }
297            Some(selector)
298        }
299        Err(_) => None,
300    }
301}
302
303pub(crate) fn get_cached_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
304    {
305        let cache = COMPILED_SELECTOR_CACHE.read();
306        if let Some(cached) = cache.get(query) {
307            return Ok(cached.clone());
308        }
309    }
310
311    let compiled = parse_compiled_selector(query)?;
312
313    {
314        let mut cache = COMPILED_SELECTOR_CACHE.write();
315        if let Some(cached) = cache.get(query) {
316            return Ok(cached.clone());
317        }
318        cache.insert(query.to_string(), compiled.clone());
319    }
320
321    Ok(compiled)
322}
323
324/// Pre-warms the selector cache with a small set of common selectors.
325pub fn prewarm_cache() {
326    let common_selectors = vec![
327        "a[href]",
328        "link[href]",
329        "script[src]",
330        "img[src]",
331        "audio[src]",
332        "video[src]",
333        "source[src]",
334        "form[action]",
335        "iframe[src]",
336        "frame[src]",
337        "embed[src]",
338        "object[data]",
339    ];
340
341    for selector_str in common_selectors {
342        get_cached_selector(selector_str);
343        let _ = get_cached_compiled_selector(selector_str);
344    }
345}
346
347fn parse_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
348    let query = query.trim();
349    if query.is_empty() {
350        return Err(SpiderError::HtmlParseError(
351            "selector query cannot be empty".to_string(),
352        ));
353    }
354
355    let (selector_str, extraction) = parse_selector_parts(query)?;
356    let selector =
357        Selector::parse(selector_str).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
358
359    Ok(CompiledSelector {
360        selector,
361        extraction,
362    })
363}
364
365fn parse_selector_parts(query: &str) -> Result<(&str, ExtractionKind), SpiderError> {
366    if let Some(selector) = query.strip_suffix("::text") {
367        let selector = selector.trim_end();
368        if selector.is_empty() {
369            return Err(SpiderError::HtmlParseError(
370                "selector cannot be empty before ::text".to_string(),
371            ));
372        }
373        return Ok((selector, ExtractionKind::Text));
374    }
375
376    if let Some(start) = query.rfind("::attr(")
377        && query.ends_with(')')
378    {
379        let selector = query[..start].trim_end();
380        let attr = query[start + "::attr(".len()..query.len() - 1].trim();
381        if selector.is_empty() {
382            return Err(SpiderError::HtmlParseError(
383                "selector cannot be empty before ::attr(...)".to_string(),
384            ));
385        }
386        if attr.is_empty() {
387            return Err(SpiderError::HtmlParseError(
388                "attribute name cannot be empty in ::attr(...)".to_string(),
389            ));
390        }
391
392        return Ok((selector, ExtractionKind::Attr(attr.to_string())));
393    }
394
395    Ok((query, ExtractionKind::Element))
396}
397
398fn element_ref_by_id(document: &Html, node_id: NodeId) -> Option<ElementRef<'_>> {
399    document.tree.get(node_id).and_then(ElementRef::wrap)
400}
401
402fn extract_element_value(element: ElementRef<'_>, extraction: &ExtractionKind) -> Option<String> {
403    match extraction {
404        ExtractionKind::Element => Some(element.html()),
405        ExtractionKind::Text => Some(element.text().collect::<String>()),
406        ExtractionKind::Attr(attr) => element.attr(attr).map(ToOwned::to_owned),
407    }
408}