1use crate::error::SpiderError;
7use ego_tree::NodeId;
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use scraper::{ElementRef, Html, Selector};
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13
14static SELECTOR_CACHE: Lazy<RwLock<HashMap<String, Selector>>> =
16 Lazy::new(|| RwLock::new(HashMap::new()));
17static COMPILED_SELECTOR_CACHE: Lazy<RwLock<HashMap<String, CompiledSelector>>> =
18 Lazy::new(|| RwLock::new(HashMap::new()));
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub(crate) enum ExtractionKind {
22 Element,
23 Text,
24 Attr(String),
25}
26
27#[derive(Debug, Clone)]
28pub(crate) struct CompiledSelector {
29 selector: Selector,
30 extraction: ExtractionKind,
31}
32
33impl CompiledSelector {
34 pub(crate) fn selector(&self) -> &Selector {
35 &self.selector
36 }
37
38 pub(crate) fn extraction(&self) -> &ExtractionKind {
39 &self.extraction
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct SelectorNode {
46 document: Arc<Html>,
47 node_id: NodeId,
48 extraction: ExtractionKind,
49}
50
51#[derive(Debug, Clone)]
53pub struct SelectorList {
54 document: Arc<Html>,
55 node_ids: Vec<NodeId>,
56 extraction: ExtractionKind,
57}
58
59impl SelectorNode {
60 pub(crate) fn new(document: Arc<Html>, node_id: NodeId, extraction: ExtractionKind) -> Self {
61 Self {
62 document,
63 node_id,
64 extraction,
65 }
66 }
67
68 pub fn css(&self, query: &str) -> Result<SelectorList, SpiderError> {
75 if self.extraction != ExtractionKind::Element {
76 return Err(SpiderError::HtmlParseError(
77 "css() can only be chained from element selections".to_string(),
78 ));
79 }
80
81 let compiled = get_cached_compiled_selector(query)?;
82 let Some(scope) = self.element_ref() else {
83 return Ok(SelectorList::empty(
84 self.document.clone(),
85 compiled.extraction().clone(),
86 ));
87 };
88
89 let node_ids = scope
90 .select(compiled.selector())
91 .map(|element| element.id())
92 .collect();
93
94 Ok(SelectorList::new(
95 self.document.clone(),
96 node_ids,
97 compiled.extraction().clone(),
98 ))
99 }
100
101 pub fn get(&self) -> Option<String> {
103 self.element_ref()
104 .and_then(|element| extract_element_value(element, &self.extraction))
105 }
106
107 pub fn get_all(&self) -> Vec<String> {
109 self.get().into_iter().collect()
110 }
111
112 pub fn attrib(&self, name: &str) -> Option<String> {
114 self.element_ref()
115 .and_then(|element| element.attr(name).map(ToOwned::to_owned))
116 }
117
118 pub fn text_content(&self) -> Option<String> {
120 self.element_ref()
121 .map(|element| element.text().collect::<String>())
122 }
123
124 pub fn has_css(&self, query: &str) -> Result<bool, SpiderError> {
131 Ok(!self.css(query)?.is_empty())
132 }
133
134 pub fn has_ancestor(&self, query: &str) -> Result<bool, SpiderError> {
141 let selector =
142 Selector::parse(query).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
143 let Some(element) = self.element_ref() else {
144 return Ok(false);
145 };
146
147 Ok(element
148 .ancestors()
149 .filter_map(ElementRef::wrap)
150 .any(|ancestor| selector.matches(&ancestor)))
151 }
152
153 fn element_ref(&self) -> Option<ElementRef<'_>> {
154 element_ref_by_id(&self.document, self.node_id)
155 }
156}
157
158impl SelectorList {
159 pub(crate) fn new(
160 document: Arc<Html>,
161 node_ids: Vec<NodeId>,
162 extraction: ExtractionKind,
163 ) -> Self {
164 Self {
165 document,
166 node_ids,
167 extraction,
168 }
169 }
170
171 pub(crate) fn from_document_query(
172 document: Arc<Html>,
173 query: &str,
174 ) -> Result<Self, SpiderError> {
175 let compiled = get_cached_compiled_selector(query)?;
176 let node_ids = document
177 .select(compiled.selector())
178 .map(|element| element.id())
179 .collect();
180
181 Ok(Self::new(document, node_ids, compiled.extraction().clone()))
182 }
183
184 pub(crate) fn empty(document: Arc<Html>, extraction: ExtractionKind) -> Self {
185 Self::new(document, Vec::new(), extraction)
186 }
187
188 pub fn css(&self, query: &str) -> Result<Self, SpiderError> {
195 if self.extraction != ExtractionKind::Element {
196 return Err(SpiderError::HtmlParseError(
197 "css() can only be chained from element selections".to_string(),
198 ));
199 }
200
201 let compiled = get_cached_compiled_selector(query)?;
202 let mut seen = HashSet::new();
203 let mut node_ids = Vec::new();
204
205 for node_id in &self.node_ids {
206 let Some(scope) = element_ref_by_id(&self.document, *node_id) else {
207 continue;
208 };
209
210 for element in scope.select(compiled.selector()) {
211 let id = element.id();
212 if seen.insert(id) {
213 node_ids.push(id);
214 }
215 }
216 }
217
218 Ok(Self::new(
219 self.document.clone(),
220 node_ids,
221 compiled.extraction().clone(),
222 ))
223 }
224
225 pub fn get(&self) -> Option<String> {
227 self.first().and_then(|node| node.get())
228 }
229
230 pub fn get_all(&self) -> Vec<String> {
232 self.node_ids
233 .iter()
234 .filter_map(|node_id| {
235 element_ref_by_id(&self.document, *node_id)
236 .and_then(|element| extract_element_value(element, &self.extraction))
237 })
238 .collect()
239 }
240
241 pub fn attrib(&self, name: &str) -> Option<String> {
243 self.first().and_then(|node| node.attrib(name))
244 }
245
246 pub fn first(&self) -> Option<SelectorNode> {
248 self.node_ids.first().copied().map(|node_id| {
249 SelectorNode::new(self.document.clone(), node_id, self.extraction.clone())
250 })
251 }
252
253 pub fn len(&self) -> usize {
255 self.node_ids.len()
256 }
257
258 pub fn is_empty(&self) -> bool {
260 self.node_ids.is_empty()
261 }
262}
263
264impl IntoIterator for SelectorList {
265 type Item = SelectorNode;
266 type IntoIter = std::vec::IntoIter<SelectorNode>;
267
268 fn into_iter(self) -> Self::IntoIter {
269 self.node_ids
270 .into_iter()
271 .map(|node_id| {
272 SelectorNode::new(self.document.clone(), node_id, self.extraction.clone())
273 })
274 .collect::<Vec<_>>()
275 .into_iter()
276 }
277}
278
279pub fn get_cached_selector(selector_str: &str) -> Option<Selector> {
281 {
282 let cache = SELECTOR_CACHE.read();
283 if let Some(cached) = cache.get(selector_str) {
284 return Some(cached.clone());
285 }
286 }
287
288 match Selector::parse(selector_str) {
289 Ok(selector) => {
290 {
291 let mut cache = SELECTOR_CACHE.write();
292 if let Some(cached) = cache.get(selector_str) {
293 return Some(cached.clone());
294 }
295 cache.insert(selector_str.to_string(), selector.clone());
296 }
297 Some(selector)
298 }
299 Err(_) => None,
300 }
301}
302
303pub(crate) fn get_cached_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
304 {
305 let cache = COMPILED_SELECTOR_CACHE.read();
306 if let Some(cached) = cache.get(query) {
307 return Ok(cached.clone());
308 }
309 }
310
311 let compiled = parse_compiled_selector(query)?;
312
313 {
314 let mut cache = COMPILED_SELECTOR_CACHE.write();
315 if let Some(cached) = cache.get(query) {
316 return Ok(cached.clone());
317 }
318 cache.insert(query.to_string(), compiled.clone());
319 }
320
321 Ok(compiled)
322}
323
324pub fn prewarm_cache() {
326 let common_selectors = vec![
327 "a[href]",
328 "link[href]",
329 "script[src]",
330 "img[src]",
331 "audio[src]",
332 "video[src]",
333 "source[src]",
334 "form[action]",
335 "iframe[src]",
336 "frame[src]",
337 "embed[src]",
338 "object[data]",
339 ];
340
341 for selector_str in common_selectors {
342 get_cached_selector(selector_str);
343 let _ = get_cached_compiled_selector(selector_str);
344 }
345}
346
347fn parse_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
348 let query = query.trim();
349 if query.is_empty() {
350 return Err(SpiderError::HtmlParseError(
351 "selector query cannot be empty".to_string(),
352 ));
353 }
354
355 let (selector_str, extraction) = parse_selector_parts(query)?;
356 let selector =
357 Selector::parse(selector_str).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
358
359 Ok(CompiledSelector {
360 selector,
361 extraction,
362 })
363}
364
365fn parse_selector_parts(query: &str) -> Result<(&str, ExtractionKind), SpiderError> {
366 if let Some(selector) = query.strip_suffix("::text") {
367 let selector = selector.trim_end();
368 if selector.is_empty() {
369 return Err(SpiderError::HtmlParseError(
370 "selector cannot be empty before ::text".to_string(),
371 ));
372 }
373 return Ok((selector, ExtractionKind::Text));
374 }
375
376 if let Some(start) = query.rfind("::attr(")
377 && query.ends_with(')')
378 {
379 let selector = query[..start].trim_end();
380 let attr = query[start + "::attr(".len()..query.len() - 1].trim();
381 if selector.is_empty() {
382 return Err(SpiderError::HtmlParseError(
383 "selector cannot be empty before ::attr(...)".to_string(),
384 ));
385 }
386 if attr.is_empty() {
387 return Err(SpiderError::HtmlParseError(
388 "attribute name cannot be empty in ::attr(...)".to_string(),
389 ));
390 }
391
392 return Ok((selector, ExtractionKind::Attr(attr.to_string())));
393 }
394
395 Ok((query, ExtractionKind::Element))
396}
397
398fn element_ref_by_id(document: &Html, node_id: NodeId) -> Option<ElementRef<'_>> {
399 document.tree.get(node_id).and_then(ElementRef::wrap)
400}
401
402fn extract_element_value(element: ElementRef<'_>, extraction: &ExtractionKind) -> Option<String> {
403 match extraction {
404 ExtractionKind::Element => Some(element.html()),
405 ExtractionKind::Text => Some(element.text().collect::<String>()),
406 ExtractionKind::Attr(attr) => element.attr(attr).map(ToOwned::to_owned),
407 }
408}