1#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10 #[serde(default = "default_patch_size")]
12 pub patch_size: usize,
13
14 #[serde(default = "default_num_queries")]
16 pub num_queries: usize,
17
18 #[serde(default = "default_embed_dim")]
20 pub embed_dim: usize,
21
22 #[serde(default = "default_depth")]
24 pub depth: usize,
25
26 #[serde(default = "default_num_heads")]
30 pub num_heads: usize,
31
32 #[serde(default = "default_mlp_ratio")]
34 pub mlp_ratio: f64,
35
36 #[serde(default)]
38 pub num_classes: usize,
39
40 #[serde(default)]
42 pub drop_path: f64,
43
44 #[serde(default = "default_norm_eps")]
46 pub norm_eps: f64,
47}
48
49fn default_patch_size() -> usize { 40 }
50fn default_num_queries() -> usize { 4 }
51fn default_embed_dim() -> usize { 64 }
52fn default_depth() -> usize { 8 }
53fn default_num_heads() -> usize { 2 }
54fn default_mlp_ratio() -> f64 { 4.0 }
55fn default_norm_eps() -> f64 { 1e-5 }
56
57impl Default for ModelConfig {
58 fn default() -> Self {
59 Self {
60 patch_size: default_patch_size(),
61 num_queries: default_num_queries(),
62 embed_dim: default_embed_dim(),
63 depth: default_depth(),
64 num_heads: default_num_heads(),
65 mlp_ratio: default_mlp_ratio(),
66 num_classes: 0,
67 drop_path: 0.0,
68 norm_eps: default_norm_eps(),
69 }
70 }
71}
72
73impl ModelConfig {
74 pub fn hidden_dim(&self) -> usize {
76 self.embed_dim * self.num_queries
77 }
78
79 pub fn ffn_hidden_dim(&self) -> usize {
81 (self.hidden_dim() as f64 * self.mlp_ratio) as usize
82 }
83
84 pub fn head_dim(&self) -> usize {
86 self.hidden_dim() / (self.num_heads * self.num_queries)
87 }
88
89 pub fn total_heads(&self) -> usize {
91 self.num_heads * self.num_queries
92 }
93}
94
95#[derive(Debug, Clone)]
98pub struct DataConfig {
99 pub sample_rate: f32,
101 pub epoch_dur: f32,
103 pub xyz_min: [f32; 3],
105 pub xyz_max: [f32; 3],
106}
107
108impl Default for DataConfig {
109 fn default() -> Self {
110 Self {
111 sample_rate: 256.0,
112 epoch_dur: 5.0,
113 xyz_min: [-0.12, -0.12, -0.12],
114 xyz_max: [ 0.12, 0.12, 0.12],
115 }
116 }
117}
118
119impl DataConfig {
120 pub fn epoch_samples(&self) -> usize {
122 (self.sample_rate * self.epoch_dur) as usize
123 }
124}