Skip to main content

luna_rs/
config.rs

1/// Model and runtime configuration for LUNA inference.
2///
3/// `ModelConfig` mirrors the Python LUNA hyperparameters.
4/// Field names match the HuggingFace `config.json` `"model"` sub-object.
5
6// ── ModelConfig ───────────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10    /// Patch size in time-samples (default 40).
11    #[serde(default = "default_patch_size")]
12    pub patch_size: usize,
13
14    /// Number of learned cross-attention queries (default 4).
15    #[serde(default = "default_num_queries")]
16    pub num_queries: usize,
17
18    /// Per-query / per-channel embedding dimension (default 64).
19    #[serde(default = "default_embed_dim")]
20    pub embed_dim: usize,
21
22    /// Number of Rotary Transformer encoder blocks (default 8).
23    #[serde(default = "default_depth")]
24    pub depth: usize,
25
26    /// Number of attention heads per transformer block.
27    /// Actual head count in the temporal encoder is `num_heads * num_queries`
28    /// because the effective dim is `embed_dim * num_queries`.
29    #[serde(default = "default_num_heads")]
30    pub num_heads: usize,
31
32    /// MLP expansion ratio inside transformer blocks (default 4.0).
33    #[serde(default = "default_mlp_ratio")]
34    pub mlp_ratio: f64,
35
36    /// Number of output classes.  0 = reconstruction (pre-training).
37    #[serde(default)]
38    pub num_classes: usize,
39
40    /// Drop-path rate for stochastic depth (default 0.0).
41    #[serde(default)]
42    pub drop_path: f64,
43
44    /// Layer normalisation epsilon (default 1e-5).
45    #[serde(default = "default_norm_eps")]
46    pub norm_eps: f64,
47}
48
49fn default_patch_size()  -> usize { 40 }
50fn default_num_queries() -> usize { 4 }
51fn default_embed_dim()   -> usize { 64 }
52fn default_depth()       -> usize { 8 }
53fn default_num_heads()   -> usize { 2 }
54fn default_mlp_ratio()   -> f64   { 4.0 }
55fn default_norm_eps()    -> f64   { 1e-5 }
56
57impl Default for ModelConfig {
58    fn default() -> Self {
59        Self {
60            patch_size:  default_patch_size(),
61            num_queries: default_num_queries(),
62            embed_dim:   default_embed_dim(),
63            depth:       default_depth(),
64            num_heads:   default_num_heads(),
65            mlp_ratio:   default_mlp_ratio(),
66            num_classes: 0,
67            drop_path:   0.0,
68            norm_eps:    default_norm_eps(),
69        }
70    }
71}
72
73impl ModelConfig {
74    /// Effective hidden dimension after query concatenation: `embed_dim * num_queries`.
75    pub fn hidden_dim(&self) -> usize {
76        self.embed_dim * self.num_queries
77    }
78
79    /// FFN hidden dimension inside transformer blocks.
80    pub fn ffn_hidden_dim(&self) -> usize {
81        (self.hidden_dim() as f64 * self.mlp_ratio) as usize
82    }
83
84    /// Attention head dimension in the temporal encoder.
85    pub fn head_dim(&self) -> usize {
86        self.hidden_dim() / (self.num_heads * self.num_queries)
87    }
88
89    /// Total number of attention heads in the temporal encoder.
90    pub fn total_heads(&self) -> usize {
91        self.num_heads * self.num_queries
92    }
93}
94
95// ── DataConfig ────────────────────────────────────────────────────────────────
96
97#[derive(Debug, Clone)]
98pub struct DataConfig {
99    /// Sampling rate after resampling (Hz).
100    pub sample_rate: f32,
101    /// Epoch duration in seconds.
102    pub epoch_dur: f32,
103    /// Bounding box for channel position normalisation (metres).
104    pub xyz_min: [f32; 3],
105    pub xyz_max: [f32; 3],
106}
107
108impl Default for DataConfig {
109    fn default() -> Self {
110        Self {
111            sample_rate: 256.0,
112            epoch_dur:   5.0,
113            xyz_min: [-0.12, -0.12, -0.12],
114            xyz_max: [ 0.12,  0.12,  0.12],
115        }
116    }
117}
118
119impl DataConfig {
120    /// Number of time samples per epoch.
121    pub fn epoch_samples(&self) -> usize {
122        (self.sample_rate * self.epoch_dur) as usize
123    }
124}