use crate::gpu::GpuModelConfig;
#[derive(Debug, Clone, PartialEq)]
pub enum GenerationStep {
ProcessPrompt {
tokens: Vec<usize>,
},
GenerateToken {
tokens: Vec<usize>,
use_greedy_optimization: bool,
},
Done {
tokens: Vec<usize>,
},
}
#[derive(Debug, Clone)]
pub struct GenerationConfig {
pub max_tokens: usize,
pub vocab_size: usize,
pub greedy_vocab_threshold: usize,
pub stop_token: Option<usize>,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_tokens: 100,
vocab_size: 32000,
greedy_vocab_threshold: 8192,
stop_token: None,
}
}
}
impl GenerationConfig {
#[must_use]
pub fn from_model(model_config: &GpuModelConfig, max_tokens: usize) -> Self {
Self {
max_tokens,
vocab_size: model_config.vocab_size,
greedy_vocab_threshold: 8192,
stop_token: None,
}
}
#[must_use]
pub fn use_greedy_path(&self) -> bool {
self.vocab_size > self.greedy_vocab_threshold
}
}
#[derive(Debug, Clone)]
pub struct BatchPlanner {
config: GenerationConfig,
state: PlannerState,
tokens: Vec<usize>,
generated_count: usize,
}
#[derive(Debug, Clone, PartialEq)]
enum PlannerState {
Initial,
Generating,
Done,
}
impl BatchPlanner {
#[must_use]
pub fn new(config: GenerationConfig) -> Self {
Self {
config,
state: PlannerState::Initial,
tokens: Vec::new(),
generated_count: 0,
}
}
#[must_use]
pub fn plan_next(&mut self, last_token: Option<usize>) -> GenerationStep {
match self.state {
PlannerState::Initial => {
GenerationStep::Done {
tokens: self.tokens.clone(),
}
},
PlannerState::Generating => {
if let Some(token) = last_token {
self.tokens.push(token);
self.generated_count += 1;
if self.should_stop(token) {
self.state = PlannerState::Done;
return GenerationStep::Done {
tokens: self.tokens.clone(),
};
}
}
GenerationStep::GenerateToken {
tokens: self.tokens.clone(),
use_greedy_optimization: self.config.use_greedy_path(),
}
},
PlannerState::Done => GenerationStep::Done {
tokens: self.tokens.clone(),
},
}
}
#[must_use]
pub fn start_with_prompt(&mut self, prompt: &[usize]) -> GenerationStep {
self.tokens = prompt.to_vec();
self.state = PlannerState::Generating;
self.generated_count = 0;
GenerationStep::ProcessPrompt {
tokens: self.tokens.clone(),
}
}
fn should_stop(&self, last_token: usize) -> bool {
if self.generated_count >= self.config.max_tokens {
return true;
}
if let Some(stop) = self.config.stop_token {
if last_token == stop {
return true;
}
}
false
}
#[must_use]
pub fn tokens(&self) -> &[usize] {
&self.tokens
}
#[must_use]
pub fn generated_count(&self) -> usize {
self.generated_count
}
#[must_use]
pub fn is_done(&self) -> bool {
self.state == PlannerState::Done
}
#[must_use]
pub fn config(&self) -> &GenerationConfig {
&self.config
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BlockForwardPlan {
pub block_idx: usize,
pub hidden_dim: usize,
pub kv_dim: usize,
pub qkv_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub intermediate_dim: usize,
pub use_swiglu: bool,
pub heads_per_kv: usize,
}
impl BlockForwardPlan {
#[must_use]
pub fn from_config(config: &GpuModelConfig, block_idx: usize, has_gate_weight: bool) -> Self {
let head_dim = config.head_dim();
let num_kv_heads = config.num_kv_heads;
let heads_per_kv = config.num_heads / num_kv_heads;
Self {
block_idx,
hidden_dim: config.hidden_dim,
kv_dim: config.kv_dim(),
qkv_dim: config.qkv_dim(),
num_heads: config.num_heads,
num_kv_heads,
head_dim,
intermediate_dim: config.intermediate_dim,
use_swiglu: has_gate_weight,
heads_per_kv,
}
}
#[must_use]
pub fn is_gqa(&self) -> bool {
self.heads_per_kv > 1
}
#[must_use]
pub fn attention_output_size(&self) -> usize {
self.hidden_dim
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum SamplingStrategy {
#[default]
Greedy,
TopK {
k: usize,
},
TopP {
p: f32,
},
Temperature {
temp: f32,
},
}
#[inline]
fn is_valid_temperature(temp: f32) -> bool {
temp > 0.0 && temp != 1.0
}
#[inline]
fn is_valid_top_p(p: f32) -> bool {
p < 1.0 && p > 0.0
}
#[inline]
fn is_valid_top_k(k: usize) -> bool {
k > 0 && k < usize::MAX
}
#[must_use]
pub fn plan_sampling(
temperature: Option<f32>,
top_k: Option<usize>,
top_p: Option<f32>,
) -> SamplingStrategy {
if let Some(temp) = temperature.filter(|&t| is_valid_temperature(t)) {
return SamplingStrategy::Temperature { temp };
}
if let Some(p) = top_p.filter(|&p| is_valid_top_p(p)) {
return SamplingStrategy::TopP { p };
}
if let Some(k) = top_k.filter(|&k| is_valid_top_k(k)) {
return SamplingStrategy::TopK { k };
}
SamplingStrategy::Greedy
}
#[derive(Debug, Clone, PartialEq)]
pub enum LmHeadPath {
CpuTransposed,
Gpu,
}
#[must_use]
pub fn plan_lm_head_path(
vocab_size: usize,
hidden_dim: usize,
gpu_buffer_limit: usize,
) -> LmHeadPath {
let elements = vocab_size * hidden_dim;
if vocab_size > 8192 || elements > gpu_buffer_limit {
LmHeadPath::CpuTransposed
} else {
LmHeadPath::Gpu
}
}
include!("planner_generation_config.rs");