1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct TrainConfig {
pub version: u32,
pub env: EnvConfig,
pub net: RogueNetConfig,
pub optim: OptimizerConfig,
pub ppo: PPOConfig,
pub rollout: RolloutConfig,
pub eval: Option<EvalConfig>,
pub vf_net: Option<RogueNetConfig>,
pub name: String,
pub seed: u64,
pub total_timesteps: u64,
pub max_train_time: Option<u64>,
pub torch_deterministic: bool,
pub cuda: bool,
pub track: bool,
pub wandb_project_name: String,
pub wandb_entity: String,
pub capture_samples: Option<u64>,
pub capture_logits: bool,
pub capture_samples_subsample: u64,
pub trial: Option<String>,
pub data_dir: String,
pub cuda_empty_cache: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct EnvConfig {
pub kwargs: String,
pub id: String,
pub validate: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RogueNetConfig {
pub embd_pdrop: f64,
pub resid_pdrop: f64,
pub attn_pdrop: f64,
pub n_layer: u32,
pub n_head: u32,
pub d_model: u32,
pub pooling: Option<String>,
pub relpos_encoding: Option<RelposEncodingConfig>,
pub d_qk: u32,
pub translation: Option<TranslationConfig>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TranslationConfig {
pub reference_entity: String,
pub position_features: Vec<String>,
pub rotation_vec_features: Option<Vec<String>>,
pub rotation_angle_feature: Option<String>,
pub add_dist_feature: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OptimizerConfig {
pub lr: f64,
pub bs: u32,
pub weight_decay: f64,
pub micro_bs: Option<u32>,
pub anneal_lr: bool,
pub update_epochs: u32,
pub max_grad_norm: f64,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct PPOConfig {
pub gae: bool,
pub gamma: f64,
pub gae_lambda: f64,
pub norm_adv: bool,
pub clip_coef: f64,
pub clip_vloss: bool,
pub ent_coef: f64,
pub vf_coef: f64,
pub target_kl: Option<f64>,
pub anneal_entropy: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RolloutConfig {
pub steps: u32,
pub num_envs: u32,
pub processes: u32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct EvalConfig {
pub steps: u64,
pub interval: u64,
pub num_envs: u64,
pub processes: Option<u32>,
pub env: EnvConfig,
pub capture_videos: bool,
pub capture_samples: Option<String>,
pub capture_logits: bool,
pub capture_samples_subsample: u64,
pub run_on_first_step: bool,
pub opponent: String,
pub opponent_only: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RelposEncodingConfig {
pub extent: Vec<u32>,
pub position_features: Vec<String>,
pub scale: f32,
pub per_entity_values: bool,
pub exclude_entities: Vec<String>,
pub value_relpos_projection: bool,
pub key_relpos_projection: bool,
pub per_entity_projections: bool,
pub radial: bool,
pub distance: bool,
pub rotation_vec_features: Option<Vec<String>>,
pub rotation_angle_feature: Option<String>,
pub interpolate: bool,
pub value_gate: String,
}