Skip to main content

content_extractor_rl/
environment.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/environment.rs
3// ============================================================================
4//! RL environment for article extraction.
5//!
6//! This is a *real* MDP: the agent's discrete action chooses which DOM
7//! candidate becomes the content root, its continuous params tune block-level
8//! filtering, and the resulting extracted text is scored against the
9//! ground-truth article (token F1). Different actions therefore produce
10//! different rewards — the precondition for learning that the previous
11//! placeholder environment lacked.
12
13use crate::baseline_extractor::BaselineExtractor;
14use crate::html_parser::HtmlParser;
15use crate::node_features::{self, CandidateContent, ExtractionParams, NodeFeatures};
16use crate::text_utils::TextUtils;
17use crate::site_profile::SiteProfile;
18use crate::config::{
19    Config, ACTION_SELECT_PARENT, ACTION_SELECT_SIBLING_LEFT, ACTION_SELECT_SIBLING_RIGHT,
20    ACTION_EXPAND_REGION, ACTION_CONTRACT_REGION, ACTION_TERMINATE,
21};
22use crate::Result;
23use std::collections::HashMap;
24
25/// A candidate content node, captured as owned data so the environment does not
26/// need to hold borrowed `ElementRef`s across `step` calls.
27struct CandidateNode {
28    xpath: String,
29    features: NodeFeatures,
30    content: CandidateContent,
31}
32
33/// RL environment for article extraction
34pub struct ArticleExtractionEnvironment {
35    baseline_extractor: BaselineExtractor,
36    candidates: Vec<CandidateNode>,
37    current_node_idx: Option<usize>,
38    /// Coarse block-word-threshold adjustment driven by EXPAND/CONTRACT actions.
39    word_threshold_adjust: i32,
40    terminated: bool,
41    url: String,
42    domain: String,
43    ground_truth_text: String,
44    /// Cached baseline fallback text (used only when the DOM has no candidates).
45    baseline_fallback: String,
46    step_count: usize,
47    max_steps: usize,
48    config: Config,
49}
50
51impl ArticleExtractionEnvironment {
52    /// Create new environment
53    pub fn new(baseline_extractor: BaselineExtractor, config: Config) -> Self {
54        Self {
55            baseline_extractor,
56            candidates: Vec::new(),
57            current_node_idx: None,
58            word_threshold_adjust: 0,
59            terminated: false,
60            url: String::new(),
61            domain: String::new(),
62            ground_truth_text: String::new(),
63            baseline_fallback: String::new(),
64            step_count: 0,
65            max_steps: config.max_steps_per_episode,
66            config,
67        }
68    }
69
70    /// Reset environment with new HTML and (optionally) the ground-truth article
71    /// text used to compute the reward.
72    pub fn reset(
73        &mut self,
74        html: &str,
75        url: String,
76        ground_truth_text: Option<&str>,
77        _site_profile: Option<&SiteProfile>,
78    ) -> Result<Vec<f32>> {
79        self.url = url.clone();
80        self.domain = Self::extract_domain(&url);
81        self.step_count = 0;
82        self.word_threshold_adjust = 0;
83        self.terminated = false;
84        self.ground_truth_text = ground_truth_text.unwrap_or("").to_string();
85
86        // Parse and clean HTML, then snapshot the candidate nodes' features and
87        // extractable content as owned data.
88        let document = HtmlParser::clean_html(html)?;
89        let candidate_refs =
90            HtmlParser::get_candidate_nodes(&document, self.config.num_candidate_nodes);
91
92        self.candidates = candidate_refs
93            .iter()
94            .map(|node| CandidateNode {
95                xpath: HtmlParser::get_element_path(*node),
96                features: node_features::extract_features(node, &self.config.stopwords),
97                content: node_features::node_content(node),
98            })
99            .collect();
100
101        // Fallback only matters when the DOM exposed no candidate nodes at all.
102        self.baseline_fallback = if self.candidates.is_empty() {
103            self.baseline_extractor
104                .extract(&document.html())
105                .map(|r| r.text)
106                .unwrap_or_default()
107        } else {
108            String::new()
109        };
110
111        self.current_node_idx = if self.candidates.is_empty() { None } else { Some(0) };
112
113        self.build_state()
114    }
115
116    /// Execute action and return next state, reward, done, info
117    pub fn step(&mut self, action: (usize, Vec<f32>)) -> Result<(Vec<f32>, f32, bool, StepInfo)> {
118        let (discrete_action, params) = action;
119        self.step_count += 1;
120
121        let n = self.candidates.len();
122
123        // Apply the discrete action. Node-select actions pick a candidate;
124        // navigation actions move the selection or adjust block filtering;
125        // TERMINATE ends the episode. Every branch has a real effect.
126        match discrete_action {
127            d if d < self.config.num_candidate_nodes => {
128                if n > 0 {
129                    self.current_node_idx = Some(d.min(n - 1));
130                }
131            }
132            ACTION_SELECT_PARENT => self.select_parent(),
133            ACTION_SELECT_SIBLING_LEFT => {
134                if let Some(idx) = self.current_node_idx {
135                    self.current_node_idx = Some(idx.saturating_sub(1));
136                }
137            }
138            ACTION_SELECT_SIBLING_RIGHT => {
139                if let (Some(idx), true) = (self.current_node_idx, n > 0) {
140                    self.current_node_idx = Some((idx + 1).min(n - 1));
141                }
142            }
143            // EXPAND keeps more text (lower the word threshold); CONTRACT is stricter.
144            ACTION_EXPAND_REGION => {
145                self.word_threshold_adjust = (self.word_threshold_adjust - 2).max(-20);
146            }
147            ACTION_CONTRACT_REGION => {
148                self.word_threshold_adjust = (self.word_threshold_adjust + 2).min(40);
149            }
150            ACTION_TERMINATE => self.terminated = true,
151            _ => {}
152        }
153
154        // Extract text using the *selected* node and the effective params.
155        let effective_params = self.effective_params(&params);
156        let extracted_text = self.extract_selected(&effective_params);
157
158        // Reward: token F1 against ground truth when available, otherwise a
159        // self-supervised text-quality proxy. Mapped to [-1, 1] with a small
160        // per-step cost to encourage decisive episodes.
161        let score = if self.ground_truth_text.is_empty() {
162            TextUtils::calculate_text_quality(&extracted_text, &self.config.stopwords)
163        } else {
164            TextUtils::token_f1(&extracted_text, &self.ground_truth_text, &self.config.stopwords)
165        };
166        let reward = (score * 2.0 - 1.0 - 0.01 * self.step_count as f32).clamp(-1.0, 1.0);
167
168        let done = self.terminated || self.step_count >= self.max_steps;
169
170        let next_state = self.build_state()?;
171
172        let info = StepInfo {
173            quality_score: score,
174            text: extracted_text,
175            xpath: self
176                .current_node_idx
177                .and_then(|idx| self.candidates.get(idx))
178                .map(|c| c.xpath.clone())
179                .unwrap_or_default(),
180            parameters: self.denormalize_params(&params),
181            step_count: self.step_count,
182        };
183
184        Ok((next_state, reward, done, info))
185    }
186
187    /// Move selection to the candidate that is the nearest DOM ancestor (longest
188    /// proper xpath prefix) of the current selection, if any.
189    fn select_parent(&mut self) {
190        let Some(idx) = self.current_node_idx else { return };
191        let current_path = self.candidates[idx].xpath.clone();
192
193        let mut best: Option<(usize, usize)> = None; // (candidate idx, prefix len)
194        for (j, cand) in self.candidates.iter().enumerate() {
195            if j == idx {
196                continue;
197            }
198            if current_path.starts_with(&cand.xpath) && cand.xpath.len() < current_path.len() {
199                let better = best.map(|(_, len)| cand.xpath.len() > len).unwrap_or(true);
200                if better {
201                    best = Some((j, cand.xpath.len()));
202                }
203            }
204        }
205        if let Some((j, _)) = best {
206            self.current_node_idx = Some(j);
207        }
208    }
209
210    /// Combine the policy's continuous params with the EXPAND/CONTRACT offset.
211    fn effective_params(&self, params: &[f32]) -> ExtractionParams {
212        let mut p = ExtractionParams::from_normalized(params);
213        let adjusted = p.min_block_words as i32 + self.word_threshold_adjust;
214        p.min_block_words = adjusted.clamp(1, 60) as usize;
215        p
216    }
217
218    /// Extract text from the currently selected candidate.
219    fn extract_selected(&self, params: &ExtractionParams) -> String {
220        match self.current_node_idx.and_then(|idx| self.candidates.get(idx)) {
221            Some(cand) => cand.content.extract(params),
222            None => self.baseline_fallback.clone(),
223        }
224    }
225
226    /// Denormalize parameters from [-1, 1] to actual ranges (for site profiles).
227    fn denormalize_params(&self, params: &[f32]) -> HashMap<String, f64> {
228        let mut result = HashMap::new();
229
230        if params.len() >= 6 {
231            result.insert("min_word_threshold".to_string(), (2.0 + (params[0] + 1.0) * 4.0) as f64);
232            result.insert("stopword_weight".to_string(), (0.5 + (params[1] + 1.0) * 0.75) as f64);
233            result.insert("link_density_penalty".to_string(), ((params[2] + 1.0) * 1.0) as f64);
234            result.insert("paragraph_boost".to_string(), (1.0 + (params[3] + 1.0) * 0.5) as f64);
235            result.insert("sibling_extension".to_string(), ((params[4] + 1.0) * 0.5) as f64);
236            result.insert("depth_penalty".to_string(), ((params[5] + 1.0) * 0.25) as f64);
237        }
238
239        result
240    }
241
242    /// Build state vector from real DOM features.
243    ///
244    /// Layout (before padding to `config.state_dim`):
245    ///   - `num_candidate_nodes` × `NodeFeatures::DIM` per-candidate features
246    ///   - 8 global document features
247    ///   - selection one-hot (`num_candidate_nodes`) + step fraction
248    ///     + threshold-adjust + terminated flag
249    fn build_state(&self) -> Result<Vec<f32>> {
250        let mut state = Vec::with_capacity(self.config.state_dim);
251
252        // Per-candidate features (padded slots get zeros).
253        for slot in 0..self.config.num_candidate_nodes {
254            match self.candidates.get(slot) {
255                Some(c) => state.extend(c.features.to_vec()),
256                None => state.extend(NodeFeatures::zeros().to_vec()),
257            }
258        }
259
260        // Global document features.
261        let n = self.candidates.len();
262        let num_candidates_norm = (n as f32 / self.config.num_candidate_nodes as f32).clamp(0.0, 1.0);
263        let max_word = self
264            .candidates
265            .iter()
266            .map(|c| c.features.word_count_norm)
267            .fold(0.0_f32, f32::max);
268        let mean_link_density = if n == 0 {
269            0.0
270        } else {
271            self.candidates.iter().map(|c| c.features.link_density).sum::<f32>() / n as f32
272        };
273        let has_article = if self.candidates.iter().any(|c| c.features.tag_article > 0.5) { 1.0 } else { 0.0 };
274        let has_main = if self.candidates.iter().any(|c| c.features.tag_main > 0.5) { 1.0 } else { 0.0 };
275        let mean_stopword = if n == 0 {
276            0.0
277        } else {
278            self.candidates.iter().map(|c| c.features.stopword_ratio).sum::<f32>() / n as f32
279        };
280
281        state.push(num_candidates_norm);
282        state.push(max_word);
283        state.push(mean_link_density);
284        state.push(mean_stopword);
285        state.push(has_article);
286        state.push(has_main);
287        state.push(Self::hash_domain_normalized(&self.domain));
288        state.push(self.ground_truth_text.is_empty() as i32 as f32);
289
290        // Selection state.
291        for slot in 0..self.config.num_candidate_nodes {
292            state.push(if self.current_node_idx == Some(slot) { 1.0 } else { 0.0 });
293        }
294        state.push(self.step_count as f32 / self.max_steps.max(1) as f32);
295        state.push((self.word_threshold_adjust as f32 / 40.0).clamp(-1.0, 1.0));
296        state.push(self.terminated as i32 as f32);
297
298        // Pad or truncate to exact STATE_DIM.
299        state.truncate(self.config.state_dim);
300        while state.len() < self.config.state_dim {
301            state.push(0.0);
302        }
303
304        Ok(state)
305    }
306
307    /// Extract domain from URL
308    fn extract_domain(url: &str) -> String {
309        url::Url::parse(url)
310            .ok()
311            .and_then(|u| u.host_str().map(|h| h.to_string()))
312            .unwrap_or_else(|| "unknown".to_string())
313    }
314
315    /// Hash domain to normalized value
316    fn hash_domain_normalized(domain: &str) -> f32 {
317        use sha2::{Sha256, Digest};
318
319        let mut hasher = Sha256::new();
320        hasher.update(domain.as_bytes());
321        let result = hasher.finalize();
322
323        let hash_val = u32::from_be_bytes([result[0], result[1], result[2], result[3]]);
324        (hash_val % 10000) as f32 / 10000.0
325    }
326}
327
328/// Information returned from step
329#[derive(Debug, Clone)]
330pub struct StepInfo {
331    pub quality_score: f32,
332    pub text: String,
333    pub xpath: String,
334    pub parameters: HashMap<String, f64>,
335    pub step_count: usize,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    fn test_html() -> &'static str {
343        r#"
344        <html><body>
345            <nav class="navigation"><a href="/a">Home</a> <a href="/b">About</a> <a href="/c">Contact</a></nav>
346            <article class="article-content">
347                <p>Quantum researchers reported a significant breakthrough in error correction this week.</p>
348                <p>The new technique stabilizes qubits for longer durations enabling deeper computations.</p>
349                <p>Independent laboratories confirmed the reproducible measurements across several runs.</p>
350            </article>
351            <div class="sidebar-ads"><a href="/x">Buy now</a> <a href="/y">Subscribe today</a></div>
352        </body></html>
353        "#
354    }
355
356    fn env() -> ArticleExtractionEnvironment {
357        let config = Config::default();
358        let baseline = BaselineExtractor::new(config.stopwords.clone());
359        ArticleExtractionEnvironment::new(baseline, config)
360    }
361
362    #[test]
363    fn reset_produces_correct_state_dim_and_varies() {
364        let mut env = env();
365        let state = env
366            .reset(test_html(), "https://example.com/post".to_string(), None, None)
367            .unwrap();
368        assert_eq!(state.len(), env.config.state_dim);
369        // The state must not be the old constant 0.5 placeholder.
370        let distinct: std::collections::HashSet<u32> =
371            state.iter().map(|f| f.to_bits()).collect();
372        assert!(distinct.len() > 5, "state should contain varied real features");
373    }
374
375    #[test]
376    fn action_choice_changes_reward() {
377        let gt = "Quantum researchers reported a significant breakthrough in error correction \
378                  this week. The new technique stabilizes qubits for longer durations enabling \
379                  deeper computations. Independent laboratories confirmed the reproducible \
380                  measurements across several runs.";
381        let mut env = env();
382        env.reset(test_html(), "https://example.com/post".to_string(), Some(gt), None)
383            .unwrap();
384
385        // Find the article candidate vs a non-article candidate and compare rewards.
386        let mut rewards = Vec::new();
387        for action in 0..env.candidates.len() {
388            env.reset(test_html(), "https://example.com/post".to_string(), Some(gt), None)
389                .unwrap();
390            let (_s, reward, _d, info) = env
391                .step((action, vec![-1.0, 0.0, 0.0, 0.0, 0.0, 0.0]))
392                .unwrap();
393            rewards.push((action, reward, info.quality_score));
394        }
395
396        let best = rewards.iter().cloned().fold((0usize, f32::MIN, 0.0), |acc, x| {
397            if x.1 > acc.1 { x } else { acc }
398        });
399        let worst = rewards.iter().cloned().fold((0usize, f32::MAX, 0.0), |acc, x| {
400            if x.1 < acc.1 { x } else { acc }
401        });
402
403        // If actions had no effect (the old bug) every reward would be equal.
404        assert!(
405            (best.1 - worst.1).abs() > 1e-3,
406            "different node selections must yield different rewards: {rewards:?}"
407        );
408        // The best-scoring selection should recover most of the ground truth.
409        assert!(best.2 > 0.5, "best F1 should be high, got {}", best.2);
410    }
411
412    #[test]
413    fn terminate_action_ends_episode() {
414        let mut env = env();
415        env.reset(test_html(), "https://example.com/post".to_string(), None, None)
416            .unwrap();
417        let (_s, _r, done, _info) = env.step((ACTION_TERMINATE, vec![0.0; 6])).unwrap();
418        assert!(done, "TERMINATE must end the episode");
419    }
420
421    #[test]
422    fn episode_force_terminates_at_max_steps() {
423        let mut env = env();
424        env.reset(test_html(), "https://example.com/post".to_string(), None, None)
425            .unwrap();
426        let mut done = false;
427        let mut steps = 0;
428        while !done && steps < env.config.max_steps_per_episode + 5 {
429            // Use a non-terminating navigation action.
430            let (_s, _r, d, _i) = env.step((ACTION_SELECT_SIBLING_RIGHT, vec![0.0; 6])).unwrap();
431            done = d;
432            steps += 1;
433        }
434        assert!(done);
435        assert!(steps <= env.config.max_steps_per_episode);
436    }
437
438    #[test]
439    fn continuous_params_affect_extraction() {
440        let mut env = env();
441        env.reset(test_html(), "https://example.com/post".to_string(), None, None)
442            .unwrap();
443        // Select the article node (action 1 is typically the article here, but
444        // force-select via repeated reset+select to be deterministic).
445        let lenient = env.step((1, vec![-1.0, 1.0, 0.0, 0.0, 0.0, 0.0])).unwrap().3.text;
446        env.reset(test_html(), "https://example.com/post".to_string(), None, None)
447            .unwrap();
448        let strict = env.step((1, vec![1.0, 1.0, 0.0, 0.0, 0.0, 0.0])).unwrap().3.text;
449        // A very high min-word threshold should not produce *more* text than lenient.
450        assert!(strict.len() <= lenient.len());
451    }
452}