1use crate::*;
8use std::path::{Path, PathBuf};
9use bzip2::read::BzDecoder;
10use std::io::Read;
11use indicatif::{ProgressBar, ProgressStyle};
12use url::Url;
13use crate::node_classifier::{HybridExtractor, NodeClassifier};
14use crate::node_features::ExtractionParams;
15use crate::text_utils::TextUtils;
16
17struct ContentExtraction {
19 text: String,
20 xpath: String,
21 method: &'static str,
22}
23
24fn rl_agent_extract(
29 html: &str,
30 url: &str,
31 config: &Config,
32 agent: &dyn RLAgent,
33) -> Result<ContentExtraction> {
34 let baseline = BaselineExtractor::new(config.stopwords.clone());
35 let mut env = ArticleExtractionEnvironment::new(baseline, config.clone());
36 let mut state = env.reset(html, url.to_string(), None, None)?;
37
38 let mut best = ContentExtraction { text: String::new(), xpath: String::new(), method: "rl" };
39 let mut best_quality = f32::MIN;
40 let mut done = false;
41 let mut steps = 0;
42
43 while !done && steps < config.max_steps_per_episode {
44 let action = agent.select_action(&state, 0.0)?; let (next_state, _reward, is_done, info) = env.step(action)?;
46 if info.quality_score > best_quality && !info.text.trim().is_empty() {
47 best_quality = info.quality_score;
48 best.text = info.text.clone();
49 best.xpath = info.xpath.clone();
50 }
51 state = next_state;
52 done = is_done;
53 steps += 1;
54 }
55
56 Ok(best)
57}
58
59fn hybrid_extract(html: &str, config: &Config) -> Result<ContentExtraction> {
62 let extractor = HybridExtractor::heuristic(config.stopwords.clone());
63 match extractor.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
64 Some(e) => Ok(ContentExtraction { text: e.text, xpath: e.xpath, method: "hybrid" }),
65 None => Ok(ContentExtraction { text: String::new(), xpath: String::new(), method: "hybrid" }),
66 }
67}
68
69pub fn extract_article(
85 html: &str,
86 url: &str,
87 config: &Config,
88 agent: Option<&dyn RLAgent>,
89) -> Result<ExtractedArticle> {
90 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
91 let baseline_result = baseline_extractor.extract(html)?;
92 let content = extract_content(html, url, config, agent, None, &baseline_result)?;
93 let quality_score = TextUtils::calculate_text_quality(&content.text, &config.stopwords);
94
95 Ok(ExtractedArticle {
96 url: url.to_string(),
97 title: baseline_result.title,
98 date: baseline_result.date,
99 content: content.text,
100 quality_score,
101 method: content.method.to_string(),
102 xpath: Some(content.xpath),
103 })
104}
105
106pub fn extract_article_hybrid(
111 html: &str,
112 url: &str,
113 config: &Config,
114 hybrid: &HybridExtractor,
115) -> Result<ExtractedArticle> {
116 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
117 let baseline_result = baseline_extractor.extract(html)?;
118
119 let (text, xpath, method) =
120 match hybrid.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
121 Some(e) if !e.text.trim().is_empty() => (e.text, e.xpath, "classifier"),
122 _ => (
123 baseline_result.text.clone(),
124 baseline_result.xpath.clone(),
125 "baseline",
126 ),
127 };
128
129 let quality_score = TextUtils::calculate_text_quality(&text, &config.stopwords);
130
131 Ok(ExtractedArticle {
132 url: url.to_string(),
133 title: baseline_result.title,
134 date: baseline_result.date,
135 content: text,
136 quality_score,
137 method: method.to_string(),
138 xpath: Some(xpath),
139 })
140}
141
142fn extract_content(
146 html: &str,
147 url: &str,
148 config: &Config,
149 agent: Option<&dyn RLAgent>,
150 hybrid: Option<&HybridExtractor>,
151 baseline_result: &crate::site_profile::ExtractionResult,
152) -> Result<ContentExtraction> {
153 let primary = if let Some(agent) = agent {
154 rl_agent_extract(html, url, config, agent)?
155 } else if let Some(hybrid) = hybrid {
156 match hybrid.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
157 Some(e) => ContentExtraction { text: e.text, xpath: e.xpath, method: "classifier" },
158 None => ContentExtraction { text: String::new(), xpath: String::new(), method: "classifier" },
159 }
160 } else {
161 hybrid_extract(html, config)?
162 };
163
164 if primary.text.trim().is_empty() {
165 Ok(ContentExtraction {
167 text: baseline_result.text.clone(),
168 xpath: baseline_result.xpath.clone(),
169 method: "baseline",
170 })
171 } else {
172 Ok(primary)
173 }
174}
175
176pub fn extract_single(
181 html_file: &Path,
182 url: String,
183 model_path: Option<&Path>,
184 classifier_path: Option<&Path>,
185 output: Option<&Path>,
186 config: &Config,
187) -> Result<ExtractedArticle> {
188 let html_content = read_html_file(html_file)?;
189
190 let article = if let Some(model_path) = model_path {
191 let device = get_device();
193 let agent = AgentFactory::load(
194 model_path,
195 config.state_dim,
196 config.num_discrete_actions,
197 config.num_continuous_params,
198 &device,
199 )?;
200 extract_article(&html_content, &url, config, Some(agent.as_ref()))?
201 } else if let Some(classifier_path) = classifier_path {
202 let device = get_device();
204 let classifier = NodeClassifier::load(classifier_path, &device, config.learning_rate)?;
205 let hybrid = HybridExtractor::with_classifier(classifier, config.stopwords.clone());
206 extract_article_hybrid(&html_content, &url, config, &hybrid)?
207 } else {
208 extract_article(&html_content, &url, config, None)?
210 };
211
212 if let Some(output_path) = output {
213 let batch_result = BatchExtractionResult {
214 articles: vec![article.clone()],
215 };
216 let json = serde_json::to_string_pretty(&batch_result)?;
217 std::fs::write(output_path, json)?;
218 }
219
220 Ok(article)
221}
222
223pub fn extract_batch(
225 archive_dir: &Path,
226 model_path: Option<&Path>,
227 classifier_path: Option<&Path>,
228 output_dir: &Path,
229 max_files: Option<usize>,
230 _batch_size: usize,
231 config: &Config,
232) -> Result<BatchExtractionResult> {
233 std::fs::create_dir_all(output_dir)?;
234
235 let file_pairs = load_html_files_recursive(archive_dir, max_files)?;
236
237 if file_pairs.is_empty() {
238 return Err(ExtractionError::ExtractionFailed(
239 "No HTML files found".to_string()
240 ));
241 }
242 let count_of_files: usize = file_pairs.len();
243
244 tracing::info!("Found {} HTML/JSON file pairs", count_of_files);
245
246 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
247 let device = get_device();
248 let agent = if let Some(path) = model_path {
249 Some(AgentFactory::load(
250 path,
251 config.state_dim,
252 config.num_discrete_actions,
253 config.num_continuous_params,
254 &device,
255 )?)
256 } else {
257 None
258 };
259
260 let hybrid = if agent.is_none() {
262 if let Some(path) = classifier_path {
263 let classifier = NodeClassifier::load(path, &device, config.learning_rate)?;
264 Some(HybridExtractor::with_classifier(classifier, config.stopwords.clone()))
265 } else {
266 None
267 }
268 } else {
269 None
270 };
271
272 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
274
275 let pb = ProgressBar::new(count_of_files as u64);
276 pb.set_style(
277 ProgressStyle::default_bar()
278 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
279 .unwrap()
280 .progress_chars("=>-"),
281 );
282
283 let mut all_articles = Vec::new();
284 let mut failed = Vec::new();
285 let mut site_profile_used_count = 0;
286
287 for (html_path, json_path) in file_pairs {
288 let url = read_url_from_json(&json_path);
289 let domain = extract_domain_from_url(&url);
290
291 let html_content = match read_html_file(&html_path) {
292 Ok(content) => content,
293 Err(e) => {
294 failed.push((url, e.to_string()));
295 pb.inc(1);
296 continue;
297 }
298 };
299
300 let site_profile = site_memory.get_profile(&domain);
302 let has_profile = site_profile.extractions.len() > 5;
303
304 if has_profile {
305 site_profile_used_count += 1;
306 }
307
308 match baseline_extractor.extract(&html_content) {
309 Ok(result) => {
310 let content = match extract_content(
313 &html_content, &url, config, agent.as_deref(), hybrid.as_ref(), &result,
314 ) {
315 Ok(c) => c,
316 Err(e) => {
317 failed.push((url, e.to_string()));
318 pb.inc(1);
319 continue;
320 }
321 };
322
323 let method = if has_profile {
324 format!("{}+profile", content.method)
325 } else {
326 content.method.to_string()
327 };
328
329 let quality_score =
330 TextUtils::calculate_text_quality(&content.text, &config.stopwords);
331
332 let article = ExtractedArticle {
333 url: url.clone(),
334 title: result.title.clone(),
335 date: result.date.clone(),
336 content: content.text.clone(),
337 quality_score,
338 method,
339 xpath: Some(content.xpath.clone()),
340 };
341
342 let extraction_result = site_profile::ExtractionResult {
344 text: content.text,
345 xpath: content.xpath,
346 quality_score,
347 parameters: result.parameters,
348 title: result.title,
349 date: result.date,
350 };
351 site_profile.add_extraction(extraction_result);
352
353 all_articles.push(article);
354 }
355 Err(e) => {
356 failed.push((url, e.to_string()));
357 }
358 }
359 pb.inc(1);
360 }
361
362 pb.finish_with_message("Batch extraction complete");
363
364 site_memory.save_all()?;
366 tracing::info!("Site profiles saved ({} domains used profiles)", site_profile_used_count);
367
368 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
370 let results_path = output_dir.join(format!("batch_results_{}.json", timestamp));
371 let batch_result = BatchExtractionResult { articles: all_articles.clone() };
372 let json = serde_json::to_string_pretty(&batch_result)?;
373 std::fs::write(&results_path, json)?;
374
375 if !failed.is_empty() {
377 let failed_path = output_dir.join(format!("failed_{}.json", timestamp));
378 let failed_json = serde_json::to_string_pretty(&failed)?;
379 std::fs::write(&failed_path, failed_json)?;
380 tracing::warn!("Failed extractions saved to: {}", failed_path.display());
381 }
382
383 tracing::info!("Batch extraction: {}/{} successful, {} with site profiles",
384 all_articles.len(), count_of_files, site_profile_used_count);
385
386 Ok(batch_result)
387}
388
389pub fn extract_domain_from_url(url: &str) -> String {
391 match Url::parse(url) {
392 Ok(parsed_url) => {
393 parsed_url.host_str()
394 .map(|h| h.to_string())
395 .unwrap_or_else(|| "unknown".to_string())
396 }
397 Err(_) => {
398 let url = url.trim();
399 let without_protocol = url.strip_prefix("https://")
400 .or_else(|| url.strip_prefix("http://"))
401 .unwrap_or(url);
402
403 let host_part = without_protocol.split('/').next().unwrap_or("");
404 let domain = host_part.split(':').next().unwrap_or("");
405
406 if domain.is_empty() {
407 "unknown".to_string()
408 } else {
409 domain.to_string()
410 }
411 }
412 }
413}
414
415pub fn load_html_files_recursive(
417 dir: &Path,
418 max_files: Option<usize>,
419) -> Result<Vec<(PathBuf, PathBuf)>> {
420 use walkdir::WalkDir;
421
422 let mut files = Vec::new();
423
424 for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
425 if let Some(max) = max_files {
426 if files.len() >= max {
427 break;
428 }
429 }
430
431 let path = entry.path();
432 if path.is_file() {
433 if let Some(ext) = path.extension() {
434 if ext == "bz2" && path.to_string_lossy().contains(".html.") {
435 let json_path = path.with_extension("").with_extension("json");
436 if json_path.exists() {
437 files.push((path.to_path_buf(), json_path));
438 }
439 } else if ext == "html" || ext == "htm" {
440 let json_path = path.with_extension("json");
441 if json_path.exists() {
442 files.push((path.to_path_buf(), json_path));
443 }
444 }
445 }
446 }
447 }
448
449 Ok(files)
450}
451
452pub fn read_html_file(path: &Path) -> Result<String> {
454 if path.extension().and_then(|s| s.to_str()) == Some("bz2") {
455 let file = std::fs::File::open(path)?;
456 let mut decoder = BzDecoder::new(file);
457 let mut bytes = Vec::new();
458 decoder.read_to_end(&mut bytes)?;
459 Ok(String::from_utf8_lossy(&bytes).into_owned())
460 } else {
461 let bytes = std::fs::read(path)?;
462 Ok(String::from_utf8_lossy(&bytes).into_owned())
463 }
464}
465
466pub fn read_url_from_json(json_path: &Path) -> String {
468 match std::fs::read_to_string(json_path) {
469 Ok(json_content) => {
470 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&json_content) {
471 json_value.get("URL")
472 .and_then(|u| u.as_str())
473 .map(|s| s.to_string())
474 .unwrap_or_else(|| "https://example.com/unknown".to_string())
475 } else {
476 "https://example.com/invalid-json".to_string()
477 }
478 }
479 Err(_) => "https://example.com/no-json".to_string(),
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use tempfile::TempDir;
487
488 #[test]
489 fn test_extract_domain_from_url() {
490 assert_eq!(
491 extract_domain_from_url("https://www.example.com/article"),
492 "www.example.com"
493 );
494
495 assert_eq!(
496 extract_domain_from_url("http://subdomain.example.org:8080/path"),
497 "subdomain.example.org"
498 );
499 }
500
501 #[test]
502 fn test_read_url_from_json() {
503 let temp_dir = TempDir::new().unwrap();
504 let json_path = temp_dir.path().join("test.json");
505
506 let json_content = r#"{"URL": "https://example.com/article"}"#;
507 std::fs::write(&json_path, json_content).unwrap();
508
509 let url = read_url_from_json(&json_path);
510 assert_eq!(url, "https://example.com/article");
511 }
512
513 #[test]
514 fn test_extract_single_uses_hybrid_without_model() {
515 let temp_dir = TempDir::new().unwrap();
516 let html_path = temp_dir.path().join("page.html");
517 std::fs::write(
518 &html_path,
519 r#"
520 <html><head><title>Mission Update</title></head><body>
521 <nav class="site-nav"><a href="/a">Home</a> <a href="/b">News</a> <a href="/c">More</a></nav>
522 <article class="article-body">
523 <p>The spacecraft entered orbit today after a long interplanetary cruise phase.</p>
524 <p>Mission controllers confirmed every instrument survived the journey intact.</p>
525 <p>Scientific observations of the planet will begin within the next two weeks.</p>
526 </article>
527 <div class="footer-links"><a href="/p">Privacy</a> <a href="/t">Terms</a></div>
528 </body></html>
529 "#,
530 )
531 .unwrap();
532
533 let config = Config::default();
534 let article = extract_single(
535 &html_path,
536 "https://example.com/mission".to_string(),
537 None, None, None, &config,
541 )
542 .unwrap();
543
544 assert_eq!(article.method, "hybrid");
546 assert!(article.content.contains("entered orbit"), "content: {}", article.content);
547 assert!(!article.content.contains("Privacy"), "footer leaked: {}", article.content);
548 }
549}