Skip to main content

burn_rl/policy/
base.rs

1use derive_new::new;
2
3use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
4
5use crate::TransitionBatch;
6
7/// An action along with additional context about the decision.
8#[derive(Clone, new)]
9pub struct ActionContext<A, C> {
10    /// The context.
11    pub context: C,
12    /// The action.
13    pub action: A,
14}
15
16/// The state of a policy.
17pub trait PolicyState<B: Backend> {
18    /// The type of the record.
19    type Record: Record<B>;
20
21    /// Convert the state to a record.
22    fn into_record(self) -> Self::Record;
23    /// Load the state from a record.
24    fn load_record(&self, record: Self::Record) -> Self;
25}
26
27/// Trait for a RL policy.
28pub trait Policy<B: Backend>: Clone {
29    /// The observation given as input to the policy.
30    type Observation;
31    /// The action distribution parameters defining how the action will be sampled.
32    type ActionDistribution;
33    /// The action.
34    type Action;
35
36    /// Additional context on the policy's decision.
37    type ActionContext;
38    /// The current parameterization of the policy.
39    type PolicyState: PolicyState<B>;
40
41    /// Produces the action distribution from a batch of observations.
42    fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
43    /// Gives the action from a batch of observations.
44    fn action(
45        &mut self,
46        obs: Self::Observation,
47        deterministic: bool,
48    ) -> (Self::Action, Vec<Self::ActionContext>);
49
50    /// Update the policy's parameters.
51    fn update(&mut self, update: Self::PolicyState);
52    /// Returns the current parameterization.
53    fn state(&self) -> Self::PolicyState;
54
55    /// Loads the policy parameters from a record.
56    fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
57}
58
59/// Trait for a type that can be batched and unbatched (split).
60pub trait Batchable: Sized {
61    /// Create a batch from a list of items.
62    fn batch(value: Vec<Self>) -> Self;
63    /// Create a list from batched items.
64    fn unbatch(self) -> Vec<Self>;
65}
66
67/// A training output.
68pub struct RLTrainOutput<TO, P> {
69    /// The policy.
70    pub policy: P,
71    /// The item.
72    pub item: TO,
73}
74
75/// Batched transitions for a PolicyLearner.
76pub type LearnerTransitionBatch<B, P> =
77    TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
78
79/// Learner for a policy.
80pub trait PolicyLearner<B>
81where
82    B: AutodiffBackend,
83    <Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
84    <Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
85    <Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
86{
87    /// Additional context of a training step.
88    type TrainContext;
89    /// The policy to train.
90    type InnerPolicy: Policy<B>;
91    /// The record of the learner.
92    type Record: Record<B>;
93
94    /// Execute a training step on the policy.
95    fn train(
96        &mut self,
97        input: LearnerTransitionBatch<B, Self::InnerPolicy>,
98    ) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
99    /// Returns the learner's current policy.
100    fn policy(&self) -> Self::InnerPolicy;
101    /// Update the learner's policy.
102    fn update_policy(&mut self, update: Self::InnerPolicy);
103
104    /// Convert the learner's state into a record.
105    fn record(&self) -> Self::Record;
106    /// Load the learner's state from a record.
107    fn load_record(self, record: Self::Record) -> Self;
108}