use crate::agents::{buffers::HistoryDataBound, ActorMode};
use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum ExplorationRateSchedule {
Constant(f64),
LinearAnnealed {
start: f64,
end: f64,
period: u64,
},
}
impl Default for ExplorationRateSchedule {
fn default() -> Self {
Self::LinearAnnealed {
start: 1.0,
end: 0.1,
period: 10_000_000,
}
}
}
impl ExplorationRateSchedule {
#[must_use]
pub fn exploration_rate(&self, global_steps: u64, mode: ActorMode) -> f64 {
use ExplorationRateSchedule::{Constant, LinearAnnealed};
match (mode, self) {
(ActorMode::Evaluation, _) => 0.0,
(ActorMode::Training, Constant(rate)) => *rate,
(ActorMode::Training, LinearAnnealed { start, end, period }) => {
(global_steps as f64 / *period as f64).min(1.0) * (end - start) + start
}
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum DataCollectionSchedule {
Constant(usize),
FirstRest { first: usize, rest: usize },
}
impl DataCollectionSchedule {
#[must_use]
pub fn update_size(&self, global_steps: u64) -> HistoryDataBound {
use DataCollectionSchedule::*;
let min_steps = match self {
Constant(value) => *value,
FirstRest { first, rest: _ } if global_steps < *first as u64 => *first,
FirstRest { first: _, rest } => *rest,
};
HistoryDataBound::with_default_slack(min_steps)
}
}