Skip to main content

content_extractor_rl/
cli_utils.rs

1//! High-level command interface for CLI
2//! This module contains the main logic for each CLI command
3// ============================================================================
4// FILE: crates/content-extractor-rl/src/cli_utils.rs
5// ============================================================================
6
7use crate::*;
8use std::path::{Path, PathBuf};
9use bzip2::read::BzDecoder;
10use std::io::Read;
11use indicatif::{ProgressBar, ProgressStyle};
12use url::Url;
13use crate::agents::dqn_agent::DQNAgent;
14
15/// Extract article from single HTML file
16pub fn extract_single(
17    html_file: &Path,
18    url: String,
19    model_path: Option<&Path>,
20    output: Option<&Path>,
21    config: &Config,
22) -> Result<ExtractedArticle> {
23    let html_content = read_html_file(html_file)?;
24    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
25
26    // Extract domain for site profile
27    let domain = extract_domain_from_url(&url);
28
29    // Try to load site profile
30    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
31    let site_profile = site_memory.get_profile(&domain);
32
33    let result = if let Some(model_path) = model_path {
34        let device = get_device();
35        let _agent = DQNAgent::load_with_device(
36            model_path,
37            config.state_dim,
38            config.num_discrete_actions,
39            config.num_continuous_params,
40            &device,
41        )?;
42
43        // Use site profile if available for better extraction
44        if site_profile.extractions.len() > 5 {
45            tracing::debug!("Using site profile for {} (has {} past extractions)",
46                          domain, site_profile.extractions.len());
47        }
48
49        baseline_extractor.extract(&html_content)?
50    } else {
51        baseline_extractor.extract(&html_content)?
52    };
53
54    let article = ExtractedArticle {
55        url: url.clone(),
56        title: result.title,
57        date: result.date,
58        content: result.text,
59        quality_score: result.quality_score,
60        method: if model_path.is_some() { "rl" } else { "baseline" }.to_string(),
61        xpath: Some(result.xpath),
62    };
63
64    if let Some(output_path) = output {
65        let batch_result = BatchExtractionResult {
66            articles: vec![article.clone()],
67        };
68        let json = serde_json::to_string_pretty(&batch_result)?;
69        std::fs::write(output_path, json)?;
70    }
71
72    Ok(article)
73}
74
75/// Extract batch of HTML files with site profile support
76pub fn extract_batch(
77    archive_dir: &Path,
78    model_path: Option<&Path>,
79    output_dir: &Path,
80    max_files: Option<usize>,
81    _batch_size: usize,
82    config: &Config,
83) -> Result<BatchExtractionResult> {
84    std::fs::create_dir_all(output_dir)?;
85
86    let file_pairs = load_html_files_recursive(archive_dir, max_files)?;
87
88    if file_pairs.is_empty() {
89        return Err(ExtractionError::ExtractionFailed(
90            "No HTML files found".to_string()
91        ));
92    }
93    let count_of_files: usize = file_pairs.len();
94
95    tracing::info!("Found {} HTML/JSON file pairs", count_of_files);
96
97    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
98    let device = get_device();
99    let agent = if let Some(path) = model_path {
100        Some(DQNAgent::load_with_device(
101            path,
102            config.state_dim,
103            config.num_discrete_actions,
104            config.num_continuous_params,
105            &device,
106        )?)
107    } else {
108        None
109    };
110
111    // Initialize site profile memory
112    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
113
114    let pb = ProgressBar::new(count_of_files as u64);
115    pb.set_style(
116        ProgressStyle::default_bar()
117            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
118            .unwrap()
119            .progress_chars("=>-"),
120    );
121
122    let mut all_articles = Vec::new();
123    let mut failed = Vec::new();
124    let mut site_profile_used_count = 0;
125
126    for (html_path, json_path) in file_pairs {
127        let url = read_url_from_json(&json_path);
128        let domain = extract_domain_from_url(&url);
129
130        let html_content = match read_html_file(&html_path) {
131            Ok(content) => content,
132            Err(e) => {
133                failed.push((url, e.to_string()));
134                pb.inc(1);
135                continue;
136            }
137        };
138
139        // Get site profile for this domain
140        let site_profile = site_memory.get_profile(&domain);
141        let has_profile = site_profile.extractions.len() > 5;
142
143        if has_profile {
144            site_profile_used_count += 1;
145        }
146
147        match baseline_extractor.extract(&html_content) {
148            Ok(result) => {
149                let method = if agent.is_some() {
150                    if has_profile { "rl+profile" } else { "rl" }
151                } else if has_profile { "baseline+profile" } else { "baseline" };
152
153                let article = ExtractedArticle {
154                    url: url.clone(),
155                    title: result.title.clone(),
156                    date: result.date.clone(),
157                    content: result.text.clone(),
158                    quality_score: result.quality_score,
159                    method: method.to_string(),
160                    xpath: Some(result.xpath.clone()),
161                };
162
163                // Update site profile with this extraction
164                let extraction_result = site_profile::ExtractionResult {
165                    text: result.text,
166                    xpath: result.xpath,
167                    quality_score: result.quality_score,
168                    parameters: result.parameters,
169                    title: result.title,
170                    date: result.date,
171                };
172                site_profile.add_extraction(extraction_result);
173
174                all_articles.push(article);
175            }
176            Err(e) => {
177                failed.push((url, e.to_string()));
178            }
179        }
180        pb.inc(1);
181    }
182
183    pb.finish_with_message("Batch extraction complete");
184
185    // Save site profiles
186    site_memory.save_all()?;
187    tracing::info!("Site profiles saved ({} domains used profiles)", site_profile_used_count);
188
189    // Save results
190    let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
191    let results_path = output_dir.join(format!("batch_results_{}.json", timestamp));
192    let batch_result = BatchExtractionResult { articles: all_articles.clone() };
193    let json = serde_json::to_string_pretty(&batch_result)?;
194    std::fs::write(&results_path, json)?;
195
196    // Save failed extractions
197    if !failed.is_empty() {
198        let failed_path = output_dir.join(format!("failed_{}.json", timestamp));
199        let failed_json = serde_json::to_string_pretty(&failed)?;
200        std::fs::write(&failed_path, failed_json)?;
201        tracing::warn!("Failed extractions saved to: {}", failed_path.display());
202    }
203
204    tracing::info!("Batch extraction: {}/{} successful, {} with site profiles",
205                   all_articles.len(), count_of_files, site_profile_used_count);
206
207    Ok(batch_result)
208}
209
210/// Extract domain from URL
211pub fn extract_domain_from_url(url: &str) -> String {
212    match Url::parse(url) {
213        Ok(parsed_url) => {
214            parsed_url.host_str()
215                .map(|h| h.to_string())
216                .unwrap_or_else(|| "unknown".to_string())
217        }
218        Err(_) => {
219            let url = url.trim();
220            let without_protocol = url.strip_prefix("https://")
221                .or_else(|| url.strip_prefix("http://"))
222                .unwrap_or(url);
223
224            let host_part = without_protocol.split('/').next().unwrap_or("");
225            let domain = host_part.split(':').next().unwrap_or("");
226
227            if domain.is_empty() {
228                "unknown".to_string()
229            } else {
230                domain.to_string()
231            }
232        }
233    }
234}
235
236/// Load HTML files recursively
237pub fn load_html_files_recursive(
238    dir: &Path,
239    max_files: Option<usize>,
240) -> Result<Vec<(PathBuf, PathBuf)>> {
241    use walkdir::WalkDir;
242
243    let mut files = Vec::new();
244
245    for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
246        if let Some(max) = max_files {
247            if files.len() >= max {
248                break;
249            }
250        }
251
252        let path = entry.path();
253        if path.is_file() {
254            if let Some(ext) = path.extension() {
255                if ext == "bz2" && path.to_string_lossy().contains(".html.") {
256                    let json_path = path.with_extension("").with_extension("json");
257                    if json_path.exists() {
258                        files.push((path.to_path_buf(), json_path));
259                    }
260                } else if ext == "html" || ext == "htm" {
261                    let json_path = path.with_extension("json");
262                    if json_path.exists() {
263                        files.push((path.to_path_buf(), json_path));
264                    }
265                }
266            }
267        }
268    }
269
270    Ok(files)
271}
272
273/// Read HTML file with UTF-8 error handling
274pub fn read_html_file(path: &Path) -> Result<String> {
275    if path.extension().and_then(|s| s.to_str()) == Some("bz2") {
276        let file = std::fs::File::open(path)?;
277        let mut decoder = BzDecoder::new(file);
278        let mut bytes = Vec::new();
279        decoder.read_to_end(&mut bytes)?;
280        Ok(String::from_utf8_lossy(&bytes).into_owned())
281    } else {
282        let bytes = std::fs::read(path)?;
283        Ok(String::from_utf8_lossy(&bytes).into_owned())
284    }
285}
286
287/// Read URL from JSON file
288pub fn read_url_from_json(json_path: &Path) -> String {
289    match std::fs::read_to_string(json_path) {
290        Ok(json_content) => {
291            if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&json_content) {
292                json_value.get("URL")
293                    .and_then(|u| u.as_str())
294                    .map(|s| s.to_string())
295                    .unwrap_or_else(|| "https://example.com/unknown".to_string())
296            } else {
297                "https://example.com/invalid-json".to_string()
298            }
299        }
300        Err(_) => "https://example.com/no-json".to_string(),
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use tempfile::TempDir;
308
309    #[test]
310    fn test_extract_domain_from_url() {
311        assert_eq!(
312            extract_domain_from_url("https://www.example.com/article"),
313            "www.example.com"
314        );
315
316        assert_eq!(
317            extract_domain_from_url("http://subdomain.example.org:8080/path"),
318            "subdomain.example.org"
319        );
320    }
321
322    #[test]
323    fn test_read_url_from_json() {
324        let temp_dir = TempDir::new().unwrap();
325        let json_path = temp_dir.path().join("test.json");
326
327        let json_content = r#"{"URL": "https://example.com/article"}"#;
328        std::fs::write(&json_path, json_content).unwrap();
329
330        let url = read_url_from_json(&json_path);
331        assert_eq!(url, "https://example.com/article");
332    }
333}