use derive_new::new;
use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
use crate::TransitionBatch;
#[derive(Clone, new)]
pub struct ActionContext<A, C> {
pub context: C,
pub action: A,
}
pub trait PolicyState<B: Backend> {
type Record: Record<B>;
fn into_record(self) -> Self::Record;
fn load_record(&self, record: Self::Record) -> Self;
}
pub trait Policy<B: Backend>: Clone {
type Observation;
type ActionDistribution;
type Action;
type ActionContext;
type PolicyState: PolicyState<B>;
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
fn action(
&mut self,
obs: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>);
fn update(&mut self, update: Self::PolicyState);
fn state(&self) -> Self::PolicyState;
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
}
pub trait Batchable: Sized {
fn batch(value: Vec<Self>) -> Self;
fn unbatch(self) -> Vec<Self>;
}
pub struct RLTrainOutput<TO, P> {
pub policy: P,
pub item: TO,
}
pub type LearnerTransitionBatch<B, P> =
TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
pub trait PolicyLearner<B>
where
B: AutodiffBackend,
<Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
{
type TrainContext;
type InnerPolicy: Policy<B>;
type Record: Record<B>;
fn train(
&mut self,
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
fn policy(&self) -> Self::InnerPolicy;
fn update_policy(&mut self, update: Self::InnerPolicy);
fn record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}