rogue_net/
config.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Deserialize, Serialize)]
4pub struct TrainConfig {
5    pub version: u32,
6    pub env: EnvConfig,
7    pub net: RogueNetConfig,
8    pub optim: OptimizerConfig,
9    pub ppo: PPOConfig,
10    pub rollout: RolloutConfig,
11    pub eval: Option<EvalConfig>,
12    pub vf_net: Option<RogueNetConfig>,
13    pub name: String,
14    pub seed: u64,
15    pub total_timesteps: u64,
16    pub max_train_time: Option<u64>,
17    pub torch_deterministic: bool,
18    pub cuda: bool,
19    pub track: bool,
20    pub wandb_project_name: String,
21    pub wandb_entity: String,
22    pub capture_samples: Option<u64>,
23    pub capture_logits: bool,
24    pub capture_samples_subsample: u64,
25    pub trial: Option<String>,
26    pub data_dir: String,
27    pub cuda_empty_cache: bool,
28}
29
30#[derive(Debug, Deserialize, Serialize)]
31pub struct EnvConfig {
32    pub kwargs: String,
33    pub id: String,
34    pub validate: bool,
35}
36
37#[derive(Debug, Deserialize, Serialize, Clone)]
38/// Network architecture hyperparameters for RogueNet.
39pub struct RogueNetConfig {
40    /// Dropout probability for the embedding layer.
41    pub embd_pdrop: f64,
42    /// Dropout probability on attention block output.
43    pub resid_pdrop: f64,
44    /// Dropout probability on attention probabilities.
45    pub attn_pdrop: f64,
46    /// Number of transformer blocks.
47    pub n_layer: u32,
48    /// Number of attention heads.
49    pub n_head: u32,
50    /// Model width.
51    pub d_model: u32,
52    /// Replace attention with a pooling layer.
53    pub pooling: Option<String>,
54    /// Settings for relative position encoding.
55    pub relpos_encoding: Option<RelposEncodingConfig>,
56    /// Width of keys and queries used in entity-selection heads.
57    pub d_qk: u32,
58    /// Configuration for translating positions of all entities with respect
59    /// to a reference entity.
60    pub translation: Option<TranslationConfig>,
61}
62
63#[derive(Debug, Deserialize, Serialize, Clone)]
64pub struct TranslationConfig {
65    pub reference_entity: String,
66    pub position_features: Vec<String>,
67    pub rotation_vec_features: Option<Vec<String>>,
68    pub rotation_angle_feature: Option<String>,
69    pub add_dist_feature: bool,
70}
71
72#[derive(Debug, Deserialize, Serialize)]
73pub struct OptimizerConfig {
74    pub lr: f64,
75    pub bs: u32,
76    pub weight_decay: f64,
77    pub micro_bs: Option<u32>,
78    pub anneal_lr: bool,
79    pub update_epochs: u32,
80    pub max_grad_norm: f64,
81}
82
83#[derive(Debug, Deserialize, Serialize)]
84pub struct PPOConfig {
85    pub gae: bool,
86    pub gamma: f64,
87    pub gae_lambda: f64,
88    pub norm_adv: bool,
89    pub clip_coef: f64,
90    pub clip_vloss: bool,
91    pub ent_coef: f64,
92    pub vf_coef: f64,
93    pub target_kl: Option<f64>,
94    pub anneal_entropy: bool,
95}
96
97#[derive(Debug, Deserialize, Serialize)]
98pub struct RolloutConfig {
99    pub steps: u32,
100    pub num_envs: u32,
101    pub processes: u32,
102}
103
104#[derive(Debug, Deserialize, Serialize)]
105pub struct EvalConfig {
106    pub steps: u64,
107    pub interval: u64,
108    pub num_envs: u64,
109    pub processes: Option<u32>,
110    pub env: EnvConfig,
111    pub capture_videos: bool,
112    pub capture_samples: Option<String>,
113    pub capture_logits: bool,
114    pub capture_samples_subsample: u64,
115    pub run_on_first_step: bool,
116    pub opponent: String,
117    pub opponent_only: bool,
118}
119
120#[derive(Debug, Deserialize, Serialize, Clone)]
121pub struct RelposEncodingConfig {
122    pub extent: Vec<u32>,
123    pub position_features: Vec<String>,
124    pub scale: f32,
125    pub per_entity_values: bool,
126    pub exclude_entities: Vec<String>,
127    pub value_relpos_projection: bool,
128    pub key_relpos_projection: bool,
129    pub per_entity_projections: bool,
130    pub radial: bool,
131    pub distance: bool,
132    pub rotation_vec_features: Option<Vec<String>>,
133    pub rotation_angle_feature: Option<String>,
134    pub interpolate: bool,
135    pub value_gate: String,
136}