Skip to main content

content_extractor_rl/
cli_utils.rs

1//! High-level command interface for CLI
2//! This module contains the main logic for each CLI command
3// ============================================================================
4// FILE: crates/content-extractor-rl/src/cli_utils.rs
5// ============================================================================
6
7use crate::*;
8use std::path::{Path, PathBuf};
9use bzip2::read::BzDecoder;
10use std::io::Read;
11use indicatif::{ProgressBar, ProgressStyle};
12use url::Url;
13use crate::node_classifier::{HybridExtractor, NodeClassifier};
14use crate::node_features::ExtractionParams;
15use crate::text_utils::TextUtils;
16
17/// Outcome of a content-only extraction (text + the node it came from).
18struct ContentExtraction {
19    text: String,
20    xpath: String,
21    method: &'static str,
22}
23
24/// Run a trained RL agent greedily through the environment for one page and
25/// return the highest-quality extraction it produced. This is the real RL
26/// inference path: the agent observes the page's DOM features, picks a content
27/// node and tunes the extraction params, and we keep the best step.
28fn rl_agent_extract(
29    html: &str,
30    url: &str,
31    config: &Config,
32    agent: &dyn RLAgent,
33) -> Result<ContentExtraction> {
34    let baseline = BaselineExtractor::new(config.stopwords.clone());
35    let mut env = ArticleExtractionEnvironment::new(baseline, config.clone());
36    let mut state = env.reset(html, url.to_string(), None, None)?;
37
38    let mut best = ContentExtraction { text: String::new(), xpath: String::new(), method: "rl" };
39    let mut best_quality = f32::MIN;
40    let mut done = false;
41    let mut steps = 0;
42
43    while !done && steps < config.max_steps_per_episode {
44        let action = agent.select_action(&state, 0.0)?; // greedy (epsilon = 0)
45        let (next_state, _reward, is_done, info) = env.step(action)?;
46        if info.quality_score > best_quality && !info.text.trim().is_empty() {
47            best_quality = info.quality_score;
48            best.text = info.text.clone();
49            best.xpath = info.xpath.clone();
50        }
51        state = next_state;
52        done = is_done;
53        steps += 1;
54    }
55
56    Ok(best)
57}
58
59/// Pick the content node with the supervised/heuristic [`HybridExtractor`] and
60/// extract its text. Used when no trained RL model is supplied.
61fn hybrid_extract(html: &str, config: &Config) -> Result<ContentExtraction> {
62    let extractor = HybridExtractor::heuristic(config.stopwords.clone());
63    match extractor.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
64        Some(e) => Ok(ContentExtraction { text: e.text, xpath: e.xpath, method: "hybrid" }),
65        None => Ok(ContentExtraction { text: String::new(), xpath: String::new(), method: "hybrid" }),
66    }
67}
68
69/// Extract a complete article (title/date metadata + body) from one page.
70///
71/// Body selection prefers, in order: the trained RL `agent` if supplied, then
72/// the hybrid/heuristic node selector, then the plain baseline. Title and date
73/// always come from the baseline metadata extractor. This is the single shared
74/// entry point used by the CLI and the Python bindings.
75///
76/// ```no_run
77/// use content_extractor_rl::{Config, extract_article};
78/// let config = Config::default();
79/// let html = std::fs::read_to_string("page.html").unwrap();
80/// // No model -> hybrid heuristic selection (no training required):
81/// let article = extract_article(&html, "https://example.com/post", &config, None).unwrap();
82/// println!("{}", article.content);
83/// ```
84pub fn extract_article(
85    html: &str,
86    url: &str,
87    config: &Config,
88    agent: Option<&dyn RLAgent>,
89) -> Result<ExtractedArticle> {
90    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
91    let baseline_result = baseline_extractor.extract(html)?;
92    let content = extract_content(html, url, config, agent, None, &baseline_result)?;
93    let quality_score = TextUtils::calculate_text_quality(&content.text, &config.stopwords);
94
95    Ok(ExtractedArticle {
96        url: url.to_string(),
97        title: baseline_result.title,
98        date: baseline_result.date,
99        content: content.text,
100        quality_score,
101        method: content.method.to_string(),
102        xpath: Some(content.xpath),
103    })
104}
105
106/// Extract a complete article using a prepared [`HybridExtractor`] (which may be
107/// backed by a trained `NodeClassifier` or the heuristic). Title/date come from
108/// the baseline metadata extractor; falls back to the baseline body if the
109/// hybrid selector returns nothing.
110pub fn extract_article_hybrid(
111    html: &str,
112    url: &str,
113    config: &Config,
114    hybrid: &HybridExtractor,
115) -> Result<ExtractedArticle> {
116    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
117    let baseline_result = baseline_extractor.extract(html)?;
118
119    let (text, xpath, method) =
120        match hybrid.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
121            Some(e) if !e.text.trim().is_empty() => (e.text, e.xpath, "classifier"),
122            _ => (
123                baseline_result.text.clone(),
124                baseline_result.xpath.clone(),
125                "baseline",
126            ),
127        };
128
129    let quality_score = TextUtils::calculate_text_quality(&text, &config.stopwords);
130
131    Ok(ExtractedArticle {
132        url: url.to_string(),
133        title: baseline_result.title,
134        date: baseline_result.date,
135        content: text,
136        quality_score,
137        method: method.to_string(),
138        xpath: Some(xpath),
139    })
140}
141
142/// Extract the article body for one page, preferring (in order): the RL agent if
143/// supplied, then the hybrid/heuristic node selector, then the plain baseline.
144/// Title/date always come from the baseline metadata extractor.
145fn extract_content(
146    html: &str,
147    url: &str,
148    config: &Config,
149    agent: Option<&dyn RLAgent>,
150    hybrid: Option<&HybridExtractor>,
151    baseline_result: &crate::site_profile::ExtractionResult,
152) -> Result<ContentExtraction> {
153    let primary = if let Some(agent) = agent {
154        rl_agent_extract(html, url, config, agent)?
155    } else if let Some(hybrid) = hybrid {
156        match hybrid.extract(html, config.num_candidate_nodes, &ExtractionParams::default())? {
157            Some(e) => ContentExtraction { text: e.text, xpath: e.xpath, method: "classifier" },
158            None => ContentExtraction { text: String::new(), xpath: String::new(), method: "classifier" },
159        }
160    } else {
161        hybrid_extract(html, config)?
162    };
163
164    if primary.text.trim().is_empty() {
165        // Fall back to the baseline body so we never return nothing.
166        Ok(ContentExtraction {
167            text: baseline_result.text.clone(),
168            xpath: baseline_result.xpath.clone(),
169            method: "baseline",
170        })
171    } else {
172        Ok(primary)
173    }
174}
175
176/// Extract article from single HTML file.
177///
178/// Selection priority: RL `model_path` (if any) → trained `classifier_path`
179/// (if any) → hybrid heuristic.
180pub fn extract_single(
181    html_file: &Path,
182    url: String,
183    model_path: Option<&Path>,
184    classifier_path: Option<&Path>,
185    output: Option<&Path>,
186    config: &Config,
187) -> Result<ExtractedArticle> {
188    let html_content = read_html_file(html_file)?;
189
190    let article = if let Some(model_path) = model_path {
191        // RL agent (any algorithm — auto-detected).
192        let device = get_device();
193        let agent = AgentFactory::load(
194            model_path,
195            config.state_dim,
196            config.num_discrete_actions,
197            config.num_continuous_params,
198            &device,
199        )?;
200        extract_article(&html_content, &url, config, Some(agent.as_ref()))?
201    } else if let Some(classifier_path) = classifier_path {
202        // Supervised node classifier.
203        let device = get_device();
204        let classifier = NodeClassifier::load(classifier_path, &device, config.learning_rate)?;
205        let hybrid = HybridExtractor::with_classifier(classifier, config.stopwords.clone());
206        extract_article_hybrid(&html_content, &url, config, &hybrid)?
207    } else {
208        // Hybrid heuristic (no model).
209        extract_article(&html_content, &url, config, None)?
210    };
211
212    if let Some(output_path) = output {
213        let batch_result = BatchExtractionResult {
214            articles: vec![article.clone()],
215        };
216        let json = serde_json::to_string_pretty(&batch_result)?;
217        std::fs::write(output_path, json)?;
218    }
219
220    Ok(article)
221}
222
223/// Extract batch of HTML files with site profile support
224pub fn extract_batch(
225    archive_dir: &Path,
226    model_path: Option<&Path>,
227    classifier_path: Option<&Path>,
228    output_dir: &Path,
229    max_files: Option<usize>,
230    _batch_size: usize,
231    config: &Config,
232) -> Result<BatchExtractionResult> {
233    std::fs::create_dir_all(output_dir)?;
234
235    let file_pairs = load_html_files_recursive(archive_dir, max_files)?;
236
237    if file_pairs.is_empty() {
238        return Err(ExtractionError::ExtractionFailed(
239            "No HTML files found".to_string()
240        ));
241    }
242    let count_of_files: usize = file_pairs.len();
243
244    tracing::info!("Found {} HTML/JSON file pairs", count_of_files);
245
246    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
247    let device = get_device();
248    let agent = if let Some(path) = model_path {
249        Some(AgentFactory::load(
250            path,
251            config.state_dim,
252            config.num_discrete_actions,
253            config.num_continuous_params,
254            &device,
255        )?)
256    } else {
257        None
258    };
259
260    // A trained classifier is used only when no RL model is supplied.
261    let hybrid = if agent.is_none() {
262        if let Some(path) = classifier_path {
263            let classifier = NodeClassifier::load(path, &device, config.learning_rate)?;
264            Some(HybridExtractor::with_classifier(classifier, config.stopwords.clone()))
265        } else {
266            None
267        }
268    } else {
269        None
270    };
271
272    // Initialize site profile memory
273    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
274
275    let pb = ProgressBar::new(count_of_files as u64);
276    pb.set_style(
277        ProgressStyle::default_bar()
278            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
279            .unwrap()
280            .progress_chars("=>-"),
281    );
282
283    let mut all_articles = Vec::new();
284    let mut failed = Vec::new();
285    let mut site_profile_used_count = 0;
286
287    for (html_path, json_path) in file_pairs {
288        let url = read_url_from_json(&json_path);
289        let domain = extract_domain_from_url(&url);
290
291        let html_content = match read_html_file(&html_path) {
292            Ok(content) => content,
293            Err(e) => {
294                failed.push((url, e.to_string()));
295                pb.inc(1);
296                continue;
297            }
298        };
299
300        // Get site profile for this domain
301        let site_profile = site_memory.get_profile(&domain);
302        let has_profile = site_profile.extractions.len() > 5;
303
304        if has_profile {
305            site_profile_used_count += 1;
306        }
307
308        match baseline_extractor.extract(&html_content) {
309            Ok(result) => {
310                // Select the content node with the RL agent (if loaded) or the
311                // hybrid heuristic; metadata comes from the baseline result.
312                let content = match extract_content(
313                    &html_content, &url, config, agent.as_deref(), hybrid.as_ref(), &result,
314                ) {
315                    Ok(c) => c,
316                    Err(e) => {
317                        failed.push((url, e.to_string()));
318                        pb.inc(1);
319                        continue;
320                    }
321                };
322
323                let method = if has_profile {
324                    format!("{}+profile", content.method)
325                } else {
326                    content.method.to_string()
327                };
328
329                let quality_score =
330                    TextUtils::calculate_text_quality(&content.text, &config.stopwords);
331
332                let article = ExtractedArticle {
333                    url: url.clone(),
334                    title: result.title.clone(),
335                    date: result.date.clone(),
336                    content: content.text.clone(),
337                    quality_score,
338                    method,
339                    xpath: Some(content.xpath.clone()),
340                };
341
342                // Update site profile with this extraction
343                let extraction_result = site_profile::ExtractionResult {
344                    text: content.text,
345                    xpath: content.xpath,
346                    quality_score,
347                    parameters: result.parameters,
348                    title: result.title,
349                    date: result.date,
350                };
351                site_profile.add_extraction(extraction_result);
352
353                all_articles.push(article);
354            }
355            Err(e) => {
356                failed.push((url, e.to_string()));
357            }
358        }
359        pb.inc(1);
360    }
361
362    pb.finish_with_message("Batch extraction complete");
363
364    // Save site profiles
365    site_memory.save_all()?;
366    tracing::info!("Site profiles saved ({} domains used profiles)", site_profile_used_count);
367
368    // Save results
369    let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
370    let results_path = output_dir.join(format!("batch_results_{}.json", timestamp));
371    let batch_result = BatchExtractionResult { articles: all_articles.clone() };
372    let json = serde_json::to_string_pretty(&batch_result)?;
373    std::fs::write(&results_path, json)?;
374
375    // Save failed extractions
376    if !failed.is_empty() {
377        let failed_path = output_dir.join(format!("failed_{}.json", timestamp));
378        let failed_json = serde_json::to_string_pretty(&failed)?;
379        std::fs::write(&failed_path, failed_json)?;
380        tracing::warn!("Failed extractions saved to: {}", failed_path.display());
381    }
382
383    tracing::info!("Batch extraction: {}/{} successful, {} with site profiles",
384                   all_articles.len(), count_of_files, site_profile_used_count);
385
386    Ok(batch_result)
387}
388
389/// Extract domain from URL
390pub fn extract_domain_from_url(url: &str) -> String {
391    match Url::parse(url) {
392        Ok(parsed_url) => {
393            parsed_url.host_str()
394                .map(|h| h.to_string())
395                .unwrap_or_else(|| "unknown".to_string())
396        }
397        Err(_) => {
398            let url = url.trim();
399            let without_protocol = url.strip_prefix("https://")
400                .or_else(|| url.strip_prefix("http://"))
401                .unwrap_or(url);
402
403            let host_part = without_protocol.split('/').next().unwrap_or("");
404            let domain = host_part.split(':').next().unwrap_or("");
405
406            if domain.is_empty() {
407                "unknown".to_string()
408            } else {
409                domain.to_string()
410            }
411        }
412    }
413}
414
415/// Load HTML files recursively
416pub fn load_html_files_recursive(
417    dir: &Path,
418    max_files: Option<usize>,
419) -> Result<Vec<(PathBuf, PathBuf)>> {
420    use walkdir::WalkDir;
421
422    let mut files = Vec::new();
423
424    for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
425        if let Some(max) = max_files {
426            if files.len() >= max {
427                break;
428            }
429        }
430
431        let path = entry.path();
432        if path.is_file() {
433            if let Some(ext) = path.extension() {
434                if ext == "bz2" && path.to_string_lossy().contains(".html.") {
435                    let json_path = path.with_extension("").with_extension("json");
436                    if json_path.exists() {
437                        files.push((path.to_path_buf(), json_path));
438                    }
439                } else if ext == "html" || ext == "htm" {
440                    let json_path = path.with_extension("json");
441                    if json_path.exists() {
442                        files.push((path.to_path_buf(), json_path));
443                    }
444                }
445            }
446        }
447    }
448
449    Ok(files)
450}
451
452/// Read HTML file with UTF-8 error handling
453pub fn read_html_file(path: &Path) -> Result<String> {
454    if path.extension().and_then(|s| s.to_str()) == Some("bz2") {
455        let file = std::fs::File::open(path)?;
456        let mut decoder = BzDecoder::new(file);
457        let mut bytes = Vec::new();
458        decoder.read_to_end(&mut bytes)?;
459        Ok(String::from_utf8_lossy(&bytes).into_owned())
460    } else {
461        let bytes = std::fs::read(path)?;
462        Ok(String::from_utf8_lossy(&bytes).into_owned())
463    }
464}
465
466/// Read URL from JSON file
467pub fn read_url_from_json(json_path: &Path) -> String {
468    match std::fs::read_to_string(json_path) {
469        Ok(json_content) => {
470            if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&json_content) {
471                json_value.get("URL")
472                    .and_then(|u| u.as_str())
473                    .map(|s| s.to_string())
474                    .unwrap_or_else(|| "https://example.com/unknown".to_string())
475            } else {
476                "https://example.com/invalid-json".to_string()
477            }
478        }
479        Err(_) => "https://example.com/no-json".to_string(),
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use tempfile::TempDir;
487
488    #[test]
489    fn test_extract_domain_from_url() {
490        assert_eq!(
491            extract_domain_from_url("https://www.example.com/article"),
492            "www.example.com"
493        );
494
495        assert_eq!(
496            extract_domain_from_url("http://subdomain.example.org:8080/path"),
497            "subdomain.example.org"
498        );
499    }
500
501    #[test]
502    fn test_read_url_from_json() {
503        let temp_dir = TempDir::new().unwrap();
504        let json_path = temp_dir.path().join("test.json");
505
506        let json_content = r#"{"URL": "https://example.com/article"}"#;
507        std::fs::write(&json_path, json_content).unwrap();
508
509        let url = read_url_from_json(&json_path);
510        assert_eq!(url, "https://example.com/article");
511    }
512
513    #[test]
514    fn test_extract_single_uses_hybrid_without_model() {
515        let temp_dir = TempDir::new().unwrap();
516        let html_path = temp_dir.path().join("page.html");
517        std::fs::write(
518            &html_path,
519            r#"
520            <html><head><title>Mission Update</title></head><body>
521                <nav class="site-nav"><a href="/a">Home</a> <a href="/b">News</a> <a href="/c">More</a></nav>
522                <article class="article-body">
523                    <p>The spacecraft entered orbit today after a long interplanetary cruise phase.</p>
524                    <p>Mission controllers confirmed every instrument survived the journey intact.</p>
525                    <p>Scientific observations of the planet will begin within the next two weeks.</p>
526                </article>
527                <div class="footer-links"><a href="/p">Privacy</a> <a href="/t">Terms</a></div>
528            </body></html>
529            "#,
530        )
531        .unwrap();
532
533        let config = Config::default();
534        let article = extract_single(
535            &html_path,
536            "https://example.com/mission".to_string(),
537            None, // no RL model
538            None, // no classifier -> hybrid heuristic path
539            None, // no output file
540            &config,
541        )
542        .unwrap();
543
544        // The hybrid selector must return the article body, not nav/footer noise.
545        assert_eq!(article.method, "hybrid");
546        assert!(article.content.contains("entered orbit"), "content: {}", article.content);
547        assert!(!article.content.contains("Privacy"), "footer leaked: {}", article.content);
548    }
549}