1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Deserialize, Serialize)]
4pub struct TrainConfig {
5 pub version: u32,
6 pub env: EnvConfig,
7 pub net: RogueNetConfig,
8 pub optim: OptimizerConfig,
9 pub ppo: PPOConfig,
10 pub rollout: RolloutConfig,
11 pub eval: Option<EvalConfig>,
12 pub vf_net: Option<RogueNetConfig>,
13 pub name: String,
14 pub seed: u64,
15 pub total_timesteps: u64,
16 pub max_train_time: Option<u64>,
17 pub torch_deterministic: bool,
18 pub cuda: bool,
19 pub track: bool,
20 pub wandb_project_name: String,
21 pub wandb_entity: String,
22 pub capture_samples: Option<u64>,
23 pub capture_logits: bool,
24 pub capture_samples_subsample: u64,
25 pub trial: Option<String>,
26 pub data_dir: String,
27 pub cuda_empty_cache: bool,
28}
29
30#[derive(Debug, Deserialize, Serialize)]
31pub struct EnvConfig {
32 pub kwargs: String,
33 pub id: String,
34 pub validate: bool,
35}
36
37#[derive(Debug, Deserialize, Serialize, Clone)]
38pub struct RogueNetConfig {
40 pub embd_pdrop: f64,
42 pub resid_pdrop: f64,
44 pub attn_pdrop: f64,
46 pub n_layer: u32,
48 pub n_head: u32,
50 pub d_model: u32,
52 pub pooling: Option<String>,
54 pub relpos_encoding: Option<RelposEncodingConfig>,
56 pub d_qk: u32,
58 pub translation: Option<TranslationConfig>,
61}
62
63#[derive(Debug, Deserialize, Serialize, Clone)]
64pub struct TranslationConfig {
65 pub reference_entity: String,
66 pub position_features: Vec<String>,
67 pub rotation_vec_features: Option<Vec<String>>,
68 pub rotation_angle_feature: Option<String>,
69 pub add_dist_feature: bool,
70}
71
72#[derive(Debug, Deserialize, Serialize)]
73pub struct OptimizerConfig {
74 pub lr: f64,
75 pub bs: u32,
76 pub weight_decay: f64,
77 pub micro_bs: Option<u32>,
78 pub anneal_lr: bool,
79 pub update_epochs: u32,
80 pub max_grad_norm: f64,
81}
82
83#[derive(Debug, Deserialize, Serialize)]
84pub struct PPOConfig {
85 pub gae: bool,
86 pub gamma: f64,
87 pub gae_lambda: f64,
88 pub norm_adv: bool,
89 pub clip_coef: f64,
90 pub clip_vloss: bool,
91 pub ent_coef: f64,
92 pub vf_coef: f64,
93 pub target_kl: Option<f64>,
94 pub anneal_entropy: bool,
95}
96
97#[derive(Debug, Deserialize, Serialize)]
98pub struct RolloutConfig {
99 pub steps: u32,
100 pub num_envs: u32,
101 pub processes: u32,
102}
103
104#[derive(Debug, Deserialize, Serialize)]
105pub struct EvalConfig {
106 pub steps: u64,
107 pub interval: u64,
108 pub num_envs: u64,
109 pub processes: Option<u32>,
110 pub env: EnvConfig,
111 pub capture_videos: bool,
112 pub capture_samples: Option<String>,
113 pub capture_logits: bool,
114 pub capture_samples_subsample: u64,
115 pub run_on_first_step: bool,
116 pub opponent: String,
117 pub opponent_only: bool,
118}
119
120#[derive(Debug, Deserialize, Serialize, Clone)]
121pub struct RelposEncodingConfig {
122 pub extent: Vec<u32>,
123 pub position_features: Vec<String>,
124 pub scale: f32,
125 pub per_entity_values: bool,
126 pub exclude_entities: Vec<String>,
127 pub value_relpos_projection: bool,
128 pub key_relpos_projection: bool,
129 pub per_entity_projections: bool,
130 pub radial: bool,
131 pub distance: bool,
132 pub rotation_vec_features: Option<Vec<String>>,
133 pub rotation_angle_feature: Option<String>,
134 pub interpolate: bool,
135 pub value_gate: String,
136}