pub trait RLAgent: Send + Sync {
// Required methods
fn select_action(
&self,
state: &[f32],
epsilon: f32,
) -> Result<(usize, Vec<f32>)>;
fn save_with_metadata(
&self,
path: &Path,
training_episodes: usize,
hyperparameters: HashMap<String, f64>,
) -> Result<()>;
fn save(&self, path: &Path) -> Result<()>;
fn train_step(
&mut self,
replay_buffer: &mut PrioritizedReplayBuffer,
batch_size: usize,
) -> Result<f32>;
fn update_target_network(&mut self);
fn get_step_count(&self) -> usize;
fn algorithm_type(&self) -> AlgorithmType;
fn get_info(&self) -> AgentInfo;
}Expand description
Common trait for all RL agents
Required Methods§
Sourcefn select_action(
&self,
state: &[f32],
epsilon: f32,
) -> Result<(usize, Vec<f32>)>
fn select_action( &self, state: &[f32], epsilon: f32, ) -> Result<(usize, Vec<f32>)>
Select action given state and exploration parameter Returns: (discrete_action, continuous_params, optional_log_prob)
Sourcefn save_with_metadata(
&self,
path: &Path,
training_episodes: usize,
hyperparameters: HashMap<String, f64>,
) -> Result<()>
fn save_with_metadata( &self, path: &Path, training_episodes: usize, hyperparameters: HashMap<String, f64>, ) -> Result<()>
Save model with metadata
Sourcefn train_step(
&mut self,
replay_buffer: &mut PrioritizedReplayBuffer,
batch_size: usize,
) -> Result<f32>
fn train_step( &mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize, ) -> Result<f32>
Train on a batch of experiences Returns: loss value
Sourcefn update_target_network(&mut self)
fn update_target_network(&mut self)
Update target network (if applicable, no-op for on-policy methods)
Sourcefn get_step_count(&self) -> usize
fn get_step_count(&self) -> usize
Get training step count
Sourcefn algorithm_type(&self) -> AlgorithmType
fn algorithm_type(&self) -> AlgorithmType
Get algorithm type