#[derive(Debug, Clone, serde::Deserialize)]
pub struct ModelConfig {
pub dim: usize, pub n_layers: usize, pub head_dim: usize,
pub input_dim: usize, pub encoder_output_dim: usize,
pub encoder_latent_downsample_factor: usize,
#[serde(default = "default_t_dim")]
pub t_dim: usize,
pub max_seqlen: usize, pub rope_dim: usize, pub rope_theta: f64,
#[serde(default = "default_norm_eps")]
pub norm_eps: f64,
#[serde(default)]
pub ffn_dim_multiplier: Option<f64>,
#[serde(default = "default_multiple_of")]
pub multiple_of: usize,
pub stft_global_sigma: f64, }
fn default_t_dim() -> usize { 64 }
fn default_norm_eps() -> f64 { 1e-5 }
fn default_multiple_of() -> usize { 256 }
impl ModelConfig {
pub fn n_heads_fallback(&self) -> usize { self.dim / self.head_dim }
pub fn ffn_hidden_dim(&self) -> usize {
let mut h = (2 * 4 * self.dim) / 3;
if let Some(m) = self.ffn_dim_multiplier {
h = (m * h as f64) as usize;
}
self.multiple_of * ((h + self.multiple_of - 1) / self.multiple_of)
}
}
#[derive(Debug, Clone)]
pub struct InferConfig {
pub sample_steps: usize, pub cfg: f32, pub data_norm: f32, }
impl Default for InferConfig {
fn default() -> Self {
Self { sample_steps: 50, cfg: 1.0, data_norm: 10.0 }
}
}
#[derive(Debug, Clone)]
pub struct DataConfig {
pub num_fine_time_pts: usize, pub num_bins: usize, pub xyz_min: [f32; 3], pub xyz_max: [f32; 3], }
impl Default for DataConfig {
fn default() -> Self {
Self {
num_fine_time_pts: 32,
num_bins: 50,
xyz_min: [-0.12, -0.12, -0.12],
xyz_max: [ 0.12, 0.12, 0.12],
}
}
}