use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use crate::tokenizer::TokenizerConfig;
use burn_dragon_core::{
AttentionResidualConfig, BdhInitializationConfig, BlockAttentionResidualConfig,
ClockedSlowMemoryConfig, DragonNormConfig, LanguageHeadConfig, LatentFanoutScheduleConfig,
LowBitQuantizationConfig, LowBitRhoConfig, MambaSequenceConfig, ManifoldHyperConnectionsConfig,
ResidualConnectorKind, RotaryEmbedding, SequenceKernelConfig, SummaryMemoryConfig,
YNeuronRecurrenceConfig,
};
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct GenerationConfig {
pub prompt: String,
#[serde(default)]
pub max_tokens: Option<i64>,
#[serde(default)]
pub max_chars: Option<usize>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default)]
pub top_k: Option<usize>,
#[serde(default = "default_context_strategy")]
pub context_strategy: ContextStrategyConfig,
#[serde(default)]
pub prompt_tokenizer: GenerationTokenizerSourceConfig,
#[serde(default)]
pub decode_tokenizer: GenerationTokenizerSourceConfig,
#[serde(default)]
pub output_format: GenerationOutputFormat,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
#[serde(tag = "source", rename_all = "snake_case")]
pub enum GenerationTokenizerSourceConfig {
#[default]
Dataset,
Config {
#[serde(default)]
cache_dir: Option<PathBuf>,
#[serde(flatten)]
tokenizer: TokenizerConfig,
},
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum GenerationOutputFormat {
#[default]
Auto,
DecodedText,
TokenIds,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContextStrategyConfig {
#[default]
Infinite,
Sliding {
window: usize,
},
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
pub struct ModelOverrides {
pub n_layer: Option<usize>,
pub n_embd: Option<usize>,
pub n_head: Option<usize>,
pub mlp_internal_dim_multiplier: Option<usize>,
pub language_head: Option<LanguageHeadConfig>,
#[serde(alias = "neuron_space_dim")]
pub latent_total: Option<usize>,
#[serde(alias = "init")]
pub initialization: Option<BdhInitializationConfig>,
pub sequence_kernel: Option<SequenceKernelConfig>,
pub mamba: Option<MambaSequenceConfig>,
pub residual_connector: Option<ResidualConnectorKind>,
pub attention_residual: Option<AttentionResidualConfig>,
pub block_attention_residual: Option<BlockAttentionResidualConfig>,
pub latent_fanout_schedule: Option<LatentFanoutScheduleConfig>,
pub relu_threshold: Option<f32>,
pub dropout: Option<f64>,
pub normalization: Option<DragonNormConfig>,
pub fused_kernels: Option<bool>,
pub block_size: Option<usize>,
#[serde(alias = "rollout_fast_steps")]
pub rollout_fast_steps_per_slow_step: Option<usize>,
pub rotary_embedding: Option<RotaryEmbedding>,
#[serde(alias = "y_sparse_recurrence")]
pub y_neuron_recurrence: Option<YNeuronRecurrenceConfig>,
pub clocked_slow_memory: Option<ClockedSlowMemoryConfig>,
pub summary_memory: Option<SummaryMemoryConfig>,
pub mhc: Option<ManifoldHyperConnectionsConfig>,
pub quant: Option<LowBitQuantizationConfig>,
pub rho: Option<LowBitRhoConfig>,
}
fn default_context_strategy() -> ContextStrategyConfig {
ContextStrategyConfig::Infinite
}
fn default_temperature() -> f32 {
1.0
}