content_extractor_rl/agents/
mod.rs1pub mod dqn_agent;
5pub mod ppo_agent;
6pub mod sac_agent;
7
8use crate::{Result, replay_buffer::PrioritizedReplayBuffer};
9use candle_core::Device;
10use std::path::Path;
11use candle_nn::VarMap;
12use serde::{Serialize, Deserialize};
13use crate::models::NetworkConfig;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[derive(Default)]
18pub enum AlgorithmType {
19 #[default]
20 DuelingDQN,
21 PPO,
22 SAC,
23 TD3,
24 Rainbow,
25}
26impl std::str::FromStr for AlgorithmType {
27 type Err = String;
28 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
29 match s.to_lowercase().as_str() {
30 "dqn" | "dueling_dqn" | "duelingdqn" => Ok(AlgorithmType::DuelingDQN),
31 "ppo" => Ok(AlgorithmType::PPO),
32 "sac" => Ok(AlgorithmType::SAC),
33 "td3" => Ok(AlgorithmType::TD3),
34 "rainbow" => Ok(AlgorithmType::Rainbow),
35 _ => Err(format!("Unknown algorithm type: {}. Supported: dqn, ppo, sac, td3, rainbow", s))
36 }
37 }
38}
39impl std::fmt::Display for AlgorithmType {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 AlgorithmType::DuelingDQN => write!(f, "DuelingDQN"),
43 AlgorithmType::PPO => write!(f, "PPO"),
44 AlgorithmType::SAC => write!(f, "SAC"),
45 AlgorithmType::TD3 => write!(f, "TD3"),
46 AlgorithmType::Rainbow => write!(f, "Rainbow"),
47 }
48 }
49}
50
51pub trait RLAgent: Send + Sync {
53 fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)>;
56
57 fn save_with_metadata(&self, path: &Path, training_episodes: usize, hyperparameters: std::collections::HashMap<String, f64>) -> Result<()>;
59
60 fn save(&self, path: &Path) -> Result<()>;
62
63 fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32>;
66
67 fn update_target_network(&mut self);
69
70 fn get_step_count(&self) -> usize;
72
73 fn algorithm_type(&self) -> AlgorithmType;
75
76 fn get_info(&self) -> AgentInfo;
78
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct AgentInfo {
84 pub algorithm: AlgorithmType,
85 pub num_parameters: usize,
86 pub state_dim: usize,
87 pub num_actions: usize,
88 pub continuous_params: usize,
89 pub version: String,
90 pub features: Vec<String>,
91}
92pub struct AgentFactory;
94
95impl AgentFactory {
96 pub fn create(
98 algorithm: AlgorithmType,
99 state_dim: usize,
100 num_actions: usize,
101 num_params: usize,
102 gamma: f32,
103 lr: f64,
104 device: &Device,
105 ) -> Result<Box<dyn RLAgent>> {
106 match algorithm {
107 AlgorithmType::DuelingDQN => {
108 let network_config = NetworkConfig {
109 state_dim,
110 num_actions,
111 num_params,
112 hidden_layers: vec![512, 256, 128],
113 use_layer_norm: true,
114 dropout: 0.1,
115 value_hidden: 64,
116 advantage_hidden: 64,
117 };
118
119 let varmap = VarMap::new();
121
122 let agent = dqn_agent::DQNAgent::new(
123 network_config, gamma, lr, device, varmap
124 )?;
125 Ok(Box::new(agent))
126 }
127 AlgorithmType::PPO => {
128 let varmap = candle_nn::VarMap::new();
129 let agent = ppo_agent::PPOAgent::new(
130 state_dim, num_actions, num_params, gamma, lr, device, varmap
131 )?;
132 Ok(Box::new(agent))
133 }
134 AlgorithmType::SAC => {
135 let actor_varmap = candle_nn::VarMap::new();
136 let critic_varmap = candle_nn::VarMap::new();
137 let agent = sac_agent::SACAgent::new(
138 state_dim, num_actions, num_params, gamma, lr, device,
139 actor_varmap, critic_varmap
140 )?;
141 Ok(Box::new(agent))
142 }
143 _ => Err(crate::ExtractionError::ModelError(
144 format!("Algorithm {} not yet implemented", algorithm)
145 ))
146 }
147 }
148
149 pub fn load(
151 path: &Path,
152 state_dim: usize,
153 num_actions: usize,
154 num_params: usize,
155 device: &Device,
156 ) -> Result<Box<dyn RLAgent>> {
157 let algorithm = Self::detect_algorithm(path)?;
158
159 match algorithm {
160 AlgorithmType::DuelingDQN => {
161 let agent = dqn_agent::DQNAgent::load_with_device(
162 path, state_dim, num_actions, num_params, device
163 )?;
164 Ok(Box::new(agent))
165 }
166 AlgorithmType::PPO => {
167 let agent = ppo_agent::PPOAgent::load_with_device(
168 path, state_dim, num_actions, num_params, device
169 )?;
170 Ok(Box::new(agent))
171 }
172 AlgorithmType::SAC => {
173 let agent = sac_agent::SACAgent::load_with_device(
174 path, state_dim, num_actions, num_params, device
175 )?;
176 Ok(Box::new(agent))
177 }
178 _ => Err(crate::ExtractionError::ModelError(
179 format!("Algorithm {} loading not implemented", algorithm)
180 ))
181 }
182 }
183
184 fn detect_algorithm(path: &Path) -> Result<AlgorithmType> {
186 use std::fs::File;
187 use std::io::Read;
188
189 let mut file = File::open(path)?;
190 let mut metadata_len_bytes = [0u8; 8];
191 file.read_exact(&mut metadata_len_bytes)?;
192 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
193
194 let mut metadata_bytes = vec![0u8; metadata_len];
195 file.read_exact(&mut metadata_bytes)?;
196
197 let metadata_json = String::from_utf8(metadata_bytes)
198 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
199
200 #[derive(Deserialize)]
201 struct Metadata {
202 architecture: String,
203 }
204
205 let metadata: Metadata = serde_json::from_str(&metadata_json)
206 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
207
208 metadata.architecture.parse()
209 .map_err(|e: String| crate::ExtractionError::ParseError(e))
210 }
211
212}