Skip to main content

content_extractor_rl/
node_features.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/node_features.rs
3// ============================================================================
4//! Real, content-aware features for candidate DOM nodes.
5//!
6//! These replace the placeholder constant state vector that previously made the
7//! RL agent blind to the document. Every feature is derived from the actual DOM
8//! subtree of a candidate node, so two different candidates produce two
9//! different feature vectors — a precondition for the agent to learn anything.
10//!
11//! The same features power the supervised node classifier (hybrid mode), so the
12//! representation is shared in one place.
13
14use scraper::{ElementRef, Selector};
15use std::collections::HashSet;
16
17/// Continuous extraction parameters that the RL policy tunes. They actually
18/// affect which text blocks are kept, so the policy's continuous head has a
19/// real effect on the extracted text (and therefore on the reward).
20#[derive(Debug, Clone, Copy)]
21pub struct ExtractionParams {
22    /// Minimum words for a block (<p>/text node) to be kept.
23    pub min_block_words: usize,
24    /// Drop blocks whose link density exceeds this threshold (0..=1).
25    pub max_block_link_density: f32,
26}
27
28impl Default for ExtractionParams {
29    fn default() -> Self {
30        Self { min_block_words: 5, max_block_link_density: 0.5 }
31    }
32}
33
34impl ExtractionParams {
35    /// Map a policy's normalized continuous params (each roughly in [-1, 1]) to
36    /// concrete extraction settings. Only the first two params are used today;
37    /// extra params are accepted and ignored so the action space can stay wide.
38    pub fn from_normalized(params: &[f32]) -> Self {
39        let p0 = params.first().copied().unwrap_or(0.0).clamp(-1.0, 1.0);
40        let p1 = params.get(1).copied().unwrap_or(0.0).clamp(-1.0, 1.0);
41        // min_block_words in [1, 40]
42        let min_block_words = (1.0 + (p0 + 1.0) * 19.5).round().clamp(1.0, 40.0) as usize;
43        // max_block_link_density in [0.1, 0.9]
44        let max_block_link_density = (0.1 + (p1 + 1.0) * 0.4).clamp(0.1, 0.9);
45        Self { min_block_words, max_block_link_density }
46    }
47}
48
49/// Structural / textual features for a single candidate node.
50///
51/// All fields are pre-normalized to roughly [0, 1] so they can be fed directly
52/// to the network without further scaling.
53#[derive(Debug, Clone, Copy, PartialEq)]
54pub struct NodeFeatures {
55    pub word_count_norm: f32,
56    pub char_count_norm: f32,
57    pub link_density: f32,
58    pub stopword_ratio: f32,
59    pub p_count_norm: f32,
60    pub text_tag_ratio: f32,
61    pub depth_norm: f32,
62    pub comma_density: f32,
63    pub tag_article: f32,
64    pub tag_main: f32,
65    pub tag_section: f32,
66    pub tag_div: f32,
67    pub tag_other: f32,
68    pub class_positive: f32,
69    pub class_negative: f32,
70    pub unique_word_ratio: f32,
71}
72
73impl NodeFeatures {
74    /// Number of features per node. Kept in sync with [`Self::to_vec`].
75    pub const DIM: usize = 16;
76
77    /// Zeroed features (used to pad unused candidate slots).
78    pub fn zeros() -> Self {
79        Self {
80            word_count_norm: 0.0,
81            char_count_norm: 0.0,
82            link_density: 0.0,
83            stopword_ratio: 0.0,
84            p_count_norm: 0.0,
85            text_tag_ratio: 0.0,
86            depth_norm: 0.0,
87            comma_density: 0.0,
88            tag_article: 0.0,
89            tag_main: 0.0,
90            tag_section: 0.0,
91            tag_div: 0.0,
92            tag_other: 0.0,
93            class_positive: 0.0,
94            class_negative: 0.0,
95            unique_word_ratio: 0.0,
96        }
97    }
98
99    /// Flatten to a fixed-length vector (length == [`Self::DIM`]).
100    pub fn to_vec(&self) -> Vec<f32> {
101        vec![
102            self.word_count_norm,
103            self.char_count_norm,
104            self.link_density,
105            self.stopword_ratio,
106            self.p_count_norm,
107            self.text_tag_ratio,
108            self.depth_norm,
109            self.comma_density,
110            self.tag_article,
111            self.tag_main,
112            self.tag_section,
113            self.tag_div,
114            self.tag_other,
115            self.class_positive,
116            self.class_negative,
117            self.unique_word_ratio,
118        ]
119    }
120
121    /// A heuristic "is this the article body?" score in [0, 1], derived purely
122    /// from the features. Used as a warm-start prior and as a baseline for the
123    /// supervised classifier's tests. Not learned — just a sane linear combo of
124    /// Readability-style signals.
125    pub fn heuristic_content_score(&self) -> f32 {
126        let mut s = 0.0;
127        s += 0.40 * self.word_count_norm;
128        s += 0.20 * self.p_count_norm;
129        s += 0.15 * (1.0 - self.link_density);
130        s += 0.10 * self.class_positive;
131        s += 0.10 * self.tag_article;
132        s += 0.05 * self.tag_main;
133        s -= 0.40 * self.class_negative;
134        s -= 0.20 * self.link_density;
135        s.clamp(0.0, 1.0)
136    }
137}
138
139/// Positive (content-ish) class/id substrings, à la Readability/arc90.
140const POSITIVE_HINTS: &[&str] = &[
141    "article", "content", "post", "story", "body", "entry", "main", "text",
142    "blog", "page",
143];
144
145/// Negative (boilerplate) class/id substrings.
146const NEGATIVE_HINTS: &[&str] = &[
147    "comment", "sidebar", "footer", "header", "nav", "menu", "ad", "advert",
148    "promo", "share", "social", "related", "widget", "banner", "popup",
149    "cookie", "newsletter", "subscribe", "breadcrumb",
150];
151
152fn p_selector() -> Selector {
153    Selector::parse("p").unwrap()
154}
155
156fn a_selector() -> Selector {
157    Selector::parse("a").unwrap()
158}
159
160/// Concatenate the text content of an element, collapsing whitespace.
161pub fn node_text(el: &ElementRef) -> String {
162    let raw: String = el.text().collect::<Vec<_>>().join(" ");
163    raw.split_whitespace().collect::<Vec<_>>().join(" ")
164}
165
166/// A single block of text (one `<p>`) with the stats needed to filter it.
167#[derive(Debug, Clone)]
168pub struct TextBlock {
169    pub text: String,
170    pub words: usize,
171    pub link_density: f32,
172}
173
174/// Self-contained, owned snapshot of a candidate node's extractable text.
175///
176/// It holds the per-paragraph blocks plus a whole-node fallback, so the
177/// environment can re-extract under *different* policy params on later steps
178/// without keeping borrowed `ElementRef`s across calls (which Rust's borrow
179/// checker forbids when the document is owned by the same struct).
180#[derive(Debug, Clone)]
181pub struct CandidateContent {
182    pub blocks: Vec<TextBlock>,
183    pub full_text: String,
184    pub full_link_density: f32,
185}
186
187impl CandidateContent {
188    /// Apply extraction params to produce the article text. Blocks are kept
189    /// only if they meet the minimum word count and stay below the link-density
190    /// threshold; if none survive we fall back to the whole-node text when it is
191    /// not link-dominated.
192    pub fn extract(&self, params: &ExtractionParams) -> String {
193        let kept: Vec<&str> = self
194            .blocks
195            .iter()
196            .filter(|b| b.words >= params.min_block_words && b.link_density <= params.max_block_link_density)
197            .map(|b| b.text.as_str())
198            .collect();
199
200        if kept.is_empty() {
201            if self.full_link_density <= params.max_block_link_density
202                && self.full_text.split_whitespace().count() >= params.min_block_words
203            {
204                return self.full_text.clone();
205            }
206            return String::new();
207        }
208
209        kept.join("\n\n")
210    }
211}
212
213/// Build the owned [`CandidateContent`] snapshot for a node.
214pub fn node_content(el: &ElementRef) -> CandidateContent {
215    let p_sel = p_selector();
216    let blocks: Vec<TextBlock> = el
217        .select(&p_sel)
218        .map(|p| {
219            let text = node_text(&p);
220            let words = text.split_whitespace().count();
221            TextBlock { words, link_density: link_density(&p), text }
222        })
223        .collect();
224
225    CandidateContent {
226        blocks,
227        full_text: node_text(el),
228        full_link_density: link_density(el),
229    }
230}
231
232/// Extract article text from a node, honoring the policy's extraction params.
233///
234/// Convenience wrapper over [`node_content`] + [`CandidateContent::extract`].
235pub fn extract_node_text(el: &ElementRef, params: &ExtractionParams) -> String {
236    node_content(el).extract(params)
237}
238
239/// Fraction of characters inside `<a>` descendants relative to total text.
240pub fn link_density(el: &ElementRef) -> f32 {
241    let total = node_text(el).chars().count();
242    if total == 0 {
243        return 0.0;
244    }
245    let a_sel = a_selector();
246    let link_chars: usize = el
247        .select(&a_sel)
248        .map(|a| node_text(&a).chars().count())
249        .sum();
250    (link_chars as f32 / total as f32).clamp(0.0, 1.0)
251}
252
253fn class_id_hint_scores(el: &ElementRef) -> (f32, f32) {
254    let mut haystack = String::new();
255    if let Some(c) = el.value().attr("class") {
256        haystack.push_str(&c.to_lowercase());
257        haystack.push(' ');
258    }
259    if let Some(id) = el.value().attr("id") {
260        haystack.push_str(&id.to_lowercase());
261    }
262    if haystack.is_empty() {
263        return (0.0, 0.0);
264    }
265    let pos = POSITIVE_HINTS.iter().filter(|h| haystack.contains(**h)).count();
266    let neg = NEGATIVE_HINTS.iter().filter(|h| haystack.contains(**h)).count();
267    // Squash counts into [0, 1] — presence matters more than exact count.
268    let pos_score = (pos as f32 / 2.0).clamp(0.0, 1.0);
269    let neg_score = (neg as f32 / 2.0).clamp(0.0, 1.0);
270    (pos_score, neg_score)
271}
272
273fn node_depth(el: &ElementRef) -> usize {
274    let mut depth = 0;
275    let mut current = Some(*el);
276    while let Some(e) = current {
277        depth += 1;
278        current = e.parent().and_then(ElementRef::wrap);
279    }
280    depth
281}
282
283/// Compute the full feature set for a candidate node.
284pub fn extract_features(el: &ElementRef, stopwords: &HashSet<String>) -> NodeFeatures {
285    let text = node_text(el);
286    let tokens: Vec<&str> = text.split_whitespace().collect();
287    let word_count = tokens.len();
288    let char_count = text.chars().count();
289
290    let word_count_norm = if word_count == 0 {
291        0.0
292    } else {
293        ((word_count as f32 + 1.0).ln() / (5000f32).ln()).clamp(0.0, 1.0)
294    };
295    let char_count_norm = ((char_count as f32 + 1.0).ln() / (40000f32).ln()).clamp(0.0, 1.0);
296
297    let link_density = link_density(el);
298
299    let stopword_ratio = if word_count == 0 {
300        0.0
301    } else {
302        let sw = tokens
303            .iter()
304            .filter(|t| stopwords.contains(&t.to_lowercase()))
305            .count();
306        (sw as f32 / word_count as f32).clamp(0.0, 1.0)
307    };
308
309    let p_count = el.select(&p_selector()).count();
310    let p_count_norm = (p_count as f32 / 30.0).clamp(0.0, 1.0);
311
312    let element_count = el
313        .descendants()
314        .filter(|n| n.value().is_element())
315        .count()
316        .max(1);
317    let text_tag_ratio = ((word_count as f32 / element_count as f32) / 50.0).clamp(0.0, 1.0);
318
319    let depth_norm = (node_depth(el) as f32 / 30.0).clamp(0.0, 1.0);
320
321    let comma_count = text.chars().filter(|c| *c == ',').count();
322    let comma_density = if word_count == 0 {
323        0.0
324    } else {
325        ((comma_count as f32 / word_count as f32) / 0.2).clamp(0.0, 1.0)
326    };
327
328    let tag = el.value().name().to_lowercase();
329    let (tag_article, tag_main, tag_section, tag_div, tag_other) = match tag.as_str() {
330        "article" => (1.0, 0.0, 0.0, 0.0, 0.0),
331        "main" => (0.0, 1.0, 0.0, 0.0, 0.0),
332        "section" => (0.0, 0.0, 1.0, 0.0, 0.0),
333        "div" => (0.0, 0.0, 0.0, 1.0, 0.0),
334        _ => (0.0, 0.0, 0.0, 0.0, 1.0),
335    };
336
337    let (class_positive, class_negative) = class_id_hint_scores(el);
338
339    let unique_word_ratio = if word_count == 0 {
340        0.0
341    } else {
342        let unique: HashSet<String> = tokens.iter().map(|t| t.to_lowercase()).collect();
343        (unique.len() as f32 / word_count as f32).clamp(0.0, 1.0)
344    };
345
346    NodeFeatures {
347        word_count_norm,
348        char_count_norm,
349        link_density,
350        stopword_ratio,
351        p_count_norm,
352        text_tag_ratio,
353        depth_norm,
354        comma_density,
355        tag_article,
356        tag_main,
357        tag_section,
358        tag_div,
359        tag_other,
360        class_positive,
361        class_negative,
362        unique_word_ratio,
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use scraper::Html;
370
371    fn stopwords() -> HashSet<String> {
372        ["the", "a", "is", "of", "and", "to", "in", "with", "this", "for"]
373            .into_iter()
374            .map(|s| s.to_string())
375            .collect()
376    }
377
378    fn first_matching<'a>(doc: &'a Html, sel: &str) -> ElementRef<'a> {
379        doc.select(&Selector::parse(sel).unwrap()).next().unwrap()
380    }
381
382    #[test]
383    fn to_vec_len_matches_dim() {
384        assert_eq!(NodeFeatures::zeros().to_vec().len(), NodeFeatures::DIM);
385    }
386
387    #[test]
388    fn content_node_scores_higher_than_boilerplate() {
389        let html = r#"
390            <html><body>
391                <article class="post-content">
392                    <p>This is the real article body with a lot of meaningful text content.</p>
393                    <p>It has multiple paragraphs describing important information in detail here.</p>
394                    <p>Readers expect substantial prose and varied vocabulary throughout the piece.</p>
395                </article>
396                <div class="sidebar-ads">
397                    <a href="/1">Link one</a> <a href="/2">Link two</a> <a href="/3">Link three</a>
398                </div>
399            </body></html>
400        "#;
401        let doc = Html::parse_document(html);
402        let sw = stopwords();
403
404        let article = first_matching(&doc, "article");
405        let sidebar = first_matching(&doc, "div");
406
407        let art_feat = extract_features(&article, &sw);
408        let side_feat = extract_features(&sidebar, &sw);
409
410        // Article has real prose; sidebar is link-dominated boilerplate.
411        assert!(art_feat.link_density < side_feat.link_density);
412        assert!(art_feat.word_count_norm > side_feat.word_count_norm);
413        assert!(art_feat.class_positive > 0.0);
414        assert!(side_feat.class_negative > 0.0);
415        assert!(
416            art_feat.heuristic_content_score() > side_feat.heuristic_content_score(),
417            "article {} should beat sidebar {}",
418            art_feat.heuristic_content_score(),
419            side_feat.heuristic_content_score()
420        );
421    }
422
423    #[test]
424    fn different_nodes_yield_different_features() {
425        let html = r#"
426            <html><body>
427                <article><p>Alpha beta gamma delta epsilon zeta eta theta iota kappa.</p></article>
428                <div class="nav"><a href="/x">x</a></div>
429            </body></html>
430        "#;
431        let doc = Html::parse_document(html);
432        let sw = stopwords();
433        let a = extract_features(&first_matching(&doc, "article"), &sw);
434        let d = extract_features(&first_matching(&doc, "div"), &sw);
435        assert_ne!(a.to_vec(), d.to_vec(), "distinct nodes must have distinct features");
436    }
437
438    #[test]
439    fn extraction_params_change_extracted_text() {
440        let html = r#"
441            <html><body>
442                <article>
443                    <p>Short one.</p>
444                    <p>This paragraph is clearly long enough to survive a high minimum word threshold filter.</p>
445                </article>
446            </body></html>
447        "#;
448        let doc = Html::parse_document(html);
449        let article = first_matching(&doc, "article");
450
451        let lenient = ExtractionParams { min_block_words: 1, max_block_link_density: 0.9 };
452        let strict = ExtractionParams { min_block_words: 6, max_block_link_density: 0.9 };
453
454        let lenient_text = extract_node_text(&article, &lenient);
455        let strict_text = extract_node_text(&article, &strict);
456
457        // The strict filter must drop the short paragraph -> different output.
458        assert!(lenient_text.contains("Short one"));
459        assert!(!strict_text.contains("Short one"));
460        assert_ne!(lenient_text, strict_text);
461    }
462
463    #[test]
464    fn normalized_params_map_into_range() {
465        let lo = ExtractionParams::from_normalized(&[-1.0, -1.0]);
466        let hi = ExtractionParams::from_normalized(&[1.0, 1.0]);
467        assert!(lo.min_block_words >= 1);
468        assert!(hi.min_block_words <= 40);
469        assert!(lo.max_block_link_density >= 0.1);
470        assert!(hi.max_block_link_density <= 0.9);
471        assert!(hi.min_block_words > lo.min_block_words);
472    }
473}