#![allow(clippy::use_self)]
mod bandits;
pub mod buffers;
pub mod finite;
mod meta;
mod pair;
mod random;
mod serial;
mod tabular;
#[cfg(test)]
pub mod testing;
pub use bandits::{
BetaThompsonSamplingAgent, BetaThompsonSamplingAgentConfig, UCB1Agent, UCB1AgentConfig,
};
pub use buffers::{
HistoryDataBound, WriteExperience, WriteExperienceError, WriteExperienceIncremental,
};
pub use meta::{ResettingMetaAgent, ResettingMetaAgentConfig};
pub use pair::AgentPair;
pub use random::{RandomAgent, RandomAgentConfig};
pub use serial::SerialActorAgent;
pub use tabular::{TabularQLearningAgent, TabularQLearningAgentConfig};
use crate::envs::EnvStructure;
use crate::logging::StatsLogger;
use crate::spaces::Space;
use crate::Prng;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tch::TchError;
use thiserror::Error;
pub trait Agent<O, A> {
type Actor: Actor<O, A>;
fn actor(&self, mode: ActorMode) -> Self::Actor;
}
macro_rules! impl_wrapped_agent {
($wrapper:ty) => {
impl<T, O, A> Agent<O, A> for $wrapper
where
T: Agent<O, A> + ?Sized,
{
type Actor = T::Actor;
fn actor(&self, mode: ActorMode) -> Self::Actor {
T::actor(self, mode)
}
}
};
}
impl_wrapped_agent!(&'_ mut T);
impl_wrapped_agent!(Box<T>);
pub trait Actor<O: ?Sized, A> {
type EpisodeState;
fn initial_state(&self, rng: &mut Prng) -> Self::EpisodeState;
fn act(&self, episode_state: &mut Self::EpisodeState, observation: &O, rng: &mut Prng) -> A;
}
macro_rules! impl_wrapped_actor {
($wrapper:ty) => {
impl<T, O, A> Actor<O, A> for $wrapper
where
T: Actor<O, A> + ?Sized,
O: ?Sized,
{
type EpisodeState = T::EpisodeState;
fn initial_state(&self, rng: &mut Prng) -> Self::EpisodeState {
T::initial_state(self, rng)
}
fn act(
&self,
episode_state: &mut Self::EpisodeState,
observation: &O,
rng: &mut Prng,
) -> A {
T::act(self, episode_state, observation, rng)
}
}
};
}
impl_wrapped_actor!(&'_ T);
impl_wrapped_actor!(Box<T>);
impl_wrapped_actor!(Arc<T>);
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ActorMode {
Training,
Evaluation,
}
pub trait BatchUpdate<O, A> {
type Feedback;
type HistoryBuffer: WriteExperience<O, A, Self::Feedback>;
fn buffer(&self) -> Self::HistoryBuffer;
fn min_update_size(&self) -> HistoryDataBound;
fn batch_update<'a, I>(&mut self, buffers: I, logger: &mut dyn StatsLogger)
where
I: IntoIterator<Item = &'a mut Self::HistoryBuffer>,
Self::HistoryBuffer: 'a;
}
pub trait BuildAgent<OS: Space, AS: Space, FS: Space> {
type Agent: Agent<OS::Element, AS::Element> + BatchUpdate<OS::Element, AS::Element>;
fn build_agent(
&self,
env: &dyn EnvStructure<ObservationSpace = OS, ActionSpace = AS, FeedbackSpace = FS>,
rng: &mut Prng,
) -> Result<Self::Agent, BuildAgentError>;
}
#[derive(Error, Debug)]
pub enum BuildAgentError {
#[error("space bound(s) are too loose for this agent")]
InvalidSpaceBounds,
#[error("reward range must not be unbounded")]
UnboundedReward,
#[error(transparent)]
TorchError(#[from] TchError),
}