1use candle_core::{DType, Device, Tensor};
20use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
21use candle_nn::ops::sigmoid;
22
23use crate::html_parser::HtmlParser;
24use crate::node_features::{self, CandidateContent, ExtractionParams, NodeFeatures};
25use crate::text_utils::TextUtils;
26use crate::training::TrainingSample;
27use crate::{Config, ExtractionError, Result};
28use std::collections::HashSet;
29use std::path::Path;
30
31pub struct NodeClassifier {
33 fc1: Linear,
34 fc2: Linear,
35 out: Linear,
36 varmap: VarMap,
37 optimizer: AdamW,
38 device: Device,
39}
40
41impl NodeClassifier {
42 pub fn new(device: &Device, lr: f64) -> Result<Self> {
44 let varmap = VarMap::new();
45 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
46
47 let hidden1 = 32;
48 let hidden2 = 16;
49 let fc1 = linear(NodeFeatures::DIM, hidden1, vb.pp("fc1"))?;
50 let fc2 = linear(hidden1, hidden2, vb.pp("fc2"))?;
51 let out = linear(hidden2, 1, vb.pp("out"))?;
52
53 let params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 1e-5 };
54 let optimizer = AdamW::new(varmap.all_vars(), params)
55 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
56
57 Ok(Self { fc1, fc2, out, varmap, optimizer, device: device.clone() })
58 }
59
60 fn logits(&self, x: &Tensor) -> candle_core::Result<Tensor> {
62 let x = self.fc1.forward(x)?.relu()?;
63 let x = self.fc2.forward(&x)?.relu()?;
64 self.out.forward(&x)?.squeeze(1)
65 }
66
67 fn features_to_tensor(&self, features: &[NodeFeatures]) -> candle_core::Result<Tensor> {
68 let flat: Vec<f32> = features.iter().flat_map(|f| f.to_vec()).collect();
69 Tensor::from_vec(flat, &[features.len(), NodeFeatures::DIM], &self.device)
70 }
71
72 pub fn score_batch(&self, features: &[NodeFeatures]) -> Result<Vec<f32>> {
74 if features.is_empty() {
75 return Ok(Vec::new());
76 }
77 let x = self.features_to_tensor(features)
78 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
79 let probs = sigmoid(&self.logits(&x).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?)
80 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
81 probs.to_vec1::<f32>()
82 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))
83 }
84
85 pub fn select_best(&self, features: &[NodeFeatures]) -> Result<Option<usize>> {
87 let scores = self.score_batch(features)?;
88 Ok(scores
89 .iter()
90 .enumerate()
91 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
92 .map(|(i, _)| i))
93 }
94
95 pub fn train_batch(&mut self, features: &[NodeFeatures], labels: &[f32]) -> Result<f32> {
99 assert_eq!(features.len(), labels.len());
100 if features.is_empty() {
101 return Ok(0.0);
102 }
103
104 let x = self.features_to_tensor(features)
105 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
106 let targets = Tensor::from_vec(labels.to_vec(), &[labels.len()], &self.device)
107 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
108
109 let loss = self
110 .bce_loss(&x, &targets)
111 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
112 let loss_val = loss.to_scalar::<f32>()
113 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
114 if loss_val.is_nan() || loss_val.is_infinite() {
115 return Ok(f32::NAN);
116 }
117
118 let grads = loss.backward()
119 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
120 self.optimizer.step(&grads)
121 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
122 Ok(loss_val)
123 }
124
125 fn bce_loss(&self, x: &Tensor, targets: &Tensor) -> candle_core::Result<Tensor> {
126 let p = sigmoid(&self.logits(x)?)?;
127 let p = p.clamp(1e-7f32, 1.0f32 - 1e-7f32)?;
129 let log_p = p.log()?;
130 let log_1mp = p.affine(-1.0, 1.0)?.log()?; let pos = targets.mul(&log_p)?;
132 let neg = targets.affine(-1.0, 1.0)?.mul(&log_1mp)?; (pos + neg)?.neg()?.mean_all()
134 }
135
136 pub fn num_parameters(&self) -> usize {
138 self.varmap
139 .all_vars()
140 .iter()
141 .map(|v| v.as_tensor().elem_count())
142 .sum()
143 }
144
145 pub fn save(&self, path: &Path) -> Result<()> {
147 self.varmap
148 .save(path)
149 .map_err(|e| ExtractionError::ModelError(format!("classifier save failed: {e}")))
150 }
151
152 pub fn load(path: &Path, device: &Device, lr: f64) -> Result<Self> {
158 let mut clf = Self::new(device, lr)?;
159 clf.varmap
160 .load(path)
161 .map_err(|e| ExtractionError::ModelError(format!("classifier load failed: {e}")))?;
162 Ok(clf)
163 }
164}
165
166pub fn build_classifier_dataset(
173 samples: &[TrainingSample],
174 num_candidates: usize,
175 stopwords: &HashSet<String>,
176) -> (Vec<NodeFeatures>, Vec<f32>) {
177 let mut features = Vec::new();
178 let mut labels = Vec::new();
179
180 for sample in samples {
181 let Some(gt) = sample.ground_truth_text.as_deref() else { continue };
182 let Ok(document) = HtmlParser::clean_html(&sample.html) else { continue };
183 let candidates = HtmlParser::get_candidate_nodes(&document, num_candidates);
184 if candidates.is_empty() {
185 continue;
186 }
187 let contents: Vec<CandidateContent> =
188 candidates.iter().map(node_features::node_content).collect();
189 let Some(sample_labels) = label_from_f1(&contents, gt, stopwords) else { continue };
190
191 for (candidate, label) in candidates.iter().zip(sample_labels) {
192 features.push(node_features::extract_features(candidate, stopwords));
193 labels.push(label);
194 }
195 }
196
197 (features, labels)
198}
199
200pub fn train_classifier(
205 samples: &[TrainingSample],
206 config: &Config,
207 epochs: usize,
208 lr: f64,
209 device: &Device,
210) -> Result<(NodeClassifier, f32)> {
211 let (features, labels) =
212 build_classifier_dataset(samples, config.num_candidate_nodes, &config.stopwords);
213
214 if features.is_empty() {
215 return Err(ExtractionError::ModelError(
216 "no labelled training examples for the classifier (samples need ground-truth text)"
217 .to_string(),
218 ));
219 }
220
221 let mut classifier = NodeClassifier::new(device, lr)?;
222 let mut last_loss = f32::NAN;
223 for _ in 0..epochs.max(1) {
224 last_loss = classifier.train_batch(&features, &labels)?;
225 if last_loss.is_nan() {
226 return Err(ExtractionError::ModelError(
227 "classifier training diverged (NaN loss)".to_string(),
228 ));
229 }
230 }
231
232 Ok((classifier, last_loss))
233}
234
235#[derive(Debug, Clone)]
237pub struct HybridExtraction {
238 pub text: String,
240 pub xpath: String,
242 pub candidate_index: usize,
244 pub score: f32,
246}
247
248pub struct HybridExtractor {
254 classifier: Option<NodeClassifier>,
255 stopwords: HashSet<String>,
256}
257
258impl HybridExtractor {
259 pub fn heuristic(stopwords: HashSet<String>) -> Self {
261 Self { classifier: None, stopwords }
262 }
263
264 pub fn with_classifier(classifier: NodeClassifier, stopwords: HashSet<String>) -> Self {
266 Self { classifier: Some(classifier), stopwords }
267 }
268
269 pub fn extract(
272 &self,
273 html: &str,
274 num_candidates: usize,
275 params: &ExtractionParams,
276 ) -> Result<Option<HybridExtraction>> {
277 let document = HtmlParser::clean_html(html)?;
278 let candidates = HtmlParser::get_candidate_nodes(&document, num_candidates);
279 if candidates.is_empty() {
280 return Ok(None);
281 }
282
283 let features: Vec<NodeFeatures> = candidates
284 .iter()
285 .map(|e| node_features::extract_features(e, &self.stopwords))
286 .collect();
287
288 let scores: Vec<f32> = match &self.classifier {
289 Some(c) => c.score_batch(&features)?,
290 None => features.iter().map(|f| f.heuristic_content_score()).collect(),
291 };
292
293 let best = scores
294 .iter()
295 .enumerate()
296 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
297 .map(|(i, _)| i)
298 .unwrap_or(0);
299
300 let content = node_features::node_content(&candidates[best]);
301 Ok(Some(HybridExtraction {
302 text: content.extract(params),
303 xpath: HtmlParser::get_element_path(candidates[best]),
304 candidate_index: best,
305 score: scores[best],
306 }))
307 }
308}
309
310pub fn label_from_f1(
316 contents: &[CandidateContent],
317 ground_truth: &str,
318 stopwords: &HashSet<String>,
319) -> Option<Vec<f32>> {
320 if contents.is_empty() || ground_truth.trim().is_empty() {
321 return None;
322 }
323
324 let params = ExtractionParams::default();
325 let f1s: Vec<f32> = contents
326 .iter()
327 .map(|c| TextUtils::token_f1(&c.extract(¶ms), ground_truth, stopwords))
328 .collect();
329
330 let (best_idx, best_f1) = f1s
331 .iter()
332 .enumerate()
333 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
334 .map(|(i, v)| (i, *v))?;
335
336 if best_f1 <= 0.0 {
337 return None;
338 }
339
340 Some((0..contents.len()).map(|i| if i == best_idx { 1.0 } else { 0.0 }).collect())
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 fn content_features(seed: f32) -> NodeFeatures {
349 let mut f = NodeFeatures::zeros();
350 f.word_count_norm = 0.8 + 0.1 * seed;
351 f.char_count_norm = 0.8;
352 f.p_count_norm = 0.7;
353 f.link_density = 0.05;
354 f.stopword_ratio = 0.45;
355 f.tag_article = 1.0;
356 f.class_positive = 1.0;
357 f.unique_word_ratio = 0.6;
358 f
359 }
360
361 fn boilerplate_features(seed: f32) -> NodeFeatures {
363 let mut f = NodeFeatures::zeros();
364 f.word_count_norm = 0.1 + 0.05 * seed;
365 f.char_count_norm = 0.1;
366 f.p_count_norm = 0.0;
367 f.link_density = 0.9;
368 f.stopword_ratio = 0.1;
369 f.tag_div = 1.0;
370 f.class_negative = 1.0;
371 f.unique_word_ratio = 0.95;
372 f
373 }
374
375 #[test]
376 fn classifier_learns_to_rank_content_above_boilerplate() {
377 let device = Device::Cpu;
378 let mut clf = NodeClassifier::new(&device, 1e-2).unwrap();
379
380 let mut features = Vec::new();
382 let mut labels = Vec::new();
383 for k in 0..40 {
384 let s = (k % 5) as f32 / 5.0;
385 features.push(content_features(s));
386 labels.push(1.0);
387 for _ in 0..3 {
388 features.push(boilerplate_features(s));
389 labels.push(0.0);
390 }
391 }
392
393 let first_loss = clf.train_batch(&features, &labels).unwrap();
394 let mut last_loss = first_loss;
395 for _ in 0..200 {
396 last_loss = clf.train_batch(&features, &labels).unwrap();
397 assert!(!last_loss.is_nan(), "loss went NaN");
398 }
399 assert!(last_loss < first_loss, "loss should decrease: {first_loss} -> {last_loss}");
400
401 let page = vec![
404 boilerplate_features(0.2),
405 content_features(0.3),
406 boilerplate_features(0.4),
407 ];
408 let scores = clf.score_batch(&page).unwrap();
409 assert!(scores[1] > scores[0] && scores[1] > scores[2], "scores: {scores:?}");
410 assert_eq!(clf.select_best(&page).unwrap(), Some(1));
411 }
412
413 #[test]
414 fn hybrid_extractor_heuristic_picks_article() {
415 let stopwords: HashSet<String> = ["the", "a", "is", "to", "of", "and", "for", "in"]
416 .into_iter()
417 .map(|s| s.to_string())
418 .collect();
419
420 let html = r#"
421 <html><body>
422 <nav class="site-nav"><a href="/a">Home</a> <a href="/b">Sports</a> <a href="/c">World</a></nav>
423 <article class="article-body">
424 <p>The mission successfully entered orbit after a seven month journey through deep space.</p>
425 <p>Engineers celebrated as telemetry confirmed every subsystem performed within tolerances.</p>
426 <p>The spacecraft will now begin its primary science campaign mapping the surface below.</p>
427 </article>
428 <div class="footer-links"><a href="/x">Privacy</a> <a href="/y">Terms</a> <a href="/z">Contact</a></div>
429 </body></html>
430 "#;
431
432 let extractor = HybridExtractor::heuristic(stopwords);
433 let result = extractor
434 .extract(html, 10, &ExtractionParams::default())
435 .unwrap()
436 .expect("should extract something");
437
438 assert!(result.text.contains("entered orbit"), "got: {}", result.text);
440 assert!(!result.text.contains("Privacy"));
441 assert!(result.score > 0.0);
442 }
443
444 #[test]
445 fn save_load_round_trip_preserves_scores() {
446 let device = Device::Cpu;
447 let mut clf = NodeClassifier::new(&device, 1e-2).unwrap();
448
449 let feats = vec![content_features(0.1), boilerplate_features(0.2)];
451 let labels = vec![1.0, 0.0];
452 for _ in 0..50 {
453 clf.train_batch(&feats, &labels).unwrap();
454 }
455
456 let page = vec![boilerplate_features(0.3), content_features(0.4)];
457 let before = clf.score_batch(&page).unwrap();
458
459 let dir = tempfile::TempDir::new().unwrap();
460 let path = dir.path().join("clf.safetensors");
461 clf.save(&path).unwrap();
462
463 let loaded = NodeClassifier::load(&path, &device, 1e-2).unwrap();
464 let after = loaded.score_batch(&page).unwrap();
465
466 for (b, a) in before.iter().zip(after.iter()) {
467 assert!((b - a).abs() < 1e-5, "score changed after reload: {b} vs {a}");
468 }
469 }
470
471 #[test]
472 fn train_classifier_fits_labelled_pages() {
473 use crate::Config;
474
475 let page = r#"
476 <html><body>
477 <nav class="site-nav"><a href="/a">Home</a> <a href="/b">News</a> <a href="/c">Sport</a></nav>
478 <article class="article-body">
479 <p>The committee approved the new budget after a lengthy debate on public spending.</p>
480 <p>Officials said the additional funds would be directed toward infrastructure projects.</p>
481 <p>Opposition members requested further review of the long term fiscal projections.</p>
482 </article>
483 <div class="footer-links"><a href="/p">Privacy</a> <a href="/t">Terms</a> <a href="/s">Subscribe</a></div>
484 </body></html>
485 "#;
486 let gt = "The committee approved the new budget after a lengthy debate on public spending. \
487 Officials said the additional funds would be directed toward infrastructure projects. \
488 Opposition members requested further review of the long term fiscal projections.";
489
490 let config = Config::default();
491 let device = Device::Cpu;
492
493 let samples = vec![TrainingSample::with_ground_truth(
494 page.to_string(),
495 "https://example.com/budget".to_string(),
496 gt.to_string(),
497 )];
498
499 let (clf, loss) = train_classifier(&samples, &config, 150, 1e-2, &device).unwrap();
500 assert!(loss.is_finite() && loss < 0.5, "classifier did not fit (loss {loss})");
501
502 let document = HtmlParser::clean_html(page).unwrap();
505 let candidates =
506 HtmlParser::get_candidate_nodes(&document, config.num_candidate_nodes);
507 let contents: Vec<_> =
508 candidates.iter().map(node_features::node_content).collect();
509 let labels = label_from_f1(&contents, gt, &config.stopwords).unwrap();
510 let expected = labels.iter().position(|&l| l == 1.0).unwrap();
511
512 let features: Vec<_> = candidates
513 .iter()
514 .map(|c| node_features::extract_features(c, &config.stopwords))
515 .collect();
516 assert_eq!(clf.select_best(&features).unwrap(), Some(expected));
517 }
518
519 #[test]
520 fn train_classifier_errors_without_ground_truth() {
521 use crate::Config;
522 let config = Config::default();
523 let samples = vec![TrainingSample::from((
524 "<html><body><article><p>no ground truth here</p></article></body></html>".to_string(),
525 "https://example.com/x".to_string(),
526 ))];
527 assert!(train_classifier(&samples, &config, 10, 1e-2, &Device::Cpu).is_err());
528 }
529
530 #[test]
531 fn label_from_f1_picks_best_matching_candidate() {
532 use crate::node_features::node_content;
533 use scraper::{Html, Selector};
534
535 let html = r#"
536 <html><body>
537 <div class="nav"><a href="/a">Home</a> <a href="/b">News</a></div>
538 <article>
539 <p>The central bank raised interest rates today citing persistent inflation pressures.</p>
540 <p>Economists expect further tightening over the coming quarters as growth slows.</p>
541 </article>
542 </body></html>
543 "#;
544 let doc = Html::parse_document(html);
545 let gt = "The central bank raised interest rates today citing persistent inflation \
546 pressures. Economists expect further tightening over the coming quarters as \
547 growth slows.";
548
549 let stopwords: HashSet<String> = ["the", "a", "as", "today", "over"]
550 .into_iter()
551 .map(|s| s.to_string())
552 .collect();
553
554 let nav = doc.select(&Selector::parse("div").unwrap()).next().unwrap();
555 let article = doc.select(&Selector::parse("article").unwrap()).next().unwrap();
556 let contents = vec![node_content(&nav), node_content(&article)];
557
558 let labels = label_from_f1(&contents, gt, &stopwords).unwrap();
559 assert_eq!(labels, vec![0.0, 1.0]);
561
562 assert!(label_from_f1(&contents, "", &stopwords).is_none());
564 }
565}