#![warn(missing_docs)]
pub mod dummy;
pub mod error;
mod evaluator;
pub mod generic_replay_buffer;
pub mod record;
mod base;
pub use base::{
Act, Agent, Configurable, Env, ExperienceBufferBase, Info, NullReplayBuffer, Obs, Policy,
ReplayBufferBase, Step, StepProcessor, TransitionBatch,
};
mod trainer;
pub use evaluator::{DefaultEvaluator, Evaluator};
pub use trainer::{Sampler, Trainer, TrainerConfig};
pub mod test {
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
pub struct TestObs {
obs: usize,
}
impl crate::Obs for TestObs {
fn len(&self) -> usize {
1
}
}
pub struct TestObsBatch {
obs: Vec<usize>,
}
impl crate::generic_replay_buffer::BatchBase for TestObsBatch {
fn new(capacity: usize) -> Self {
Self {
obs: vec![0; capacity],
}
}
fn push(&mut self, i: usize, data: Self) {
self.obs[i] = data.obs[0];
}
fn sample(&self, ixs: &Vec<usize>) -> Self {
let obs = ixs.iter().map(|ix| self.obs[*ix]).collect();
Self { obs }
}
}
impl From<TestObs> for TestObsBatch {
fn from(obs: TestObs) -> Self {
Self { obs: vec![obs.obs] }
}
}
#[derive(Clone, Debug)]
pub struct TestAct {
act: usize,
}
impl crate::Act for TestAct {}
pub struct TestActBatch {
act: Vec<usize>,
}
impl From<TestAct> for TestActBatch {
fn from(act: TestAct) -> Self {
Self { act: vec![act.act] }
}
}
impl crate::generic_replay_buffer::BatchBase for TestActBatch {
fn new(capacity: usize) -> Self {
Self {
act: vec![0; capacity],
}
}
fn push(&mut self, i: usize, data: Self) {
self.act[i] = data.act[0];
}
fn sample(&self, ixs: &Vec<usize>) -> Self {
let act = ixs.iter().map(|ix| self.act[*ix]).collect();
Self { act }
}
}
pub struct TestInfo {}
impl crate::Info for TestInfo {}
pub struct TestEnv {
state_init: usize,
state: usize,
}
impl crate::Env for TestEnv {
type Config = usize;
type Obs = TestObs;
type Act = TestAct;
type Info = TestInfo;
fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> anyhow::Result<Self::Obs> {
self.state = self.state_init;
Ok(TestObs { obs: self.state })
}
fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result<Self::Obs> {
self.state = self.state_init;
Ok(TestObs { obs: self.state })
}
fn step_with_reset(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
where
Self: Sized,
{
self.state = self.state + a.act;
let step = crate::Step {
obs: TestObs { obs: self.state },
act: a.clone(),
reward: vec![0.0],
is_terminated: vec![0],
is_truncated: vec![0],
info: TestInfo {},
init_obs: Some(TestObs {
obs: self.state_init,
}),
};
return (step, crate::record::Record::empty());
}
fn step(&mut self, a: &Self::Act) -> (crate::Step<Self>, crate::record::Record)
where
Self: Sized,
{
self.state = self.state + a.act;
let step = crate::Step {
obs: TestObs { obs: self.state },
act: a.clone(),
reward: vec![0.0],
is_terminated: vec![0],
is_truncated: vec![0],
info: TestInfo {},
init_obs: Some(TestObs {
obs: self.state_init,
}),
};
return (step, crate::record::Record::empty());
}
fn build(config: &Self::Config, _seed: i64) -> anyhow::Result<Self>
where
Self: Sized,
{
Ok(Self {
state_init: *config,
state: 0,
})
}
}
type ReplayBuffer =
crate::generic_replay_buffer::SimpleReplayBuffer<TestObsBatch, TestActBatch>;
pub struct TestAgent {}
#[derive(Clone, Deserialize, Serialize)]
pub struct TestAgentConfig;
impl crate::Agent<TestEnv, ReplayBuffer> for TestAgent {
fn train(&mut self) {}
fn is_train(&self) -> bool {
false
}
fn eval(&mut self) {}
fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record {
crate::record::Record::empty()
}
fn save_params(&self, _path: &std::path::Path) -> anyhow::Result<Vec<std::path::PathBuf>> {
Ok(vec![])
}
fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> {
Ok(())
}
fn as_any_ref(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl crate::Policy<TestEnv> for TestAgent {
fn sample(&mut self, _obs: &TestObs) -> TestAct {
TestAct { act: 1 }
}
}
impl crate::Configurable for TestAgent {
type Config = TestAgentConfig;
fn build(_config: Self::Config) -> Self {
Self {}
}
}
}