use serde::{Deserialize, Serialize};
use crate::constants::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensorEncoderConfig {
pub time_steps: usize,
pub num_channels: usize,
pub patch_h: usize,
pub patch_w: usize,
pub d_model: usize,
pub depth: usize,
pub num_heads: usize,
pub mlp_dim: usize,
pub dropout: f64,
pub pool_type: PoolType,
pub head_zeroinit: bool,
pub attn_chunk_size: usize,
}
impl Default for SensorEncoderConfig {
fn default() -> Self {
Self {
time_steps: TIME_STEPS,
num_channels: NUM_CHANNELS,
patch_h: PATCH_H,
patch_w: PATCH_W,
d_model: VIT_WIDTH,
depth: VIT_DEPTH,
num_heads: VIT_HEADS,
mlp_dim: VIT_MLP_DIM,
dropout: 0.0,
pool_type: PoolType::Map,
head_zeroinit: false,
attn_chunk_size: 64,
}
}
}
impl SensorEncoderConfig {
pub fn num_patches(&self) -> usize {
let pt = self.time_steps / self.patch_h;
let pc = (self.num_channels + self.patch_w - 1) / self.patch_w;
pt * pc
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModelSize {
#[default]
Tiny,
Small,
Base,
}
impl ModelSize {
pub fn d_model(self) -> usize {
match self {
Self::Tiny => 192,
Self::Small => 384,
Self::Base => VIT_WIDTH, }
}
pub fn depth(self) -> usize {
12 }
pub fn num_heads(self) -> usize {
match self {
Self::Tiny => 3,
Self::Small => 6,
Self::Base => VIT_HEADS, }
}
pub fn mlp_dim(self) -> usize {
self.d_model() * 4
}
pub fn sensor_encoder_config(self) -> SensorEncoderConfig {
SensorEncoderConfig {
d_model: self.d_model(),
depth: self.depth(),
num_heads: self.num_heads(),
mlp_dim: self.mlp_dim(),
attn_chunk_size: 64, ..SensorEncoderConfig::default()
}
}
pub fn text_encoder_config(self) -> TextEncoderConfig {
TextEncoderConfig {
d_model: self.d_model(),
depth: self.depth(),
num_heads: self.num_heads(),
mlp_dim: self.mlp_dim(),
out_dim: Some(self.d_model()), ..TextEncoderConfig::default()
}
}
pub fn sensorlm_config(self) -> SensorLMConfig {
SensorLMConfig {
sensor_encoder: self.sensor_encoder_config(),
text_encoder: self.text_encoder_config(),
embed_dim: self.d_model(),
..SensorLMConfig::default()
}
}
pub fn approx_params(self) -> &'static str {
match self {
Self::Tiny => "~11 M",
Self::Small => "~44 M",
Self::Base => "~205 M",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PoolType {
Map,
Gap,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextEncoderConfig {
pub vocab_size: usize,
pub max_seq_len: usize,
pub d_model: usize,
pub depth: usize,
pub num_heads: usize,
pub mlp_dim: usize,
pub dropout: f64,
pub out_dim: Option<usize>,
}
impl Default for TextEncoderConfig {
fn default() -> Self {
Self {
vocab_size: VOCAB_SIZE,
max_seq_len: 1024,
d_model: VIT_WIDTH,
depth: VIT_DEPTH,
num_heads: VIT_HEADS,
mlp_dim: VIT_MLP_DIM,
dropout: 0.0,
out_dim: Some(EMBED_DIM),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensorLMConfig {
pub sensor_encoder: SensorEncoderConfig,
pub text_encoder: TextEncoderConfig,
pub embed_dim: usize,
pub temperature_init: f32,
pub bias_init: f32,
}
impl Default for SensorLMConfig {
fn default() -> Self {
Self {
sensor_encoder: SensorEncoderConfig::default(),
text_encoder: TextEncoderConfig::default(),
embed_dim: EMBED_DIM,
temperature_init: TEMPERATURE_INIT,
bias_init: BIAS_INIT,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LrScheduleType {
RsqrtWithWarmupCooldown,
Cosine,
Constant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub model_size: ModelSize,
pub total_steps: usize,
pub batch_size: usize,
pub lr: f64,
pub weight_decay: f64,
pub beta2: f64,
pub beta1: f64,
pub epsilon: f64,
pub grad_clip_norm: f64,
pub warmup_fraction: f64,
pub cooldown_fraction: f64,
pub lr_schedule: LrScheduleType,
pub checkpoint_every: usize,
pub log_every: usize,
pub seed: u64,
pub caption_key: CaptionKey,
pub tokenizer_path: String,
pub artifact_dir: String,
pub data_dir: String,
pub num_workers: usize,
pub vram_gb: Option<f64>,
pub skip_vram_check: bool,
pub show_summary: bool,
}
impl Default for TrainingConfig {
fn default() -> Self {
let total_examples = TOTAL_EXAMPLES;
let batch_size = DEFAULT_BATCH_SIZE;
let total_steps = total_examples / batch_size;
Self {
model_size: ModelSize::default(), total_steps,
batch_size,
lr: DEFAULT_LR,
weight_decay: DEFAULT_WD,
beta1: 0.9,
beta2: ADAM_BETA2,
epsilon: 1e-8,
grad_clip_norm: GRAD_CLIP_NORM,
warmup_fraction: 0.2,
cooldown_fraction: 0.2,
lr_schedule: LrScheduleType::RsqrtWithWarmupCooldown,
checkpoint_every: 500,
log_every: 50,
seed: 0,
caption_key: CaptionKey::HighLevelSummary,
tokenizer_path: "tokenizer.model".to_string(),
artifact_dir: "./artifacts".to_string(),
data_dir: "./data".to_string(),
num_workers: 2,
vram_gb: None,
skip_vram_check: false,
show_summary: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CaptionKey {
LowLevel,
MiddleLevel,
HighLevelSummary,
HighLevelAll,
MiddleLow,
HighLow,
HighMiddle,
HighMiddleLow,
}
impl CaptionKey {
pub fn max_tokens(self) -> usize {
match self {
Self::LowLevel => 512,
Self::MiddleLevel => 512,
Self::HighLevelSummary => 256,
Self::HighLevelAll => 1024,
Self::MiddleLow => 1024,
Self::HighLow => 1024,
Self::HighMiddle => 512,
Self::HighMiddleLow => 1024,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
pub checkpoint: String,
pub tokenizer_path: String,
pub max_seq_len: usize,
pub batch_size: usize,
pub fp16: bool,
pub caption_key: CaptionKey,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
checkpoint: "./artifacts/model_final.bin".to_string(),
tokenizer_path: "tokenizer.model".to_string(),
max_seq_len: 256,
batch_size: 64,
fp16: false,
caption_key: CaptionKey::HighLevelSummary,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantScheme {
SymmetricPerTensor,
AsymmetricPerTensor,
SymmetricPerChannel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
pub source_checkpoint: String,
pub output_path: String,
pub calibration_data: String,
pub num_calibration_batches: usize,
pub calibration_batch_size: usize,
pub scheme: QuantScheme,
pub quantise_text_encoder: bool,
pub tokenizer_path: String,
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
source_checkpoint: "./artifacts/model_final.bin".to_string(),
output_path: "./artifacts/model_int8.bin".to_string(),
calibration_data: "./data/calibration.parquet".to_string(),
num_calibration_batches: 100,
calibration_batch_size: 32,
scheme: QuantScheme::SymmetricPerTensor,
quantise_text_encoder: true,
tokenizer_path: "tokenizer.model".to_string(),
}
}
}