1use crate::baseline_extractor::BaselineExtractor;
14use crate::html_parser::HtmlParser;
15use crate::node_features::{self, CandidateContent, ExtractionParams, NodeFeatures};
16use crate::text_utils::TextUtils;
17use crate::site_profile::SiteProfile;
18use crate::config::{
19 Config, ACTION_SELECT_PARENT, ACTION_SELECT_SIBLING_LEFT, ACTION_SELECT_SIBLING_RIGHT,
20 ACTION_EXPAND_REGION, ACTION_CONTRACT_REGION, ACTION_TERMINATE,
21};
22use crate::Result;
23use std::collections::HashMap;
24
25struct CandidateNode {
28 xpath: String,
29 features: NodeFeatures,
30 content: CandidateContent,
31}
32
33pub struct ArticleExtractionEnvironment {
35 baseline_extractor: BaselineExtractor,
36 candidates: Vec<CandidateNode>,
37 current_node_idx: Option<usize>,
38 word_threshold_adjust: i32,
40 terminated: bool,
41 url: String,
42 domain: String,
43 ground_truth_text: String,
44 baseline_fallback: String,
46 step_count: usize,
47 max_steps: usize,
48 config: Config,
49}
50
51impl ArticleExtractionEnvironment {
52 pub fn new(baseline_extractor: BaselineExtractor, config: Config) -> Self {
54 Self {
55 baseline_extractor,
56 candidates: Vec::new(),
57 current_node_idx: None,
58 word_threshold_adjust: 0,
59 terminated: false,
60 url: String::new(),
61 domain: String::new(),
62 ground_truth_text: String::new(),
63 baseline_fallback: String::new(),
64 step_count: 0,
65 max_steps: config.max_steps_per_episode,
66 config,
67 }
68 }
69
70 pub fn reset(
73 &mut self,
74 html: &str,
75 url: String,
76 ground_truth_text: Option<&str>,
77 _site_profile: Option<&SiteProfile>,
78 ) -> Result<Vec<f32>> {
79 self.url = url.clone();
80 self.domain = Self::extract_domain(&url);
81 self.step_count = 0;
82 self.word_threshold_adjust = 0;
83 self.terminated = false;
84 self.ground_truth_text = ground_truth_text.unwrap_or("").to_string();
85
86 let document = HtmlParser::clean_html(html)?;
89 let candidate_refs =
90 HtmlParser::get_candidate_nodes(&document, self.config.num_candidate_nodes);
91
92 self.candidates = candidate_refs
93 .iter()
94 .map(|node| CandidateNode {
95 xpath: HtmlParser::get_element_path(*node),
96 features: node_features::extract_features(node, &self.config.stopwords),
97 content: node_features::node_content(node),
98 })
99 .collect();
100
101 self.baseline_fallback = if self.candidates.is_empty() {
103 self.baseline_extractor
104 .extract(&document.html())
105 .map(|r| r.text)
106 .unwrap_or_default()
107 } else {
108 String::new()
109 };
110
111 self.current_node_idx = if self.candidates.is_empty() { None } else { Some(0) };
112
113 self.build_state()
114 }
115
116 pub fn step(&mut self, action: (usize, Vec<f32>)) -> Result<(Vec<f32>, f32, bool, StepInfo)> {
118 let (discrete_action, params) = action;
119 self.step_count += 1;
120
121 let n = self.candidates.len();
122
123 match discrete_action {
127 d if d < self.config.num_candidate_nodes => {
128 if n > 0 {
129 self.current_node_idx = Some(d.min(n - 1));
130 }
131 }
132 ACTION_SELECT_PARENT => self.select_parent(),
133 ACTION_SELECT_SIBLING_LEFT => {
134 if let Some(idx) = self.current_node_idx {
135 self.current_node_idx = Some(idx.saturating_sub(1));
136 }
137 }
138 ACTION_SELECT_SIBLING_RIGHT => {
139 if let (Some(idx), true) = (self.current_node_idx, n > 0) {
140 self.current_node_idx = Some((idx + 1).min(n - 1));
141 }
142 }
143 ACTION_EXPAND_REGION => {
145 self.word_threshold_adjust = (self.word_threshold_adjust - 2).max(-20);
146 }
147 ACTION_CONTRACT_REGION => {
148 self.word_threshold_adjust = (self.word_threshold_adjust + 2).min(40);
149 }
150 ACTION_TERMINATE => self.terminated = true,
151 _ => {}
152 }
153
154 let effective_params = self.effective_params(¶ms);
156 let extracted_text = self.extract_selected(&effective_params);
157
158 let score = if self.ground_truth_text.is_empty() {
162 TextUtils::calculate_text_quality(&extracted_text, &self.config.stopwords)
163 } else {
164 TextUtils::token_f1(&extracted_text, &self.ground_truth_text, &self.config.stopwords)
165 };
166 let reward = (score * 2.0 - 1.0 - 0.01 * self.step_count as f32).clamp(-1.0, 1.0);
167
168 let done = self.terminated || self.step_count >= self.max_steps;
169
170 let next_state = self.build_state()?;
171
172 let info = StepInfo {
173 quality_score: score,
174 text: extracted_text,
175 xpath: self
176 .current_node_idx
177 .and_then(|idx| self.candidates.get(idx))
178 .map(|c| c.xpath.clone())
179 .unwrap_or_default(),
180 parameters: self.denormalize_params(¶ms),
181 step_count: self.step_count,
182 };
183
184 Ok((next_state, reward, done, info))
185 }
186
187 fn select_parent(&mut self) {
190 let Some(idx) = self.current_node_idx else { return };
191 let current_path = self.candidates[idx].xpath.clone();
192
193 let mut best: Option<(usize, usize)> = None; for (j, cand) in self.candidates.iter().enumerate() {
195 if j == idx {
196 continue;
197 }
198 if current_path.starts_with(&cand.xpath) && cand.xpath.len() < current_path.len() {
199 let better = best.map(|(_, len)| cand.xpath.len() > len).unwrap_or(true);
200 if better {
201 best = Some((j, cand.xpath.len()));
202 }
203 }
204 }
205 if let Some((j, _)) = best {
206 self.current_node_idx = Some(j);
207 }
208 }
209
210 fn effective_params(&self, params: &[f32]) -> ExtractionParams {
212 let mut p = ExtractionParams::from_normalized(params);
213 let adjusted = p.min_block_words as i32 + self.word_threshold_adjust;
214 p.min_block_words = adjusted.clamp(1, 60) as usize;
215 p
216 }
217
218 fn extract_selected(&self, params: &ExtractionParams) -> String {
220 match self.current_node_idx.and_then(|idx| self.candidates.get(idx)) {
221 Some(cand) => cand.content.extract(params),
222 None => self.baseline_fallback.clone(),
223 }
224 }
225
226 fn denormalize_params(&self, params: &[f32]) -> HashMap<String, f64> {
228 let mut result = HashMap::new();
229
230 if params.len() >= 6 {
231 result.insert("min_word_threshold".to_string(), (2.0 + (params[0] + 1.0) * 4.0) as f64);
232 result.insert("stopword_weight".to_string(), (0.5 + (params[1] + 1.0) * 0.75) as f64);
233 result.insert("link_density_penalty".to_string(), ((params[2] + 1.0) * 1.0) as f64);
234 result.insert("paragraph_boost".to_string(), (1.0 + (params[3] + 1.0) * 0.5) as f64);
235 result.insert("sibling_extension".to_string(), ((params[4] + 1.0) * 0.5) as f64);
236 result.insert("depth_penalty".to_string(), ((params[5] + 1.0) * 0.25) as f64);
237 }
238
239 result
240 }
241
242 fn build_state(&self) -> Result<Vec<f32>> {
250 let mut state = Vec::with_capacity(self.config.state_dim);
251
252 for slot in 0..self.config.num_candidate_nodes {
254 match self.candidates.get(slot) {
255 Some(c) => state.extend(c.features.to_vec()),
256 None => state.extend(NodeFeatures::zeros().to_vec()),
257 }
258 }
259
260 let n = self.candidates.len();
262 let num_candidates_norm = (n as f32 / self.config.num_candidate_nodes as f32).clamp(0.0, 1.0);
263 let max_word = self
264 .candidates
265 .iter()
266 .map(|c| c.features.word_count_norm)
267 .fold(0.0_f32, f32::max);
268 let mean_link_density = if n == 0 {
269 0.0
270 } else {
271 self.candidates.iter().map(|c| c.features.link_density).sum::<f32>() / n as f32
272 };
273 let has_article = if self.candidates.iter().any(|c| c.features.tag_article > 0.5) { 1.0 } else { 0.0 };
274 let has_main = if self.candidates.iter().any(|c| c.features.tag_main > 0.5) { 1.0 } else { 0.0 };
275 let mean_stopword = if n == 0 {
276 0.0
277 } else {
278 self.candidates.iter().map(|c| c.features.stopword_ratio).sum::<f32>() / n as f32
279 };
280
281 state.push(num_candidates_norm);
282 state.push(max_word);
283 state.push(mean_link_density);
284 state.push(mean_stopword);
285 state.push(has_article);
286 state.push(has_main);
287 state.push(Self::hash_domain_normalized(&self.domain));
288 state.push(self.ground_truth_text.is_empty() as i32 as f32);
289
290 for slot in 0..self.config.num_candidate_nodes {
292 state.push(if self.current_node_idx == Some(slot) { 1.0 } else { 0.0 });
293 }
294 state.push(self.step_count as f32 / self.max_steps.max(1) as f32);
295 state.push((self.word_threshold_adjust as f32 / 40.0).clamp(-1.0, 1.0));
296 state.push(self.terminated as i32 as f32);
297
298 state.truncate(self.config.state_dim);
300 while state.len() < self.config.state_dim {
301 state.push(0.0);
302 }
303
304 Ok(state)
305 }
306
307 fn extract_domain(url: &str) -> String {
309 url::Url::parse(url)
310 .ok()
311 .and_then(|u| u.host_str().map(|h| h.to_string()))
312 .unwrap_or_else(|| "unknown".to_string())
313 }
314
315 fn hash_domain_normalized(domain: &str) -> f32 {
317 use sha2::{Sha256, Digest};
318
319 let mut hasher = Sha256::new();
320 hasher.update(domain.as_bytes());
321 let result = hasher.finalize();
322
323 let hash_val = u32::from_be_bytes([result[0], result[1], result[2], result[3]]);
324 (hash_val % 10000) as f32 / 10000.0
325 }
326}
327
328#[derive(Debug, Clone)]
330pub struct StepInfo {
331 pub quality_score: f32,
332 pub text: String,
333 pub xpath: String,
334 pub parameters: HashMap<String, f64>,
335 pub step_count: usize,
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 fn test_html() -> &'static str {
343 r#"
344 <html><body>
345 <nav class="navigation"><a href="/a">Home</a> <a href="/b">About</a> <a href="/c">Contact</a></nav>
346 <article class="article-content">
347 <p>Quantum researchers reported a significant breakthrough in error correction this week.</p>
348 <p>The new technique stabilizes qubits for longer durations enabling deeper computations.</p>
349 <p>Independent laboratories confirmed the reproducible measurements across several runs.</p>
350 </article>
351 <div class="sidebar-ads"><a href="/x">Buy now</a> <a href="/y">Subscribe today</a></div>
352 </body></html>
353 "#
354 }
355
356 fn env() -> ArticleExtractionEnvironment {
357 let config = Config::default();
358 let baseline = BaselineExtractor::new(config.stopwords.clone());
359 ArticleExtractionEnvironment::new(baseline, config)
360 }
361
362 #[test]
363 fn reset_produces_correct_state_dim_and_varies() {
364 let mut env = env();
365 let state = env
366 .reset(test_html(), "https://example.com/post".to_string(), None, None)
367 .unwrap();
368 assert_eq!(state.len(), env.config.state_dim);
369 let distinct: std::collections::HashSet<u32> =
371 state.iter().map(|f| f.to_bits()).collect();
372 assert!(distinct.len() > 5, "state should contain varied real features");
373 }
374
375 #[test]
376 fn action_choice_changes_reward() {
377 let gt = "Quantum researchers reported a significant breakthrough in error correction \
378 this week. The new technique stabilizes qubits for longer durations enabling \
379 deeper computations. Independent laboratories confirmed the reproducible \
380 measurements across several runs.";
381 let mut env = env();
382 env.reset(test_html(), "https://example.com/post".to_string(), Some(gt), None)
383 .unwrap();
384
385 let mut rewards = Vec::new();
387 for action in 0..env.candidates.len() {
388 env.reset(test_html(), "https://example.com/post".to_string(), Some(gt), None)
389 .unwrap();
390 let (_s, reward, _d, info) = env
391 .step((action, vec![-1.0, 0.0, 0.0, 0.0, 0.0, 0.0]))
392 .unwrap();
393 rewards.push((action, reward, info.quality_score));
394 }
395
396 let best = rewards.iter().cloned().fold((0usize, f32::MIN, 0.0), |acc, x| {
397 if x.1 > acc.1 { x } else { acc }
398 });
399 let worst = rewards.iter().cloned().fold((0usize, f32::MAX, 0.0), |acc, x| {
400 if x.1 < acc.1 { x } else { acc }
401 });
402
403 assert!(
405 (best.1 - worst.1).abs() > 1e-3,
406 "different node selections must yield different rewards: {rewards:?}"
407 );
408 assert!(best.2 > 0.5, "best F1 should be high, got {}", best.2);
410 }
411
412 #[test]
413 fn terminate_action_ends_episode() {
414 let mut env = env();
415 env.reset(test_html(), "https://example.com/post".to_string(), None, None)
416 .unwrap();
417 let (_s, _r, done, _info) = env.step((ACTION_TERMINATE, vec![0.0; 6])).unwrap();
418 assert!(done, "TERMINATE must end the episode");
419 }
420
421 #[test]
422 fn episode_force_terminates_at_max_steps() {
423 let mut env = env();
424 env.reset(test_html(), "https://example.com/post".to_string(), None, None)
425 .unwrap();
426 let mut done = false;
427 let mut steps = 0;
428 while !done && steps < env.config.max_steps_per_episode + 5 {
429 let (_s, _r, d, _i) = env.step((ACTION_SELECT_SIBLING_RIGHT, vec![0.0; 6])).unwrap();
431 done = d;
432 steps += 1;
433 }
434 assert!(done);
435 assert!(steps <= env.config.max_steps_per_episode);
436 }
437
438 #[test]
439 fn continuous_params_affect_extraction() {
440 let mut env = env();
441 env.reset(test_html(), "https://example.com/post".to_string(), None, None)
442 .unwrap();
443 let lenient = env.step((1, vec![-1.0, 1.0, 0.0, 0.0, 0.0, 0.0])).unwrap().3.text;
446 env.reset(test_html(), "https://example.com/post".to_string(), None, None)
447 .unwrap();
448 let strict = env.step((1, vec![1.0, 1.0, 0.0, 0.0, 0.0, 0.0])).unwrap().3.text;
449 assert!(strict.len() <= lenient.len());
451 }
452}