1use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::path::PathBuf;
9use crate::agents::AlgorithmType;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Config {
14 pub model_path: Option<PathBuf>,
16 pub site_profiles_dir: PathBuf,
17 pub output_dir: PathBuf,
18 pub models_dir: PathBuf,
19 pub use_cpu_for_tuning: bool,
20
21 #[serde(default)]
23 pub algorithm: AlgorithmType,
24
25 pub num_episodes: usize,
27 pub batch_size: usize,
28 pub learning_rate: f64,
29 pub gamma: f64,
30 pub epsilon_start: f64,
31 pub epsilon_end: f64,
32 pub epsilon_decay: f64,
33 pub target_update_freq: usize,
34 pub max_steps_per_episode: usize,
35
36 pub replay_buffer_size: usize,
38 pub priority_alpha: f64,
39 pub priority_beta: f64,
40
41 pub min_replay_size: usize,
43 pub train_freq: usize,
44 pub num_train_steps_per_episode: usize,
45 pub max_html_samples: usize,
46 pub sample_batch_load_size: usize,
47 pub prefetch_samples: bool,
48
49 pub metrics_window: usize,
51 pub checkpoint_freq: usize,
52 pub log_freq: usize,
53
54 pub state_dim: usize,
56 pub num_discrete_actions: usize,
57 pub num_continuous_params: usize,
58 pub num_candidate_nodes: usize,
59
60 pub ppo_clip_epsilon: f32,
62 pub ppo_gae_lambda: f32,
63 pub ppo_value_loss_coef: f32,
64 pub ppo_entropy_coef: f32,
65 pub ppo_epochs: usize,
66
67 pub stopwords: HashSet<String>,
69}
70
71impl Default for Config {
72 fn default() -> Self {
73 Self {
74 model_path: std::env::var("ARTICLE_EXTRACTOR_MODEL_PATH")
75 .ok()
76 .map(PathBuf::from),
77 site_profiles_dir: std::env::var("ARTICLE_EXTRACTOR_SITE_PROFILES")
78 .ok()
79 .map(PathBuf::from)
80 .unwrap_or_else(|| PathBuf::from("./site_profiles")),
81 output_dir: std::env::var("ARTICLE_EXTRACTOR_OUTPUT_DIR")
82 .ok()
83 .map(PathBuf::from)
84 .unwrap_or_else(|| PathBuf::from("./output")),
85 models_dir: std::env::var("ARTICLE_EXTRACTOR_MODELS_DIR")
86 .ok()
87 .map(PathBuf::from)
88 .unwrap_or_else(|| PathBuf::from("./models")),
89 use_cpu_for_tuning: false,
90
91 algorithm: AlgorithmType::DuelingDQN,
93
94 num_episodes: 10000,
95
96 batch_size: 512,
98
99 learning_rate: 3e-4,
100 gamma: 0.95,
101 epsilon_start: 1.0,
102 epsilon_end: 0.05,
103 epsilon_decay: 0.995,
104
105 target_update_freq: 500, max_steps_per_episode: 20,
109
110 replay_buffer_size: 100000,
111 priority_alpha: 0.6,
112 priority_beta: 0.4,
113
114 min_replay_size: 5000, train_freq: 4, num_train_steps_per_episode: 4, max_html_samples: 5000, sample_batch_load_size: 1000, prefetch_samples: true, metrics_window: 50, checkpoint_freq: 500, log_freq: 5, state_dim: 300,
128 num_discrete_actions: 16,
129 num_continuous_params: 6,
130 num_candidate_nodes: 10,
131
132 ppo_clip_epsilon: 0.2,
134 ppo_gae_lambda: 0.95,
135 ppo_value_loss_coef: 0.5,
136 ppo_entropy_coef: 0.01,
137 ppo_epochs: 4,
138
139 stopwords: Self::default_stopwords(),
140 }
141 }
142}
143
144impl Config {
145 pub fn with_algorithm(algorithm: AlgorithmType) -> Self {
147 Self { algorithm, ..Self::default() }
148 }
149
150 pub fn ppo_recommended() -> Self {
152 let mut config = Self::with_algorithm(AlgorithmType::PPO);
153 config.batch_size = 2048;
154 config.learning_rate = 3e-4;
155 config.num_train_steps_per_episode = 8;
156 config.ppo_epochs = 10;
157 config.ppo_clip_epsilon = 0.2;
158 config.ppo_gae_lambda = 0.95;
159 config
160 }
161
162 pub fn dqn_optimized() -> Self {
164 let mut config = Self::with_algorithm(AlgorithmType::DuelingDQN);
165 config.batch_size = 2048;
166 config.learning_rate = 0.001;
167 config.target_update_freq = 500;
168 config
169 }
170
171 pub fn gpu_optimized() -> Self {
173 Self {
174 batch_size: 6144,
175 num_train_steps_per_episode: 32,
176 train_freq: 1,
177 replay_buffer_size: 500000,
178 min_replay_size: 20000,
179 max_html_samples: 10000,
180 sample_batch_load_size: 2000,
181 learning_rate: 0.00183,
182 target_update_freq: 100,
183 metrics_window: 50,
184 checkpoint_freq: 250,
185 ..Self::default()
186 }
187 }
188
189 pub fn from_env() -> Result<Self> {
191 Ok(Self::default())
192 }
193
194 pub fn setup_directories(&self) -> Result<()> {
196 std::fs::create_dir_all(&self.site_profiles_dir)?;
197 std::fs::create_dir_all(&self.output_dir)?;
198 std::fs::create_dir_all(&self.models_dir)?;
199 Ok(())
200 }
201
202 fn default_stopwords() -> HashSet<String> {
204 vec![
205 "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
206 "of", "with", "by", "from", "as", "is", "was", "are", "been", "be",
207 "have", "has", "had", "do", "does", "did", "will", "would", "could",
208 "should", "may", "might", "can", "this", "that", "these", "those",
209 "i", "you", "he", "she", "it", "we", "they", "them", "their", "his",
210 "her", "its", "our", "your", "who", "what", "where", "when", "why",
211 "how", "which", "there", "here", "more", "most", "some", "any", "all",
212 ]
213 .into_iter()
214 .map(|s| s.to_string())
215 .collect()
216 }
217}
218
219pub const ACTION_SELECT_NODE_0: usize = 0;
221pub const ACTION_SELECT_NODE_9: usize = 9;
222pub const ACTION_SELECT_PARENT: usize = 10;
223pub const ACTION_SELECT_SIBLING_LEFT: usize = 11;
224pub const ACTION_SELECT_SIBLING_RIGHT: usize = 12;
225pub const ACTION_EXPAND_REGION: usize = 13;
226pub const ACTION_CONTRACT_REGION: usize = 14;
227pub const ACTION_TERMINATE: usize = 15;