burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use std::path::PathBuf;

use serde::{Deserialize, Serialize};

use crate::tokenizer::TokenizerConfig;
use burn_dragon_core::{
    AttentionResidualConfig, BdhInitializationConfig, BlockAttentionResidualConfig,
    ClockedSlowMemoryConfig, DragonNormConfig, LanguageHeadConfig, LatentFanoutScheduleConfig,
    LowBitQuantizationConfig, LowBitRhoConfig, MambaSequenceConfig, ManifoldHyperConnectionsConfig,
    ResidualConnectorKind, RotaryEmbedding, SequenceKernelConfig, SummaryMemoryConfig,
    YNeuronRecurrenceConfig,
};

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct GenerationConfig {
    pub prompt: String,
    #[serde(default)]
    pub max_tokens: Option<i64>,
    #[serde(default)]
    pub max_chars: Option<usize>,
    #[serde(default = "default_temperature")]
    pub temperature: f32,
    #[serde(default)]
    pub top_k: Option<usize>,
    #[serde(default = "default_context_strategy")]
    pub context_strategy: ContextStrategyConfig,
    #[serde(default)]
    pub prompt_tokenizer: GenerationTokenizerSourceConfig,
    #[serde(default)]
    pub decode_tokenizer: GenerationTokenizerSourceConfig,
    #[serde(default)]
    pub output_format: GenerationOutputFormat,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
#[serde(tag = "source", rename_all = "snake_case")]
pub enum GenerationTokenizerSourceConfig {
    #[default]
    Dataset,
    Config {
        #[serde(default)]
        cache_dir: Option<PathBuf>,
        #[serde(flatten)]
        tokenizer: TokenizerConfig,
    },
}

#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum GenerationOutputFormat {
    #[default]
    Auto,
    DecodedText,
    TokenIds,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContextStrategyConfig {
    #[default]
    Infinite,
    Sliding {
        window: usize,
    },
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
pub struct ModelOverrides {
    pub n_layer: Option<usize>,
    pub n_embd: Option<usize>,
    pub n_head: Option<usize>,
    pub mlp_internal_dim_multiplier: Option<usize>,
    pub language_head: Option<LanguageHeadConfig>,
    #[serde(alias = "neuron_space_dim")]
    pub latent_total: Option<usize>,
    #[serde(alias = "init")]
    pub initialization: Option<BdhInitializationConfig>,
    pub sequence_kernel: Option<SequenceKernelConfig>,
    pub mamba: Option<MambaSequenceConfig>,
    pub residual_connector: Option<ResidualConnectorKind>,
    pub attention_residual: Option<AttentionResidualConfig>,
    pub block_attention_residual: Option<BlockAttentionResidualConfig>,
    pub latent_fanout_schedule: Option<LatentFanoutScheduleConfig>,
    pub relu_threshold: Option<f32>,
    pub dropout: Option<f64>,
    pub normalization: Option<DragonNormConfig>,
    pub fused_kernels: Option<bool>,
    pub block_size: Option<usize>,
    #[serde(alias = "rollout_fast_steps")]
    pub rollout_fast_steps_per_slow_step: Option<usize>,
    pub rotary_embedding: Option<RotaryEmbedding>,
    #[serde(alias = "y_sparse_recurrence")]
    pub y_neuron_recurrence: Option<YNeuronRecurrenceConfig>,
    pub clocked_slow_memory: Option<ClockedSlowMemoryConfig>,
    pub summary_memory: Option<SummaryMemoryConfig>,
    pub mhc: Option<ManifoldHyperConnectionsConfig>,
    pub quant: Option<LowBitQuantizationConfig>,
    pub rho: Option<LowBitRhoConfig>,
}

fn default_context_strategy() -> ContextStrategyConfig {
    ContextStrategyConfig::Infinite
}

fn default_temperature() -> f32 {
    1.0
}