content_extractor_rl/
environment.rs1use scraper::{Html};
6use crate::baseline_extractor::BaselineExtractor;
7use crate::html_parser::HtmlParser;
8use crate::text_utils::TextUtils;
9use crate::site_profile::SiteProfile;
10use crate::config::{Config, ACTION_SELECT_PARENT, ACTION_SELECT_SIBLING_LEFT, ACTION_SELECT_SIBLING_RIGHT, ACTION_TERMINATE};
11use crate::Result;
12use std::collections::HashMap;
13
14pub struct ArticleExtractionEnvironment {
16 baseline_extractor: BaselineExtractor,
17 document: Option<Html>,
18 current_node_idx: Option<usize>,
19 candidates: Vec<String>, url: String,
21 domain: String,
22 step_count: usize,
23 max_steps: usize,
24 config: Config,
25}
26
27impl ArticleExtractionEnvironment {
28 pub fn new(baseline_extractor: BaselineExtractor, config: Config) -> Self {
30 Self {
31 baseline_extractor,
32 document: None,
33 current_node_idx: None,
34 candidates: Vec::new(),
35 url: String::new(),
36 domain: String::new(),
37 step_count: 0,
38 max_steps: config.max_steps_per_episode,
39 config,
40 }
41 }
42
43 pub fn reset(&mut self, html: &str, url: String, _site_profile: Option<&SiteProfile>) -> Result<Vec<f32>> {
45 self.url = url.clone();
46 self.domain = Self::extract_domain(&url);
47 self.step_count = 0;
48
49 let document = HtmlParser::clean_html(html)?;
51 let candidates = HtmlParser::get_candidate_nodes(&document, self.config.num_candidate_nodes);
52
53 self.candidates = candidates.iter()
55 .map(|node| HtmlParser::get_element_path(*node))
56 .collect();
57
58 self.document = Some(document);
59 self.current_node_idx = if !self.candidates.is_empty() { Some(0) } else { None };
60
61 self.build_state()
63 }
64
65 pub fn step(&mut self, action: (usize, Vec<f32>)) -> Result<(Vec<f32>, f32, bool, StepInfo)> {
67 let (discrete_action, params) = action;
68 self.step_count += 1;
69
70 let mut done = false;
71 let mut info = StepInfo {
72 quality_score: 0.0,
73 text: String::new(),
74 xpath: String::new(),
75 parameters: HashMap::new(),
76 step_count: self.step_count,
77 };
78
79 match discrete_action {
81 0..=9 => {
82 let idx = discrete_action.min(self.candidates.len().saturating_sub(1));
84 self.current_node_idx = Some(idx);
85 }
86 ACTION_SELECT_PARENT => {
87 }
89 ACTION_SELECT_SIBLING_LEFT => {
90 }
92 ACTION_SELECT_SIBLING_RIGHT => {
93 }
95 ACTION_TERMINATE => {
96 done = true;
97 }
98 _ => {}
99 }
100
101 let extracted_text = self.extract_with_params(¶ms)?;
103
104 let quality_score = TextUtils::calculate_text_quality(&extracted_text, &self.config.stopwords);
106 let reward = quality_score * 2.0 - 1.0 - 0.01 * self.step_count as f32;
107
108 if self.step_count >= self.max_steps {
110 done = true;
111 }
112
113 let next_state = self.build_state()?;
115
116 info.quality_score = quality_score;
117 info.text = extracted_text;
118 info.xpath = self.current_node_idx
119 .and_then(|idx| self.candidates.get(idx))
120 .cloned()
121 .unwrap_or_default();
122 info.parameters = self.denormalize_params(¶ms);
123
124 Ok((next_state, reward, done, info))
125 }
126
127 fn extract_with_params(&self, _params: &[f32]) -> Result<String> {
129 if let Some(document) = &self.document {
131 if let Some(idx) = self.current_node_idx {
132 if let Some(_xpath) = self.candidates.get(idx) {
133 let result = self.baseline_extractor.extract(&document.html())?;
136 return Ok(result.text);
137 }
138 }
139 }
140
141 Ok(String::new())
142 }
143
144 fn denormalize_params(&self, params: &[f32]) -> HashMap<String, f64> {
146 let mut result = HashMap::new();
147
148 if params.len() >= 6 {
149 result.insert("min_word_threshold".to_string(), (2.0 + (params[0] + 1.0) * 4.0) as f64);
150 result.insert("stopword_weight".to_string(), (0.5 + (params[1] + 1.0) * 0.75) as f64);
151 result.insert("link_density_penalty".to_string(), ((params[2] + 1.0) * 1.0) as f64);
152 result.insert("paragraph_boost".to_string(), (1.0 + (params[3] + 1.0) * 0.5) as f64);
153 result.insert("sibling_extension".to_string(), ((params[4] + 1.0) * 0.5) as f64);
154 result.insert("depth_penalty".to_string(), ((params[5] + 1.0) * 0.25) as f64);
155 }
156
157 result
158 }
159
160 fn build_state(&self) -> Result<Vec<f32>> {
162 let mut state = Vec::with_capacity(self.config.state_dim);
163
164 if let Some(document) = &self.document {
166 let _all_text = document.root_element().text().collect::<String>();
167
168 state.push(0.5); state.push(0.5);
170 state.push(0.5);
171 state.push(0.5);
172 state.push(0.5);
173 state.push(0.5);
174 state.push(0.5);
175 state.push(0.5);
176 state.push(0.0);
177 state.push(0.0);
178 state.push(0.5);
179 state.push(Self::hash_domain_normalized(&self.domain));
180 } else {
181 state.extend(vec![0.0; 12]);
182 }
183
184 for _ in 0..self.config.num_candidate_nodes {
186 state.extend(vec![0.5; 20]); }
188
189 state.extend(vec![0.0; 8]);
191
192 state.push(self.step_count as f32 / self.max_steps as f32);
194 state.extend(vec![0.5; 5]);
195
196 state.truncate(self.config.state_dim);
198 while state.len() < self.config.state_dim {
199 state.push(0.0);
200 }
201
202 Ok(state)
203 }
204
205 fn extract_domain(url: &str) -> String {
207 url::Url::parse(url)
208 .ok()
209 .and_then(|u| u.host_str().map(|h| h.to_string()))
210 .unwrap_or_else(|| "unknown".to_string())
211 }
212
213 fn hash_domain_normalized(domain: &str) -> f32 {
215 use sha2::{Sha256, Digest};
216
217 let mut hasher = Sha256::new();
218 hasher.update(domain.as_bytes());
219 let result = hasher.finalize();
220
221 let hash_val = u32::from_be_bytes([result[0], result[1], result[2], result[3]]);
222 (hash_val % 10000) as f32 / 10000.0
223 }
224}
225#[derive(Debug, Clone)]
227pub struct StepInfo {
228 pub quality_score: f32,
229 pub text: String,
230 pub xpath: String,
231 pub parameters: HashMap<String, f64>,
232 pub step_count: usize,
233}