use content_extractor_rl::*;
use std::path::PathBuf;
use tempfile::TempDir;
use content_extractor_rl::curriculum::CurriculumManager;
use content_extractor_rl::html_parser::HtmlParser;
use content_extractor_rl::replay_buffer::PrioritizedReplayBuffer;
use content_extractor_rl::reward::ImprovedRewardCalculator;
use content_extractor_rl::text_utils::TextUtils;
#[test]
fn test_end_to_end_extraction() {
let html = r#"
<html>
<head><title>Test Article</title></head>
<body>
<nav>Navigation</nav>
<article>
<h1>Breaking News: Important Discovery</h1>
<p>Scientists have made a groundbreaking discovery in the field of
artificial intelligence that could revolutionize the way we interact
with technology.</p>
<p>The research team, led by experts from leading universities, developed
a novel approach that combines deep learning with reinforcement learning
to create more efficient and accurate models.</p>
<p>This breakthrough has significant implications for various industries
including healthcare, finance, and autonomous systems. Industry leaders
are already expressing interest in applying these findings.</p>
</article>
<footer>Copyright 2024</footer>
</body>
</html>
"#;
let config = Config::default();
let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
let result = baseline_extractor.extract(html).unwrap();
assert!(!result.text.is_empty(), "Should extract some text");
assert!(result.quality_score > 0.0, "Should have positive quality score");
assert!(result.text.contains("discovery"), "Should contain article content");
assert!(!result.text.contains("Navigation"), "Should not contain nav text");
assert!(!result.text.contains("Copyright"), "Should not contain footer text");
}
#[test]
fn test_baseline_extractor_quality_scoring() {
let good_html = r#"
<article>
<p>This is a well-written article with proper sentence structure.
It contains multiple paragraphs with substantial content and provides
detailed information on the topic at hand with appropriate vocabulary.</p>
<p>The article provides detailed information on the topic at hand.
Each paragraph contributes meaningfully to the overall narrative and
demonstrates good writing with proper punctuation and grammar.</p>
<p>Furthermore, the text maintains good lexical diversity and appropriate
punctuation throughout the entire piece. This ensures high quality content
that readers can appreciate and understand easily.</p>
</article>
"#;
let poor_html = "
<div>
<a href=\"\">Link</a>
<a href=\"#\">Link</a>
<a href=\"#\">Link</a>
Short text.
</div>
";
let config = Config::default();
let extractor = BaselineExtractor::new(config.stopwords);
let good_result = extractor.extract(good_html).unwrap();
let poor_result = extractor.extract(poor_html).unwrap();
println!("Good quality score: {}", good_result.quality_score);
println!("Poor quality score: {}", poor_result.quality_score);
assert!(
good_result.quality_score > poor_result.quality_score,
"Good article should score higher than poor one: {} vs {}",
good_result.quality_score, poor_result.quality_score
);
assert!(
good_result.quality_score > 0.3,
"Good article should score > 0.3, got {}",
good_result.quality_score
);
assert!(
poor_result.quality_score < 0.3,
"Poor article should score < 0.3, got {}",
poor_result.quality_score
);
}
#[test]
fn test_site_profile_persistence() {
let temp_dir = TempDir::new().unwrap();
let mut memory = SiteProfileMemory::new(temp_dir.path()).unwrap();
let profile = memory.get_profile("example.com");
let result = site_profile::ExtractionResult {
text: "Test article content".to_string(),
xpath: "//article[1]".to_string(),
quality_score: 0.85,
parameters: std::collections::HashMap::new(),
title: None,
date: None,
};
profile.add_extraction(result);
memory.save_profile("example.com").unwrap();
let mut new_memory = SiteProfileMemory::new(temp_dir.path()).unwrap();
let loaded_profile = new_memory.get_profile("example.com");
assert_eq!(loaded_profile.quality_scores.len(), 1);
assert_eq!(loaded_profile.quality_scores[0], 0.85);
}
#[test]
fn test_replay_buffer() {
let mut buffer = PrioritizedReplayBuffer::new(100, 0.6, 0.4);
for i in 0..50 {
let exp = replay_buffer::Experience {
state: vec![0.0; 300],
action: (i % 16, vec![0.0; 6]),
reward: (i as f32) / 50.0,
next_state: vec![0.0; 300],
done: false,
};
buffer.add(exp);
}
assert_eq!(buffer.len(), 50);
let batch = buffer.sample(32);
assert!(batch.is_some());
let batch = batch.unwrap();
assert_eq!(batch.experiences.len(), 32);
assert_eq!(batch.indices.len(), 32);
assert_eq!(batch.weights.len(), 32);
}
#[test]
fn test_curriculum_manager() {
let mut curriculum = CurriculumManager::new();
let initial_threshold = curriculum.get_threshold();
assert_eq!(initial_threshold, 0.3);
for episode in (0..1000).step_by(100) {
curriculum.update_threshold(episode);
}
let updated_threshold = curriculum.get_threshold();
assert!(updated_threshold > initial_threshold, "Threshold should increase");
assert!(updated_threshold <= 1.0, "Threshold should not exceed 1.0");
}
#[test]
fn test_reward_calculator() {
let config = Config::default();
let calculator = ImprovedRewardCalculator::new(config.stopwords);
let good_text = "This is an excellent article with proper structure and substantial content. \
It contains multiple well-formed sentences that provide valuable information \
to the reader. The text maintains good quality throughout with appropriate \
vocabulary and demonstrates clear communication. Furthermore, it includes \
diverse words and maintains coherent paragraphs with proper punctuation marks. \
Each sentence contributes meaningfully to the overall narrative and provides \
detailed explanations that help readers understand the topic thoroughly.";
let poor_text = "Short.";
let good_reward = calculator.calculate_reward(good_text, 0.0);
let poor_reward = calculator.calculate_reward(poor_text, 0.0);
println!("Good text reward: {}", good_reward);
println!("Poor text reward: {}", poor_reward);
assert!(
good_reward > poor_reward,
"Good text should have higher reward: {} vs {}",
good_reward, poor_reward
);
assert!(
good_reward > -0.5,
"Good text should have reward > -0.5, got {}",
good_reward
);
assert!(
poor_reward < 0.0,
"Poor text should have negative reward, got {}",
poor_reward
);
}
#[test]
fn test_model_checkpoint() {
let temp_dir = TempDir::new().unwrap();
let checkpoint = Checkpoint::new(
100,
5000,
0.5,
0.75,
0.8,
0.1,
PathBuf::from("model.onnx"),
);
let checkpoint_path = temp_dir.path().join("checkpoint.json");
checkpoint.save(&checkpoint_path).unwrap();
let loaded = Checkpoint::load(&checkpoint_path).unwrap();
assert_eq!(loaded.episode, 100);
assert_eq!(loaded.step_count, 5000);
assert_eq!(loaded.avg_reward, 0.5);
assert_eq!(loaded.avg_quality, 0.75);
assert_eq!(loaded.best_quality, 0.8);
}
#[test]
fn test_checkpoint_manager() {
let temp_dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(temp_dir.path().to_path_buf(), 3).unwrap();
for i in 0..5 {
let checkpoint = Checkpoint::new(
i * 100,
i * 1000,
0.5 + (i as f32 * 0.05),
0.7 + (i as f32 * 0.02),
0.8,
0.1 - (i as f32 * 0.01),
PathBuf::from(format!("model_{}.onnx", i)),
);
manager.save_checkpoint(&checkpoint).unwrap();
}
let checkpoints = manager.list_checkpoints().unwrap();
assert!(checkpoints.len() <= 3, "Should keep max 3 checkpoints");
let latest = manager.load_latest().unwrap();
assert!(latest.is_some());
assert_eq!(latest.unwrap().episode, 400);
let best = manager.load_best().unwrap();
assert!(best.is_some());
}
#[test]
fn test_html_parser_xpath() {
let html = r#"
<html>
<body>
<div id="content">
<article>
<h1>Title</h1>
<p>Paragraph 1</p>
<p>Paragraph 2</p>
</article>
</div>
</body>
</html>
"#;
let document = HtmlParser::parse(html).unwrap();
let candidates = HtmlParser::get_candidate_nodes(&document, 5);
assert!(!candidates.is_empty(), "Should find candidate nodes");
if let Some(first) = candidates.first() {
let xpath = HtmlParser::get_element_path(*first);
assert!(!xpath.is_empty(), "Should generate XPath");
}
}
#[test]
fn test_text_quality_calculation() {
let config = Config::default();
let high_quality = "The quick brown fox jumps over the lazy dog with remarkable agility. \
This is a well-formed sentence with proper structure and demonstrates \
excellent writing quality. It contains appropriate vocabulary and maintains \
coherence throughout the entire passage. Furthermore, the text exhibits \
good lexical diversity with varied word choices and proper punctuation. \
Each sentence contributes meaningfully to the overall narrative while \
maintaining reader engagement through clear and concise communication.";
let low_quality = "a a a a a";
let high_score = TextUtils::calculate_text_quality(high_quality, &config.stopwords);
let low_score = TextUtils::calculate_text_quality(low_quality, &config.stopwords);
println!("High quality score: {}", high_score);
println!("Low quality score: {}", low_score);
assert!(
high_score > low_score,
"High quality text should score higher: {} vs {}",
high_score, low_score
);
assert!(
high_score > 0.3,
"High quality should score > 0.3, got {}",
high_score
);
}
#[test]
fn test_config_from_env() {
std::env::set_var("ARTICLE_EXTRACTOR_MODEL_PATH", "/tmp/model.onnx");
std::env::set_var("ARTICLE_EXTRACTOR_SITE_PROFILES", "/tmp/profiles");
std::env::set_var("ARTICLE_EXTRACTOR_OUTPUT_DIR", "/tmp/output");
let config = Config::from_env().unwrap();
assert_eq!(config.model_path, Some(PathBuf::from("/tmp/model.onnx")));
assert_eq!(config.site_profiles_dir, PathBuf::from("/tmp/profiles"));
assert_eq!(config.output_dir, PathBuf::from("/tmp/output"));
std::env::remove_var("ARTICLE_EXTRACTOR_MODEL_PATH");
std::env::remove_var("ARTICLE_EXTRACTOR_SITE_PROFILES");
std::env::remove_var("ARTICLE_EXTRACTOR_OUTPUT_DIR");
}
#[test]
fn test_batch_extraction_result() {
let articles = vec![
ExtractedArticle {
url: "https://example.com/1".to_string(),
title: Some("Article 1".to_string()),
date: None,
content: "Content 1".to_string(),
quality_score: 0.8,
method: "baseline".to_string(),
xpath: Some("//article[1]".to_string()),
},
ExtractedArticle {
url: "https://example.com/2".to_string(),
title: Some("Article 2".to_string()),
date: None,
content: "Content 2".to_string(),
quality_score: 0.9,
method: "rl".to_string(),
xpath: Some("//article[1]".to_string()),
},
];
let batch_result = BatchExtractionResult { articles };
let json = serde_json::to_string(&batch_result).unwrap();
let deserialized: BatchExtractionResult = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.articles.len(), 2);
assert_eq!(deserialized.articles[0].url, "https://example.com/1");
assert_eq!(deserialized.articles[1].quality_score, 0.9);
}