Skip to main content

content_extractor_rl/
node_classifier.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/node_classifier.rs
3// ============================================================================
4//! Supervised content-node classifier (the "hybrid" half of the system).
5//!
6//! Content extraction is, at heart, a supervised node-classification problem:
7//! given the candidate DOM nodes of a page and labelled ground-truth article
8//! text, learn which candidate is the article body. This is far more
9//! sample-efficient and stable than asking RL to discover node selection from a
10//! sparse reward. The division of labour is therefore:
11//!
12//!   * this classifier picks **which node** is the content root (supervised);
13//!   * the RL policy tunes the **continuous extraction params** within it.
14//!
15//! Labels come for free from the data: the candidate whose extracted text has
16//! the highest token-F1 against the ground-truth article is the positive
17//! example, the rest are negatives ([`label_from_f1`]).
18
19use candle_core::{DType, Device, Tensor};
20use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
21use candle_nn::ops::sigmoid;
22
23use crate::html_parser::HtmlParser;
24use crate::node_features::{self, CandidateContent, ExtractionParams, NodeFeatures};
25use crate::text_utils::TextUtils;
26use crate::training::TrainingSample;
27use crate::{Config, ExtractionError, Result};
28use std::collections::HashSet;
29use std::path::Path;
30
31/// A small MLP that maps [`NodeFeatures`] to a content-probability.
32pub struct NodeClassifier {
33    fc1: Linear,
34    fc2: Linear,
35    out: Linear,
36    varmap: VarMap,
37    optimizer: AdamW,
38    device: Device,
39}
40
41impl NodeClassifier {
42    /// Create a new, randomly-initialized classifier.
43    pub fn new(device: &Device, lr: f64) -> Result<Self> {
44        let varmap = VarMap::new();
45        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
46
47        let hidden1 = 32;
48        let hidden2 = 16;
49        let fc1 = linear(NodeFeatures::DIM, hidden1, vb.pp("fc1"))?;
50        let fc2 = linear(hidden1, hidden2, vb.pp("fc2"))?;
51        let out = linear(hidden2, 1, vb.pp("out"))?;
52
53        let params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 1e-5 };
54        let optimizer = AdamW::new(varmap.all_vars(), params)
55            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
56
57        Ok(Self { fc1, fc2, out, varmap, optimizer, device: device.clone() })
58    }
59
60    /// Forward pass returning raw logits of shape `[batch]`.
61    fn logits(&self, x: &Tensor) -> candle_core::Result<Tensor> {
62        let x = self.fc1.forward(x)?.relu()?;
63        let x = self.fc2.forward(&x)?.relu()?;
64        self.out.forward(&x)?.squeeze(1)
65    }
66
67    fn features_to_tensor(&self, features: &[NodeFeatures]) -> candle_core::Result<Tensor> {
68        let flat: Vec<f32> = features.iter().flat_map(|f| f.to_vec()).collect();
69        Tensor::from_vec(flat, &[features.len(), NodeFeatures::DIM], &self.device)
70    }
71
72    /// Content probability in [0, 1] for each candidate.
73    pub fn score_batch(&self, features: &[NodeFeatures]) -> Result<Vec<f32>> {
74        if features.is_empty() {
75            return Ok(Vec::new());
76        }
77        let x = self.features_to_tensor(features)
78            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
79        let probs = sigmoid(&self.logits(&x).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?)
80            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
81        probs.to_vec1::<f32>()
82            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))
83    }
84
85    /// Index of the candidate most likely to be the content root.
86    pub fn select_best(&self, features: &[NodeFeatures]) -> Result<Option<usize>> {
87        let scores = self.score_batch(features)?;
88        Ok(scores
89            .iter()
90            .enumerate()
91            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
92            .map(|(i, _)| i))
93    }
94
95    /// One supervised gradient step over a batch of `(features, label)` pairs,
96    /// where `label` is 1.0 for the content node and 0.0 otherwise. Returns the
97    /// binary cross-entropy loss.
98    pub fn train_batch(&mut self, features: &[NodeFeatures], labels: &[f32]) -> Result<f32> {
99        assert_eq!(features.len(), labels.len());
100        if features.is_empty() {
101            return Ok(0.0);
102        }
103
104        let x = self.features_to_tensor(features)
105            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
106        let targets = Tensor::from_vec(labels.to_vec(), &[labels.len()], &self.device)
107            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
108
109        let loss = self
110            .bce_loss(&x, &targets)
111            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
112        let loss_val = loss.to_scalar::<f32>()
113            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
114        if loss_val.is_nan() || loss_val.is_infinite() {
115            return Ok(f32::NAN);
116        }
117
118        let grads = loss.backward()
119            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
120        self.optimizer.step(&grads)
121            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
122        Ok(loss_val)
123    }
124
125    fn bce_loss(&self, x: &Tensor, targets: &Tensor) -> candle_core::Result<Tensor> {
126        let p = sigmoid(&self.logits(x)?)?;
127        // Clamp to avoid log(0).
128        let p = p.clamp(1e-7f32, 1.0f32 - 1e-7f32)?;
129        let log_p = p.log()?;
130        let log_1mp = p.affine(-1.0, 1.0)?.log()?; // log(1 - p)
131        let pos = targets.mul(&log_p)?;
132        let neg = targets.affine(-1.0, 1.0)?.mul(&log_1mp)?; // (1 - y) * log(1 - p)
133        (pos + neg)?.neg()?.mean_all()
134    }
135
136    /// Number of trainable parameters (for logging / metadata).
137    pub fn num_parameters(&self) -> usize {
138        self.varmap
139            .all_vars()
140            .iter()
141            .map(|v| v.as_tensor().elem_count())
142            .sum()
143    }
144
145    /// Serialize the classifier weights to a `.safetensors` file.
146    pub fn save(&self, path: &Path) -> Result<()> {
147        self.varmap
148            .save(path)
149            .map_err(|e| ExtractionError::ModelError(format!("classifier save failed: {e}")))
150    }
151
152    /// Load a classifier previously written with [`Self::save`].
153    ///
154    /// The architecture is fixed (so `new` recreates the same variables) and the
155    /// stored tensors are copied into them; `lr` only affects subsequent
156    /// training and is irrelevant for inference.
157    pub fn load(path: &Path, device: &Device, lr: f64) -> Result<Self> {
158        let mut clf = Self::new(device, lr)?;
159        clf.varmap
160            .load(path)
161            .map_err(|e| ExtractionError::ModelError(format!("classifier load failed: {e}")))?;
162        Ok(clf)
163    }
164}
165
166/// Build a pointwise training set for the classifier from labelled samples.
167///
168/// For every sample that has ground-truth text, each candidate node becomes one
169/// `(features, label)` example, where `label = 1.0` for the best-F1 candidate
170/// (via [`label_from_f1`]) and `0.0` otherwise. Samples without ground truth, or
171/// where no candidate matches, are skipped.
172pub fn build_classifier_dataset(
173    samples: &[TrainingSample],
174    num_candidates: usize,
175    stopwords: &HashSet<String>,
176) -> (Vec<NodeFeatures>, Vec<f32>) {
177    let mut features = Vec::new();
178    let mut labels = Vec::new();
179
180    for sample in samples {
181        let Some(gt) = sample.ground_truth_text.as_deref() else { continue };
182        let Ok(document) = HtmlParser::clean_html(&sample.html) else { continue };
183        let candidates = HtmlParser::get_candidate_nodes(&document, num_candidates);
184        if candidates.is_empty() {
185            continue;
186        }
187        let contents: Vec<CandidateContent> =
188            candidates.iter().map(node_features::node_content).collect();
189        let Some(sample_labels) = label_from_f1(&contents, gt, stopwords) else { continue };
190
191        for (candidate, label) in candidates.iter().zip(sample_labels) {
192            features.push(node_features::extract_features(candidate, stopwords));
193            labels.push(label);
194        }
195    }
196
197    (features, labels)
198}
199
200/// Train a [`NodeClassifier`] on labelled samples for `epochs` full-batch steps.
201///
202/// Returns the trained classifier and the final BCE loss. Errors if no sample
203/// carried usable ground truth.
204pub fn train_classifier(
205    samples: &[TrainingSample],
206    config: &Config,
207    epochs: usize,
208    lr: f64,
209    device: &Device,
210) -> Result<(NodeClassifier, f32)> {
211    let (features, labels) =
212        build_classifier_dataset(samples, config.num_candidate_nodes, &config.stopwords);
213
214    if features.is_empty() {
215        return Err(ExtractionError::ModelError(
216            "no labelled training examples for the classifier (samples need ground-truth text)"
217                .to_string(),
218        ));
219    }
220
221    let mut classifier = NodeClassifier::new(device, lr)?;
222    let mut last_loss = f32::NAN;
223    for _ in 0..epochs.max(1) {
224        last_loss = classifier.train_batch(&features, &labels)?;
225        if last_loss.is_nan() {
226            return Err(ExtractionError::ModelError(
227                "classifier training diverged (NaN loss)".to_string(),
228            ));
229        }
230    }
231
232    Ok((classifier, last_loss))
233}
234
235/// Result of a hybrid extraction.
236#[derive(Debug, Clone)]
237pub struct HybridExtraction {
238    /// The extracted article text.
239    pub text: String,
240    /// XPath of the selected content node.
241    pub xpath: String,
242    /// Index of the selected candidate.
243    pub candidate_index: usize,
244    /// Content score of the selected candidate (classifier prob or heuristic).
245    pub score: f32,
246}
247
248/// End-to-end hybrid extractor: the classifier (supervised) picks the content
249/// node, then the RL-tuned [`ExtractionParams`] drive block-level extraction
250/// within it. When no trained classifier is supplied it falls back to the
251/// Readability-style [`NodeFeatures::heuristic_content_score`], so it is useful
252/// even before any training has happened.
253pub struct HybridExtractor {
254    classifier: Option<NodeClassifier>,
255    stopwords: HashSet<String>,
256}
257
258impl HybridExtractor {
259    /// Heuristic-only extractor (no learned model).
260    pub fn heuristic(stopwords: HashSet<String>) -> Self {
261        Self { classifier: None, stopwords }
262    }
263
264    /// Extractor backed by a trained node classifier.
265    pub fn with_classifier(classifier: NodeClassifier, stopwords: HashSet<String>) -> Self {
266        Self { classifier: Some(classifier), stopwords }
267    }
268
269    /// Extract article content from a page. Returns `None` when the document
270    /// exposes no candidate nodes.
271    pub fn extract(
272        &self,
273        html: &str,
274        num_candidates: usize,
275        params: &ExtractionParams,
276    ) -> Result<Option<HybridExtraction>> {
277        let document = HtmlParser::clean_html(html)?;
278        let candidates = HtmlParser::get_candidate_nodes(&document, num_candidates);
279        if candidates.is_empty() {
280            return Ok(None);
281        }
282
283        let features: Vec<NodeFeatures> = candidates
284            .iter()
285            .map(|e| node_features::extract_features(e, &self.stopwords))
286            .collect();
287
288        let scores: Vec<f32> = match &self.classifier {
289            Some(c) => c.score_batch(&features)?,
290            None => features.iter().map(|f| f.heuristic_content_score()).collect(),
291        };
292
293        let best = scores
294            .iter()
295            .enumerate()
296            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
297            .map(|(i, _)| i)
298            .unwrap_or(0);
299
300        let content = node_features::node_content(&candidates[best]);
301        Ok(Some(HybridExtraction {
302            text: content.extract(params),
303            xpath: HtmlParser::get_element_path(candidates[best]),
304            candidate_index: best,
305            score: scores[best],
306        }))
307    }
308}
309
310/// Derive the supervised label vector for one page: 1.0 for the candidate whose
311/// extracted text best matches the ground truth (token F1), 0.0 for the rest.
312///
313/// Returns `None` when there is no ground truth or no candidate scores above
314/// zero (nothing to learn from for this page).
315pub fn label_from_f1(
316    contents: &[CandidateContent],
317    ground_truth: &str,
318    stopwords: &HashSet<String>,
319) -> Option<Vec<f32>> {
320    if contents.is_empty() || ground_truth.trim().is_empty() {
321        return None;
322    }
323
324    let params = ExtractionParams::default();
325    let f1s: Vec<f32> = contents
326        .iter()
327        .map(|c| TextUtils::token_f1(&c.extract(&params), ground_truth, stopwords))
328        .collect();
329
330    let (best_idx, best_f1) = f1s
331        .iter()
332        .enumerate()
333        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
334        .map(|(i, v)| (i, *v))?;
335
336    if best_f1 <= 0.0 {
337        return None;
338    }
339
340    Some((0..contents.len()).map(|i| if i == best_idx { 1.0 } else { 0.0 }).collect())
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    /// A clearly content-like feature vector.
348    fn content_features(seed: f32) -> NodeFeatures {
349        let mut f = NodeFeatures::zeros();
350        f.word_count_norm = 0.8 + 0.1 * seed;
351        f.char_count_norm = 0.8;
352        f.p_count_norm = 0.7;
353        f.link_density = 0.05;
354        f.stopword_ratio = 0.45;
355        f.tag_article = 1.0;
356        f.class_positive = 1.0;
357        f.unique_word_ratio = 0.6;
358        f
359    }
360
361    /// A clearly boilerplate-like feature vector.
362    fn boilerplate_features(seed: f32) -> NodeFeatures {
363        let mut f = NodeFeatures::zeros();
364        f.word_count_norm = 0.1 + 0.05 * seed;
365        f.char_count_norm = 0.1;
366        f.p_count_norm = 0.0;
367        f.link_density = 0.9;
368        f.stopword_ratio = 0.1;
369        f.tag_div = 1.0;
370        f.class_negative = 1.0;
371        f.unique_word_ratio = 0.95;
372        f
373    }
374
375    #[test]
376    fn classifier_learns_to_rank_content_above_boilerplate() {
377        let device = Device::Cpu;
378        let mut clf = NodeClassifier::new(&device, 1e-2).unwrap();
379
380        // Build a synthetic dataset: 4 candidates per page, the first is content.
381        let mut features = Vec::new();
382        let mut labels = Vec::new();
383        for k in 0..40 {
384            let s = (k % 5) as f32 / 5.0;
385            features.push(content_features(s));
386            labels.push(1.0);
387            for _ in 0..3 {
388                features.push(boilerplate_features(s));
389                labels.push(0.0);
390            }
391        }
392
393        let first_loss = clf.train_batch(&features, &labels).unwrap();
394        let mut last_loss = first_loss;
395        for _ in 0..200 {
396            last_loss = clf.train_batch(&features, &labels).unwrap();
397            assert!(!last_loss.is_nan(), "loss went NaN");
398        }
399        assert!(last_loss < first_loss, "loss should decrease: {first_loss} -> {last_loss}");
400
401        // The content node must score above boilerplate, and select_best must
402        // pick it out of a fresh page.
403        let page = vec![
404            boilerplate_features(0.2),
405            content_features(0.3),
406            boilerplate_features(0.4),
407        ];
408        let scores = clf.score_batch(&page).unwrap();
409        assert!(scores[1] > scores[0] && scores[1] > scores[2], "scores: {scores:?}");
410        assert_eq!(clf.select_best(&page).unwrap(), Some(1));
411    }
412
413    #[test]
414    fn hybrid_extractor_heuristic_picks_article() {
415        let stopwords: HashSet<String> = ["the", "a", "is", "to", "of", "and", "for", "in"]
416            .into_iter()
417            .map(|s| s.to_string())
418            .collect();
419
420        let html = r#"
421            <html><body>
422                <nav class="site-nav"><a href="/a">Home</a> <a href="/b">Sports</a> <a href="/c">World</a></nav>
423                <article class="article-body">
424                    <p>The mission successfully entered orbit after a seven month journey through deep space.</p>
425                    <p>Engineers celebrated as telemetry confirmed every subsystem performed within tolerances.</p>
426                    <p>The spacecraft will now begin its primary science campaign mapping the surface below.</p>
427                </article>
428                <div class="footer-links"><a href="/x">Privacy</a> <a href="/y">Terms</a> <a href="/z">Contact</a></div>
429            </body></html>
430        "#;
431
432        let extractor = HybridExtractor::heuristic(stopwords);
433        let result = extractor
434            .extract(html, 10, &ExtractionParams::default())
435            .unwrap()
436            .expect("should extract something");
437
438        // The heuristic must select the article body, not the nav or footer.
439        assert!(result.text.contains("entered orbit"), "got: {}", result.text);
440        assert!(!result.text.contains("Privacy"));
441        assert!(result.score > 0.0);
442    }
443
444    #[test]
445    fn save_load_round_trip_preserves_scores() {
446        let device = Device::Cpu;
447        let mut clf = NodeClassifier::new(&device, 1e-2).unwrap();
448
449        // Train briefly so the weights are non-trivial.
450        let feats = vec![content_features(0.1), boilerplate_features(0.2)];
451        let labels = vec![1.0, 0.0];
452        for _ in 0..50 {
453            clf.train_batch(&feats, &labels).unwrap();
454        }
455
456        let page = vec![boilerplate_features(0.3), content_features(0.4)];
457        let before = clf.score_batch(&page).unwrap();
458
459        let dir = tempfile::TempDir::new().unwrap();
460        let path = dir.path().join("clf.safetensors");
461        clf.save(&path).unwrap();
462
463        let loaded = NodeClassifier::load(&path, &device, 1e-2).unwrap();
464        let after = loaded.score_batch(&page).unwrap();
465
466        for (b, a) in before.iter().zip(after.iter()) {
467            assert!((b - a).abs() < 1e-5, "score changed after reload: {b} vs {a}");
468        }
469    }
470
471    #[test]
472    fn train_classifier_fits_labelled_pages() {
473        use crate::Config;
474
475        let page = r#"
476            <html><body>
477                <nav class="site-nav"><a href="/a">Home</a> <a href="/b">News</a> <a href="/c">Sport</a></nav>
478                <article class="article-body">
479                    <p>The committee approved the new budget after a lengthy debate on public spending.</p>
480                    <p>Officials said the additional funds would be directed toward infrastructure projects.</p>
481                    <p>Opposition members requested further review of the long term fiscal projections.</p>
482                </article>
483                <div class="footer-links"><a href="/p">Privacy</a> <a href="/t">Terms</a> <a href="/s">Subscribe</a></div>
484            </body></html>
485        "#;
486        let gt = "The committee approved the new budget after a lengthy debate on public spending. \
487                  Officials said the additional funds would be directed toward infrastructure projects. \
488                  Opposition members requested further review of the long term fiscal projections.";
489
490        let config = Config::default();
491        let device = Device::Cpu;
492
493        let samples = vec![TrainingSample::with_ground_truth(
494            page.to_string(),
495            "https://example.com/budget".to_string(),
496            gt.to_string(),
497        )];
498
499        let (clf, loss) = train_classifier(&samples, &config, 150, 1e-2, &device).unwrap();
500        assert!(loss.is_finite() && loss < 0.5, "classifier did not fit (loss {loss})");
501
502        // After training, select_best must pick the candidate the labels marked
503        // as content (the article body).
504        let document = HtmlParser::clean_html(page).unwrap();
505        let candidates =
506            HtmlParser::get_candidate_nodes(&document, config.num_candidate_nodes);
507        let contents: Vec<_> =
508            candidates.iter().map(node_features::node_content).collect();
509        let labels = label_from_f1(&contents, gt, &config.stopwords).unwrap();
510        let expected = labels.iter().position(|&l| l == 1.0).unwrap();
511
512        let features: Vec<_> = candidates
513            .iter()
514            .map(|c| node_features::extract_features(c, &config.stopwords))
515            .collect();
516        assert_eq!(clf.select_best(&features).unwrap(), Some(expected));
517    }
518
519    #[test]
520    fn train_classifier_errors_without_ground_truth() {
521        use crate::Config;
522        let config = Config::default();
523        let samples = vec![TrainingSample::from((
524            "<html><body><article><p>no ground truth here</p></article></body></html>".to_string(),
525            "https://example.com/x".to_string(),
526        ))];
527        assert!(train_classifier(&samples, &config, 10, 1e-2, &Device::Cpu).is_err());
528    }
529
530    #[test]
531    fn label_from_f1_picks_best_matching_candidate() {
532        use crate::node_features::node_content;
533        use scraper::{Html, Selector};
534
535        let html = r#"
536            <html><body>
537                <div class="nav"><a href="/a">Home</a> <a href="/b">News</a></div>
538                <article>
539                    <p>The central bank raised interest rates today citing persistent inflation pressures.</p>
540                    <p>Economists expect further tightening over the coming quarters as growth slows.</p>
541                </article>
542            </body></html>
543        "#;
544        let doc = Html::parse_document(html);
545        let gt = "The central bank raised interest rates today citing persistent inflation \
546                  pressures. Economists expect further tightening over the coming quarters as \
547                  growth slows.";
548
549        let stopwords: HashSet<String> = ["the", "a", "as", "today", "over"]
550            .into_iter()
551            .map(|s| s.to_string())
552            .collect();
553
554        let nav = doc.select(&Selector::parse("div").unwrap()).next().unwrap();
555        let article = doc.select(&Selector::parse("article").unwrap()).next().unwrap();
556        let contents = vec![node_content(&nav), node_content(&article)];
557
558        let labels = label_from_f1(&contents, gt, &stopwords).unwrap();
559        // The article (index 1) must be the positive label.
560        assert_eq!(labels, vec![0.0, 1.0]);
561
562        // No ground truth -> no labels.
563        assert!(label_from_f1(&contents, "", &stopwords).is_none());
564    }
565}