content_extractor_rl/
cli_utils.rs1use crate::*;
8use std::path::{Path, PathBuf};
9use bzip2::read::BzDecoder;
10use std::io::Read;
11use indicatif::{ProgressBar, ProgressStyle};
12use url::Url;
13use crate::agents::dqn_agent::DQNAgent;
14
15pub fn extract_single(
17 html_file: &Path,
18 url: String,
19 model_path: Option<&Path>,
20 output: Option<&Path>,
21 config: &Config,
22) -> Result<ExtractedArticle> {
23 let html_content = read_html_file(html_file)?;
24 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
25
26 let domain = extract_domain_from_url(&url);
28
29 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
31 let site_profile = site_memory.get_profile(&domain);
32
33 let result = if let Some(model_path) = model_path {
34 let device = get_device();
35 let _agent = DQNAgent::load_with_device(
36 model_path,
37 config.state_dim,
38 config.num_discrete_actions,
39 config.num_continuous_params,
40 &device,
41 )?;
42
43 if site_profile.extractions.len() > 5 {
45 tracing::debug!("Using site profile for {} (has {} past extractions)",
46 domain, site_profile.extractions.len());
47 }
48
49 baseline_extractor.extract(&html_content)?
50 } else {
51 baseline_extractor.extract(&html_content)?
52 };
53
54 let article = ExtractedArticle {
55 url: url.clone(),
56 title: result.title,
57 date: result.date,
58 content: result.text,
59 quality_score: result.quality_score,
60 method: if model_path.is_some() { "rl" } else { "baseline" }.to_string(),
61 xpath: Some(result.xpath),
62 };
63
64 if let Some(output_path) = output {
65 let batch_result = BatchExtractionResult {
66 articles: vec![article.clone()],
67 };
68 let json = serde_json::to_string_pretty(&batch_result)?;
69 std::fs::write(output_path, json)?;
70 }
71
72 Ok(article)
73}
74
75pub fn extract_batch(
77 archive_dir: &Path,
78 model_path: Option<&Path>,
79 output_dir: &Path,
80 max_files: Option<usize>,
81 _batch_size: usize,
82 config: &Config,
83) -> Result<BatchExtractionResult> {
84 std::fs::create_dir_all(output_dir)?;
85
86 let file_pairs = load_html_files_recursive(archive_dir, max_files)?;
87
88 if file_pairs.is_empty() {
89 return Err(ExtractionError::ExtractionFailed(
90 "No HTML files found".to_string()
91 ));
92 }
93 let count_of_files: usize = file_pairs.len();
94
95 tracing::info!("Found {} HTML/JSON file pairs", count_of_files);
96
97 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
98 let device = get_device();
99 let agent = if let Some(path) = model_path {
100 Some(DQNAgent::load_with_device(
101 path,
102 config.state_dim,
103 config.num_discrete_actions,
104 config.num_continuous_params,
105 &device,
106 )?)
107 } else {
108 None
109 };
110
111 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
113
114 let pb = ProgressBar::new(count_of_files as u64);
115 pb.set_style(
116 ProgressStyle::default_bar()
117 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
118 .unwrap()
119 .progress_chars("=>-"),
120 );
121
122 let mut all_articles = Vec::new();
123 let mut failed = Vec::new();
124 let mut site_profile_used_count = 0;
125
126 for (html_path, json_path) in file_pairs {
127 let url = read_url_from_json(&json_path);
128 let domain = extract_domain_from_url(&url);
129
130 let html_content = match read_html_file(&html_path) {
131 Ok(content) => content,
132 Err(e) => {
133 failed.push((url, e.to_string()));
134 pb.inc(1);
135 continue;
136 }
137 };
138
139 let site_profile = site_memory.get_profile(&domain);
141 let has_profile = site_profile.extractions.len() > 5;
142
143 if has_profile {
144 site_profile_used_count += 1;
145 }
146
147 match baseline_extractor.extract(&html_content) {
148 Ok(result) => {
149 let method = if agent.is_some() {
150 if has_profile { "rl+profile" } else { "rl" }
151 } else if has_profile { "baseline+profile" } else { "baseline" };
152
153 let article = ExtractedArticle {
154 url: url.clone(),
155 title: result.title.clone(),
156 date: result.date.clone(),
157 content: result.text.clone(),
158 quality_score: result.quality_score,
159 method: method.to_string(),
160 xpath: Some(result.xpath.clone()),
161 };
162
163 let extraction_result = site_profile::ExtractionResult {
165 text: result.text,
166 xpath: result.xpath,
167 quality_score: result.quality_score,
168 parameters: result.parameters,
169 title: result.title,
170 date: result.date,
171 };
172 site_profile.add_extraction(extraction_result);
173
174 all_articles.push(article);
175 }
176 Err(e) => {
177 failed.push((url, e.to_string()));
178 }
179 }
180 pb.inc(1);
181 }
182
183 pb.finish_with_message("Batch extraction complete");
184
185 site_memory.save_all()?;
187 tracing::info!("Site profiles saved ({} domains used profiles)", site_profile_used_count);
188
189 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
191 let results_path = output_dir.join(format!("batch_results_{}.json", timestamp));
192 let batch_result = BatchExtractionResult { articles: all_articles.clone() };
193 let json = serde_json::to_string_pretty(&batch_result)?;
194 std::fs::write(&results_path, json)?;
195
196 if !failed.is_empty() {
198 let failed_path = output_dir.join(format!("failed_{}.json", timestamp));
199 let failed_json = serde_json::to_string_pretty(&failed)?;
200 std::fs::write(&failed_path, failed_json)?;
201 tracing::warn!("Failed extractions saved to: {}", failed_path.display());
202 }
203
204 tracing::info!("Batch extraction: {}/{} successful, {} with site profiles",
205 all_articles.len(), count_of_files, site_profile_used_count);
206
207 Ok(batch_result)
208}
209
210pub fn extract_domain_from_url(url: &str) -> String {
212 match Url::parse(url) {
213 Ok(parsed_url) => {
214 parsed_url.host_str()
215 .map(|h| h.to_string())
216 .unwrap_or_else(|| "unknown".to_string())
217 }
218 Err(_) => {
219 let url = url.trim();
220 let without_protocol = url.strip_prefix("https://")
221 .or_else(|| url.strip_prefix("http://"))
222 .unwrap_or(url);
223
224 let host_part = without_protocol.split('/').next().unwrap_or("");
225 let domain = host_part.split(':').next().unwrap_or("");
226
227 if domain.is_empty() {
228 "unknown".to_string()
229 } else {
230 domain.to_string()
231 }
232 }
233 }
234}
235
236pub fn load_html_files_recursive(
238 dir: &Path,
239 max_files: Option<usize>,
240) -> Result<Vec<(PathBuf, PathBuf)>> {
241 use walkdir::WalkDir;
242
243 let mut files = Vec::new();
244
245 for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
246 if let Some(max) = max_files {
247 if files.len() >= max {
248 break;
249 }
250 }
251
252 let path = entry.path();
253 if path.is_file() {
254 if let Some(ext) = path.extension() {
255 if ext == "bz2" && path.to_string_lossy().contains(".html.") {
256 let json_path = path.with_extension("").with_extension("json");
257 if json_path.exists() {
258 files.push((path.to_path_buf(), json_path));
259 }
260 } else if ext == "html" || ext == "htm" {
261 let json_path = path.with_extension("json");
262 if json_path.exists() {
263 files.push((path.to_path_buf(), json_path));
264 }
265 }
266 }
267 }
268 }
269
270 Ok(files)
271}
272
273pub fn read_html_file(path: &Path) -> Result<String> {
275 if path.extension().and_then(|s| s.to_str()) == Some("bz2") {
276 let file = std::fs::File::open(path)?;
277 let mut decoder = BzDecoder::new(file);
278 let mut bytes = Vec::new();
279 decoder.read_to_end(&mut bytes)?;
280 Ok(String::from_utf8_lossy(&bytes).into_owned())
281 } else {
282 let bytes = std::fs::read(path)?;
283 Ok(String::from_utf8_lossy(&bytes).into_owned())
284 }
285}
286
287pub fn read_url_from_json(json_path: &Path) -> String {
289 match std::fs::read_to_string(json_path) {
290 Ok(json_content) => {
291 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&json_content) {
292 json_value.get("URL")
293 .and_then(|u| u.as_str())
294 .map(|s| s.to_string())
295 .unwrap_or_else(|| "https://example.com/unknown".to_string())
296 } else {
297 "https://example.com/invalid-json".to_string()
298 }
299 }
300 Err(_) => "https://example.com/no-json".to_string(),
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use tempfile::TempDir;
308
309 #[test]
310 fn test_extract_domain_from_url() {
311 assert_eq!(
312 extract_domain_from_url("https://www.example.com/article"),
313 "www.example.com"
314 );
315
316 assert_eq!(
317 extract_domain_from_url("http://subdomain.example.org:8080/path"),
318 "subdomain.example.org"
319 );
320 }
321
322 #[test]
323 fn test_read_url_from_json() {
324 let temp_dir = TempDir::new().unwrap();
325 let json_path = temp_dir.path().join("test.json");
326
327 let json_content = r#"{"URL": "https://example.com/article"}"#;
328 std::fs::write(&json_path, json_content).unwrap();
329
330 let url = read_url_from_json(&json_path);
331 assert_eq!(url, "https://example.com/article");
332 }
333}