Skip to main content

content_extractor_rl/
environment.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/environment.rs
3// ============================================================================
4
5use scraper::{Html};
6use crate::baseline_extractor::BaselineExtractor;
7use crate::html_parser::HtmlParser;
8use crate::text_utils::TextUtils;
9use crate::site_profile::SiteProfile;
10use crate::config::{Config, ACTION_SELECT_PARENT, ACTION_SELECT_SIBLING_LEFT, ACTION_SELECT_SIBLING_RIGHT, ACTION_TERMINATE};
11use crate::Result;
12use std::collections::HashMap;
13
14/// RL environment for article extraction
15pub struct ArticleExtractionEnvironment {
16    baseline_extractor: BaselineExtractor,
17    document: Option<Html>,
18    current_node_idx: Option<usize>,
19    candidates: Vec<String>, // Store node identifiers
20    url: String,
21    domain: String,
22    step_count: usize,
23    max_steps: usize,
24    config: Config,
25}
26
27impl ArticleExtractionEnvironment {
28    /// Create new environment
29    pub fn new(baseline_extractor: BaselineExtractor, config: Config) -> Self {
30        Self {
31            baseline_extractor,
32            document: None,
33            current_node_idx: None,
34            candidates: Vec::new(),
35            url: String::new(),
36            domain: String::new(),
37            step_count: 0,
38            max_steps: config.max_steps_per_episode,
39            config,
40        }
41    }
42
43    /// Reset environment with new HTML
44    pub fn reset(&mut self, html: &str, url: String, _site_profile: Option<&SiteProfile>) -> Result<Vec<f32>> {
45        self.url = url.clone();
46        self.domain = Self::extract_domain(&url);
47        self.step_count = 0;
48
49        // Parse and clean HTML
50        let document = HtmlParser::clean_html(html)?;
51        let candidates = HtmlParser::get_candidate_nodes(&document, self.config.num_candidate_nodes);
52
53        // Store candidate identifiers
54        self.candidates = candidates.iter()
55            .map(|node| HtmlParser::get_element_path(*node))
56            .collect();
57
58        self.document = Some(document);
59        self.current_node_idx = if !self.candidates.is_empty() { Some(0) } else { None };
60
61        // Build initial state
62        self.build_state()
63    }
64
65    /// Execute action and return next state, reward, done, info
66    pub fn step(&mut self, action: (usize, Vec<f32>)) -> Result<(Vec<f32>, f32, bool, StepInfo)> {
67        let (discrete_action, params) = action;
68        self.step_count += 1;
69
70        let mut done = false;
71        let mut info = StepInfo {
72            quality_score: 0.0,
73            text: String::new(),
74            xpath: String::new(),
75            parameters: HashMap::new(),
76            step_count: self.step_count,
77        };
78
79        // Execute discrete action
80        match discrete_action {
81            0..=9 => {
82                // Select candidate node
83                let idx = discrete_action.min(self.candidates.len().saturating_sub(1));
84                self.current_node_idx = Some(idx);
85            }
86            ACTION_SELECT_PARENT => {
87                // Move to parent (simplified)
88            }
89            ACTION_SELECT_SIBLING_LEFT => {
90                // Move to left sibling (simplified)
91            }
92            ACTION_SELECT_SIBLING_RIGHT => {
93                // Move to right sibling (simplified)
94            }
95            ACTION_TERMINATE => {
96                done = true;
97            }
98            _ => {}
99        }
100
101        // Extract text with parameters
102        let extracted_text = self.extract_with_params(&params)?;
103
104        // Calculate reward
105        let quality_score = TextUtils::calculate_text_quality(&extracted_text, &self.config.stopwords);
106        let reward = quality_score * 2.0 - 1.0 - 0.01 * self.step_count as f32;
107
108        // Force termination
109        if self.step_count >= self.max_steps {
110            done = true;
111        }
112
113        // Build next state
114        let next_state = self.build_state()?;
115
116        info.quality_score = quality_score;
117        info.text = extracted_text;
118        info.xpath = self.current_node_idx
119            .and_then(|idx| self.candidates.get(idx))
120            .cloned()
121            .unwrap_or_default();
122        info.parameters = self.denormalize_params(&params);
123
124        Ok((next_state, reward, done, info))
125    }
126
127    /// Extract text using parameters
128    fn extract_with_params(&self, _params: &[f32]) -> Result<String> {
129        // Simplified extraction using parameters
130        if let Some(document) = &self.document {
131            if let Some(idx) = self.current_node_idx {
132                if let Some(_xpath) = self.candidates.get(idx) {
133                    // In a real implementation, we would use the parameters
134                    // to customize the extraction
135                    let result = self.baseline_extractor.extract(&document.html())?;
136                    return Ok(result.text);
137                }
138            }
139        }
140
141        Ok(String::new())
142    }
143
144    /// Denormalize parameters from [-1, 1] to actual ranges
145    fn denormalize_params(&self, params: &[f32]) -> HashMap<String, f64> {
146        let mut result = HashMap::new();
147
148        if params.len() >= 6 {
149            result.insert("min_word_threshold".to_string(), (2.0 + (params[0] + 1.0) * 4.0) as f64);
150            result.insert("stopword_weight".to_string(), (0.5 + (params[1] + 1.0) * 0.75) as f64);
151            result.insert("link_density_penalty".to_string(), ((params[2] + 1.0) * 1.0) as f64);
152            result.insert("paragraph_boost".to_string(), (1.0 + (params[3] + 1.0) * 0.5) as f64);
153            result.insert("sibling_extension".to_string(), ((params[4] + 1.0) * 0.5) as f64);
154            result.insert("depth_penalty".to_string(), ((params[5] + 1.0) * 0.25) as f64);
155        }
156
157        result
158    }
159
160    /// Build state vector
161    fn build_state(&self) -> Result<Vec<f32>> {
162        let mut state = Vec::with_capacity(self.config.state_dim);
163
164        // Global document features (12 dims)
165        if let Some(document) = &self.document {
166            let _all_text = document.root_element().text().collect::<String>();
167
168            state.push(0.5); // Normalized features
169            state.push(0.5);
170            state.push(0.5);
171            state.push(0.5);
172            state.push(0.5);
173            state.push(0.5);
174            state.push(0.5);
175            state.push(0.5);
176            state.push(0.0);
177            state.push(0.0);
178            state.push(0.5);
179            state.push(Self::hash_domain_normalized(&self.domain));
180        } else {
181            state.extend(vec![0.0; 12]);
182        }
183
184        // Candidate node features (20 dims * 10 nodes = 200 dims)
185        for _ in 0..self.config.num_candidate_nodes {
186            state.extend(vec![0.5; 20]); // Simplified features
187        }
188
189        // Historical features (8 dims)
190        state.extend(vec![0.0; 8]);
191
192        // Current extraction state (6 dims)
193        state.push(self.step_count as f32 / self.max_steps as f32);
194        state.extend(vec![0.5; 5]);
195
196        // Pad or truncate to exact STATE_DIM
197        state.truncate(self.config.state_dim);
198        while state.len() < self.config.state_dim {
199            state.push(0.0);
200        }
201
202        Ok(state)
203    }
204
205    /// Extract domain from URL
206    fn extract_domain(url: &str) -> String {
207        url::Url::parse(url)
208            .ok()
209            .and_then(|u| u.host_str().map(|h| h.to_string()))
210            .unwrap_or_else(|| "unknown".to_string())
211    }
212
213    /// Hash domain to normalized value
214    fn hash_domain_normalized(domain: &str) -> f32 {
215        use sha2::{Sha256, Digest};
216
217        let mut hasher = Sha256::new();
218        hasher.update(domain.as_bytes());
219        let result = hasher.finalize();
220
221        let hash_val = u32::from_be_bytes([result[0], result[1], result[2], result[3]]);
222        (hash_val % 10000) as f32 / 10000.0
223    }
224}
225/// Information returned from step
226#[derive(Debug, Clone)]
227pub struct StepInfo {
228    pub quality_score: f32,
229    pub text: String,
230    pub xpath: String,
231    pub parameters: HashMap<String, f64>,
232    pub step_count: usize,
233}