Skip to main content

content_extractor_rl/
config.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/config.rs
3// ============================================================================
4
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::path::PathBuf;
9use crate::agents::AlgorithmType;
10
11/// Configuration for the content extractor rl
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Config {
14    // Paths
15    pub model_path: Option<PathBuf>,
16    pub site_profiles_dir: PathBuf,
17    pub output_dir: PathBuf,
18    pub models_dir: PathBuf,
19    pub use_cpu_for_tuning: bool,
20
21    // Algorithm selection
22    #[serde(default)]
23    pub algorithm: AlgorithmType,
24
25    // Training hyperparameters
26    pub num_episodes: usize,
27    pub batch_size: usize,
28    pub learning_rate: f64,
29    pub gamma: f64,
30    pub epsilon_start: f64,
31    pub epsilon_end: f64,
32    pub epsilon_decay: f64,
33    pub target_update_freq: usize,
34    pub max_steps_per_episode: usize,
35
36    // Replay buffer
37    pub replay_buffer_size: usize,
38    pub priority_alpha: f64,
39    pub priority_beta: f64,
40
41    // Performance tuning
42    pub min_replay_size: usize,
43    pub train_freq: usize,
44    pub num_train_steps_per_episode: usize,
45    pub max_html_samples: usize,
46    pub sample_batch_load_size: usize,
47    pub prefetch_samples: bool,
48
49    // Metrics
50    pub metrics_window: usize,
51    pub checkpoint_freq: usize,
52    pub log_freq: usize,
53
54    // State/Action space
55    pub state_dim: usize,
56    pub num_discrete_actions: usize,
57    pub num_continuous_params: usize,
58    pub num_candidate_nodes: usize,
59
60    // PPO-specific hyperparameters
61    pub ppo_clip_epsilon: f32,
62    pub ppo_gae_lambda: f32,
63    pub ppo_value_loss_coef: f32,
64    pub ppo_entropy_coef: f32,
65    pub ppo_epochs: usize,
66
67    // Stopwords
68    pub stopwords: HashSet<String>,
69}
70
71impl Default for Config {
72    fn default() -> Self {
73        Self {
74            model_path: std::env::var("ARTICLE_EXTRACTOR_MODEL_PATH")
75                .ok()
76                .map(PathBuf::from),
77            site_profiles_dir: std::env::var("ARTICLE_EXTRACTOR_SITE_PROFILES")
78                .ok()
79                .map(PathBuf::from)
80                .unwrap_or_else(|| PathBuf::from("./site_profiles")),
81            output_dir: std::env::var("ARTICLE_EXTRACTOR_OUTPUT_DIR")
82                .ok()
83                .map(PathBuf::from)
84                .unwrap_or_else(|| PathBuf::from("./output")),
85            models_dir: std::env::var("ARTICLE_EXTRACTOR_MODELS_DIR")
86                .ok()
87                .map(PathBuf::from)
88                .unwrap_or_else(|| PathBuf::from("./models")),
89            use_cpu_for_tuning: false,
90
91            // Default algorithm
92            algorithm: AlgorithmType::DuelingDQN,
93
94            num_episodes: 10000,
95
96            // Smaller batch size for better GPU utilization and is better for gradient updates
97            batch_size: 512,
98
99            learning_rate: 3e-4,
100            gamma: 0.95,
101            epsilon_start: 1.0,
102            epsilon_end: 0.05,
103            epsilon_decay: 0.995,
104
105            // OPTIMIZED: More frequent target updates
106            target_update_freq: 500,  // Was 1000
107
108            max_steps_per_episode: 20,
109
110            replay_buffer_size: 100000,
111            priority_alpha: 0.6,
112            priority_beta: 0.4,
113
114            // NEW PERFORMANCE SETTINGS
115            min_replay_size: 5000,              // Start training after 5K experiences
116            train_freq: 4,                      // Train every 4 steps (more frequent)
117            num_train_steps_per_episode: 4,    // 4 gradient updates per episode
118            max_html_samples: 5000,             // CRITICAL: Limit to 5K samples
119            sample_batch_load_size: 1000,       // Load 1K at a time
120            prefetch_samples: true,             // Enable async loading
121
122            // OPTIMIZED METRICS
123            metrics_window: 50,                 // Down from 100
124            checkpoint_freq: 500,               // More frequent saves
125            log_freq: 5,                        // Update progress every 5 episodes
126
127            state_dim: 300,
128            num_discrete_actions: 16,
129            num_continuous_params: 6,
130            num_candidate_nodes: 10,
131
132            // NEW: PPO defaults
133            ppo_clip_epsilon: 0.2,
134            ppo_gae_lambda: 0.95,
135            ppo_value_loss_coef: 0.5,
136            ppo_entropy_coef: 0.01,
137            ppo_epochs: 4,
138
139            stopwords: Self::default_stopwords(),
140        }
141    }
142}
143
144impl Config {
145    /// Create config with specific algorithm
146    pub fn with_algorithm(algorithm: AlgorithmType) -> Self {
147        Self { algorithm, ..Self::default() }
148    }
149
150    /// PPO recommended configuration
151    pub fn ppo_recommended() -> Self {
152        let mut config = Self::with_algorithm(AlgorithmType::PPO);
153        config.batch_size = 2048;
154        config.learning_rate = 3e-4;
155        config.num_train_steps_per_episode = 8;
156        config.ppo_epochs = 10;
157        config.ppo_clip_epsilon = 0.2;
158        config.ppo_gae_lambda = 0.95;
159        config
160    }
161
162    /// DQN optimized configuration
163    pub fn dqn_optimized() -> Self {
164        let mut config = Self::with_algorithm(AlgorithmType::DuelingDQN);
165        config.batch_size = 2048;
166        config.learning_rate = 0.001;
167        config.target_update_freq = 500;
168        config
169    }
170
171    /// Create high-performance GPU config
172    pub fn gpu_optimized() -> Self {
173        Self {
174            batch_size: 6144,
175            num_train_steps_per_episode: 32,
176            train_freq: 1,
177            replay_buffer_size: 500000,
178            min_replay_size: 20000,
179            max_html_samples: 10000,
180            sample_batch_load_size: 2000,
181            learning_rate: 0.00183,
182            target_update_freq: 100,
183            metrics_window: 50,
184            checkpoint_freq: 250,
185            ..Self::default()
186        }
187    }
188
189    /// Create config from environment variables
190    pub fn from_env() -> Result<Self> {
191        Ok(Self::default())
192    }
193
194    /// Setup required directories
195    pub fn setup_directories(&self) -> Result<()> {
196        std::fs::create_dir_all(&self.site_profiles_dir)?;
197        std::fs::create_dir_all(&self.output_dir)?;
198        std::fs::create_dir_all(&self.models_dir)?;
199        Ok(())
200    }
201
202    /// Default English stopwords
203    fn default_stopwords() -> HashSet<String> {
204        vec![
205            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
206            "of", "with", "by", "from", "as", "is", "was", "are", "been", "be",
207            "have", "has", "had", "do", "does", "did", "will", "would", "could",
208            "should", "may", "might", "can", "this", "that", "these", "those",
209            "i", "you", "he", "she", "it", "we", "they", "them", "their", "his",
210            "her", "its", "our", "your", "who", "what", "where", "when", "why",
211            "how", "which", "there", "here", "more", "most", "some", "any", "all",
212        ]
213            .into_iter()
214            .map(|s| s.to_string())
215            .collect()
216    }
217}
218
219// Action IDs
220pub const ACTION_SELECT_NODE_0: usize = 0;
221pub const ACTION_SELECT_NODE_9: usize = 9;
222pub const ACTION_SELECT_PARENT: usize = 10;
223pub const ACTION_SELECT_SIBLING_LEFT: usize = 11;
224pub const ACTION_SELECT_SIBLING_RIGHT: usize = 12;
225pub const ACTION_EXPAND_REGION: usize = 13;
226pub const ACTION_CONTRACT_REGION: usize = 14;
227pub const ACTION_TERMINATE: usize = 15;