Skip to main content

content_extractor_rl/agents/
mod.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/agents/mod.rs
3// ============================================================================
4pub mod dqn_agent;
5pub mod ppo_agent;
6pub mod sac_agent;
7
8use crate::{Result, replay_buffer::PrioritizedReplayBuffer};
9use candle_core::Device;
10use std::path::Path;
11use candle_nn::VarMap;
12use serde::{Serialize, Deserialize};
13use crate::models::NetworkConfig;
14
15/// Algorithm type selection
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[derive(Default)]
18pub enum AlgorithmType {
19    #[default]
20    DuelingDQN,
21    PPO,
22    SAC,
23    TD3,
24    Rainbow,
25}
26impl std::str::FromStr for AlgorithmType {
27    type Err = String;
28    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
29        match s.to_lowercase().as_str() {
30            "dqn" | "dueling_dqn" | "duelingdqn" => Ok(AlgorithmType::DuelingDQN),
31            "ppo" => Ok(AlgorithmType::PPO),
32            "sac" => Ok(AlgorithmType::SAC),
33            "td3" => Ok(AlgorithmType::TD3),
34            "rainbow" => Ok(AlgorithmType::Rainbow),
35            _ => Err(format!("Unknown algorithm type: {}. Supported: dqn, ppo, sac, td3, rainbow", s))
36        }
37    }
38}
39impl std::fmt::Display for AlgorithmType {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            AlgorithmType::DuelingDQN => write!(f, "DuelingDQN"),
43            AlgorithmType::PPO => write!(f, "PPO"),
44            AlgorithmType::SAC => write!(f, "SAC"),
45            AlgorithmType::TD3 => write!(f, "TD3"),
46            AlgorithmType::Rainbow => write!(f, "Rainbow"),
47        }
48    }
49}
50
51/// Common trait for all RL agents
52pub trait RLAgent: Send + Sync {
53    /// Select action given state and exploration parameter
54    /// Returns: (discrete_action, continuous_params, optional_log_prob)
55    fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)>;
56
57    /// Save model with metadata
58    fn save_with_metadata(&self, path: &Path, training_episodes: usize, hyperparameters: std::collections::HashMap<String, f64>) -> Result<()>;
59
60    /// Save model to disk (uses default metadata)
61    fn save(&self, path: &Path) -> Result<()>;
62
63    /// Train on a batch of experiences
64    /// Returns: loss value
65    fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32>;
66
67    /// Update target network (if applicable, no-op for on-policy methods)
68    fn update_target_network(&mut self);
69
70    /// Get training step count
71    fn get_step_count(&self) -> usize;
72
73    /// Get algorithm type
74    fn algorithm_type(&self) -> AlgorithmType;
75
76    /// Get algorithm-specific info for logging
77    fn get_info(&self) -> AgentInfo;
78
79}
80
81/// Agent information for logging and tracking
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct AgentInfo {
84    pub algorithm: AlgorithmType,
85    pub num_parameters: usize,
86    pub state_dim: usize,
87    pub num_actions: usize,
88    pub continuous_params: usize,
89    pub version: String,
90    pub features: Vec<String>,
91}
92/// Factory for creating RL agents
93pub struct AgentFactory;
94
95impl AgentFactory {
96    /// Create agent from configuration
97    pub fn create(
98        algorithm: AlgorithmType,
99        state_dim: usize,
100        num_actions: usize,
101        num_params: usize,
102        gamma: f32,
103        lr: f64,
104        device: &Device,
105    ) -> Result<Box<dyn RLAgent>> {
106        match algorithm {
107            AlgorithmType::DuelingDQN => {
108                let network_config = NetworkConfig {
109                    state_dim,
110                    num_actions,
111                    num_params,
112                    hidden_layers: vec![512, 256, 128],
113                    use_layer_norm: true,
114                    dropout: 0.1,
115                    value_hidden: 64,
116                    advantage_hidden: 64,
117                };
118
119                // Create varmap for this agent
120                let varmap = VarMap::new();
121
122                let agent = dqn_agent::DQNAgent::new(
123                    network_config, gamma, lr, device, varmap
124                )?;
125                Ok(Box::new(agent))
126            }
127            AlgorithmType::PPO => {
128                let varmap = candle_nn::VarMap::new();
129                let agent = ppo_agent::PPOAgent::new(
130                    state_dim, num_actions, num_params, gamma, lr, device, varmap
131                )?;
132                Ok(Box::new(agent))
133            }
134            AlgorithmType::SAC => {
135                let actor_varmap = candle_nn::VarMap::new();
136                let critic_varmap = candle_nn::VarMap::new();
137                let agent = sac_agent::SACAgent::new(
138                    state_dim, num_actions, num_params, gamma, lr, device,
139                    actor_varmap, critic_varmap
140                )?;
141                Ok(Box::new(agent))
142            }
143            _ => Err(crate::ExtractionError::ModelError(
144                format!("Algorithm {} not yet implemented", algorithm)
145            ))
146        }
147    }
148
149    /// Load agent from saved model
150    pub fn load(
151        path: &Path,
152        state_dim: usize,
153        num_actions: usize,
154        num_params: usize,
155        device: &Device,
156    ) -> Result<Box<dyn RLAgent>> {
157        let algorithm = Self::detect_algorithm(path)?;
158
159        match algorithm {
160            AlgorithmType::DuelingDQN => {
161                let agent = dqn_agent::DQNAgent::load_with_device(
162                    path, state_dim, num_actions, num_params, device
163                )?;
164                Ok(Box::new(agent))
165            }
166            AlgorithmType::PPO => {
167                let agent = ppo_agent::PPOAgent::load_with_device(
168                    path, state_dim, num_actions, num_params, device
169                )?;
170                Ok(Box::new(agent))
171            }
172            AlgorithmType::SAC => {
173                let agent = sac_agent::SACAgent::load_with_device(
174                    path, state_dim, num_actions, num_params, device
175                )?;
176                Ok(Box::new(agent))
177            }
178            _ => Err(crate::ExtractionError::ModelError(
179                format!("Algorithm {} loading not implemented", algorithm)
180            ))
181        }
182    }
183
184    /// Detect algorithm type from saved model
185    fn detect_algorithm(path: &Path) -> Result<AlgorithmType> {
186        use std::fs::File;
187        use std::io::Read;
188
189        let mut file = File::open(path)?;
190        let mut metadata_len_bytes = [0u8; 8];
191        file.read_exact(&mut metadata_len_bytes)?;
192        let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
193
194        let mut metadata_bytes = vec![0u8; metadata_len];
195        file.read_exact(&mut metadata_bytes)?;
196
197        let metadata_json = String::from_utf8(metadata_bytes)
198            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
199
200        #[derive(Deserialize)]
201        struct Metadata {
202            architecture: String,
203        }
204
205        let metadata: Metadata = serde_json::from_str(&metadata_json)
206            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
207
208        metadata.architecture.parse()
209            .map_err(|e: String| crate::ExtractionError::ParseError(e))
210    }
211    
212}