burn_train/learner/rl/
components.rs1use std::marker::PhantomData;
2
3use burn_core::tensor::backend::AutodiffBackend;
4use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
5
6use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
7
8pub trait RLComponentsTypes {
10 type Backend: AutodiffBackend;
12 type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
14 type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
16 type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
18 type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
20 + Into<<Self::Policy as Policy<Self::Backend>>::Action>
21 + Clone
22 + Send
23 + 'static;
24
25 type Policy: Policy<
27 Self::Backend,
28 Observation = Self::PolicyObs,
29 ActionDistribution = Self::PolicyAD,
30 Action = Self::PolicyAction,
31 ActionContext = Self::ActionContext,
32 PolicyState = Self::PolicyState,
33 > + Send
34 + 'static;
35 type PolicyObs: Clone + Send + Batchable + 'static;
37 type PolicyAD: Clone + Send + Batchable;
39 type PolicyAction: Clone + Send + Batchable;
41 type ActionContext: ItemLazy + Clone + Send + 'static;
43 type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
45
46 type LearningAgent: PolicyLearner<
48 Self::Backend,
49 TrainContext = Self::TrainingOutput,
50 InnerPolicy = Self::Policy,
51 > + Send
52 + 'static;
53 type TrainingOutput: ItemLazy + Clone + Send;
55}
56
57pub struct RLComponentsMarker<B, E, EI, A> {
59 _backend: PhantomData<B>,
60 _env: PhantomData<E>,
61 _env_init: PhantomData<EI>,
62 _agent: PhantomData<A>,
63}
64
65impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
66where
67 B: AutodiffBackend,
68 E: Environment + 'static,
69 EI: EnvironmentInit<E> + Send + 'static,
70 A: PolicyLearner<B> + Send + 'static,
71 A::TrainContext: ItemLazy + Clone + Send,
72 A::InnerPolicy: Policy<B> + Send,
73 <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
74 <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
75 <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
76 <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
77 <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
78 E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
79 E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
80 + Into<<A::InnerPolicy as Policy<B>>::Action>
81 + Clone
82 + Send
83 + 'static,
84{
85 type Backend = B;
86 type Env = E;
87 type EnvInit = EI;
88 type LearningAgent = A;
89 type Policy = A::InnerPolicy;
90 type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
91 type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
92 type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
93 type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
94 type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
95 type TrainingOutput = A::TrainContext;
96 type State = E::State;
97 type Action = E::Action;
98}
99
100pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
101 <RLC as RLComponentsTypes>::Backend,
102>>::InnerPolicy;
103pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
105 RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
106 AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
107>;
108pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
110 <RLC as RLComponentsTypes>::Backend,
111>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
112pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
114 <RLC as RLComponentsTypes>::Backend,
115>>::Record;