Skip to main content

RLAgent

Trait RLAgent 

Source
pub trait RLAgent: Send + Sync {
    // Required methods
    fn select_action(
        &self,
        state: &[f32],
        epsilon: f32,
    ) -> Result<(usize, Vec<f32>)>;
    fn save_with_metadata(
        &self,
        path: &Path,
        training_episodes: usize,
        hyperparameters: HashMap<String, f64>,
    ) -> Result<()>;
    fn save(&self, path: &Path) -> Result<()>;
    fn train_step(
        &mut self,
        replay_buffer: &mut PrioritizedReplayBuffer,
        batch_size: usize,
    ) -> Result<f32>;
    fn update_target_network(&mut self);
    fn get_step_count(&self) -> usize;
    fn algorithm_type(&self) -> AlgorithmType;
    fn get_info(&self) -> AgentInfo;
}
Expand description

Common trait for all RL agents

Required Methods§

Source

fn select_action( &self, state: &[f32], epsilon: f32, ) -> Result<(usize, Vec<f32>)>

Select action given state and exploration parameter Returns: (discrete_action, continuous_params, optional_log_prob)

Source

fn save_with_metadata( &self, path: &Path, training_episodes: usize, hyperparameters: HashMap<String, f64>, ) -> Result<()>

Save model with metadata

Source

fn save(&self, path: &Path) -> Result<()>

Save model to disk (uses default metadata)

Source

fn train_step( &mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize, ) -> Result<f32>

Train on a batch of experiences Returns: loss value

Source

fn update_target_network(&mut self)

Update target network (if applicable, no-op for on-policy methods)

Source

fn get_step_count(&self) -> usize

Get training step count

Source

fn algorithm_type(&self) -> AlgorithmType

Get algorithm type

Source

fn get_info(&self) -> AgentInfo

Get algorithm-specific info for logging

Implementors§