Skip to main content

burn_train/learner/rl/
components.rs

1use std::marker::PhantomData;
2
3use burn_core::tensor::backend::AutodiffBackend;
4use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
5
6use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
7
8/// All components used by the reinforcement learning paradigm, grouped in one trait.
9pub trait RLComponentsTypes {
10    /// The backend used for training.
11    type Backend: AutodiffBackend;
12    /// The learning environment.
13    type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
14    /// Specifies how to initialize the environment.
15    type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
16    /// The type of the environment state.
17    type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
18    /// The type of the environment action.
19    type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
20        + Into<<Self::Policy as Policy<Self::Backend>>::Action>
21        + Clone
22        + Send
23        + 'static;
24
25    /// The policy used to take actions in the environment.
26    type Policy: Policy<
27            Self::Backend,
28            Observation = Self::PolicyObs,
29            ActionDistribution = Self::PolicyAD,
30            Action = Self::PolicyAction,
31            ActionContext = Self::ActionContext,
32            PolicyState = Self::PolicyState,
33        > + Send
34        + 'static;
35    /// The policy's observation type.
36    type PolicyObs: Clone + Send + Batchable + 'static;
37    /// The policy's action distribution type.
38    type PolicyAD: Clone + Send + Batchable;
39    /// The policy's action type.
40    type PolicyAction: Clone + Send + Batchable;
41    /// Additional data as context for an agent's action.
42    type ActionContext: ItemLazy + Clone + Send + 'static;
43    /// The state of the parameterized policy.
44    type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
45
46    /// The learning agent.
47    type LearningAgent: PolicyLearner<
48            Self::Backend,
49            TrainContext = Self::TrainingOutput,
50            InnerPolicy = Self::Policy,
51        > + Send
52        + 'static;
53    /// The output data of a training step.
54    type TrainingOutput: ItemLazy + Clone + Send;
55}
56
57/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.
58pub struct RLComponentsMarker<B, E, EI, A> {
59    _backend: PhantomData<B>,
60    _env: PhantomData<E>,
61    _env_init: PhantomData<EI>,
62    _agent: PhantomData<A>,
63}
64
65impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
66where
67    B: AutodiffBackend,
68    E: Environment + 'static,
69    EI: EnvironmentInit<E> + Send + 'static,
70    A: PolicyLearner<B> + Send + 'static,
71    A::TrainContext: ItemLazy + Clone + Send,
72    A::InnerPolicy: Policy<B> + Send,
73    <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
74    <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
75    <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
76    <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
77    <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
78    E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
79    E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
80        + Into<<A::InnerPolicy as Policy<B>>::Action>
81        + Clone
82        + Send
83        + 'static,
84{
85    type Backend = B;
86    type Env = E;
87    type EnvInit = EI;
88    type LearningAgent = A;
89    type Policy = A::InnerPolicy;
90    type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
91    type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
92    type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
93    type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
94    type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
95    type TrainingOutput = A::TrainContext;
96    type State = E::State;
97    type Action = E::Action;
98}
99
100pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
101    <RLC as RLComponentsTypes>::Backend,
102>>::InnerPolicy;
103/// The event processor type for reinforcement learning.
104pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
105    RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
106    AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
107>;
108/// The record of the policy.
109pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
110    <RLC as RLComponentsTypes>::Backend,
111>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
112/// The record of the learning agent.
113pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
114    <RLC as RLComponentsTypes>::Backend,
115>>::Record;