1use derive_new::new;
2
3use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
4
5use crate::TransitionBatch;
6
7#[derive(Clone, new)]
9pub struct ActionContext<A, C> {
10 pub context: C,
12 pub action: A,
14}
15
16pub trait PolicyState<B: Backend> {
18 type Record: Record<B>;
20
21 fn into_record(self) -> Self::Record;
23 fn load_record(&self, record: Self::Record) -> Self;
25}
26
27pub trait Policy<B: Backend>: Clone {
29 type Observation;
31 type ActionDistribution;
33 type Action;
35
36 type ActionContext;
38 type PolicyState: PolicyState<B>;
40
41 fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
43 fn action(
45 &mut self,
46 obs: Self::Observation,
47 deterministic: bool,
48 ) -> (Self::Action, Vec<Self::ActionContext>);
49
50 fn update(&mut self, update: Self::PolicyState);
52 fn state(&self) -> Self::PolicyState;
54
55 fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
57}
58
59pub trait Batchable: Sized {
61 fn batch(value: Vec<Self>) -> Self;
63 fn unbatch(self) -> Vec<Self>;
65}
66
67pub struct RLTrainOutput<TO, P> {
69 pub policy: P,
71 pub item: TO,
73}
74
75pub type LearnerTransitionBatch<B, P> =
77 TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
78
79pub trait PolicyLearner<B>
81where
82 B: AutodiffBackend,
83 <Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
84 <Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
85 <Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
86{
87 type TrainContext;
89 type InnerPolicy: Policy<B>;
91 type Record: Record<B>;
93
94 fn train(
96 &mut self,
97 input: LearnerTransitionBatch<B, Self::InnerPolicy>,
98 ) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
99 fn policy(&self) -> Self::InnerPolicy;
101 fn update_policy(&mut self, update: Self::InnerPolicy);
103
104 fn record(&self) -> Self::Record;
106 fn load_record(self, record: Self::Record) -> Self;
108}