1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
mod actor;
mod ppo;
mod reinforce;
mod trpo;
pub use actor::PolicyActor;
pub use ppo::{Ppo, PpoConfig};
pub use reinforce::{Reinforce, ReinforceConfig};
pub use trpo::{Trpo, TrpoConfig};
use super::features::HistoryFeatures;
use super::WithCpuCopy;
use crate::logging::StatsLogger;
use crate::spaces::{NonEmptyFeatures, ParameterizedDistributionSpace};
use crate::torch::modules::{AsModule, Module, SeqIterative, SeqPacked};
use crate::torch::packed::PackedTensor;
use tch::{Device, Tensor};
pub trait Policy: AsModule<Module = Self::PolicyModule> {
type PolicyModule: Module + SeqPacked + SeqIterative;
fn update<AS: ParameterizedDistributionSpace<Tensor> + ?Sized>(
&mut self,
features: &dyn HistoryFeatures,
advantages: PackedTensor,
action_space: &AS,
logger: &mut dyn StatsLogger,
);
fn actor<OS, AS>(
&self,
observation_space: NonEmptyFeatures<OS>,
action_space: AS,
) -> PolicyActor<OS, AS, Self::Module> {
PolicyActor::new(
observation_space,
action_space,
self.as_module().shallow_clone(),
)
}
}
pub trait BuildPolicy {
type Policy: Policy;
fn build_policy(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Policy;
}
impl<P: Policy> Policy for WithCpuCopy<P> {
type PolicyModule = P::PolicyModule;
fn update<AS: ParameterizedDistributionSpace<Tensor> + ?Sized>(
&mut self,
features: &dyn HistoryFeatures,
advantages: PackedTensor,
action_space: &AS,
logger: &mut dyn StatsLogger,
) {
self.as_inner_mut()
.update(features, advantages, action_space, logger)
}
fn actor<OS, AS>(
&self,
observation_space: NonEmptyFeatures<OS>,
action_space: AS,
) -> PolicyActor<OS, AS, Self::Module> {
PolicyActor::new(
observation_space,
action_space,
self.shallow_clone_module_cpu(),
)
}
}