use crate::tensor::DenseTensor;
use super::model::LlamaModel;
use super::generation::GenerationConfig;
use super::kv_cache::KVCache;
#[derive(Debug, Clone)]
pub struct BatchData {
pub input_ids: Vec<Vec<usize>>,
pub attention_mask: Option<DenseTensor>,
pub position_ids: Option<Vec<Vec<usize>>>,
pub seq_lengths: Vec<usize>,
}
impl BatchData {
pub fn new(input_ids: Vec<Vec<usize>>) -> Self {
let seq_lengths: Vec<usize> = input_ids.iter().map(|ids| ids.len()).collect();
let max_len = seq_lengths.iter().max().copied().unwrap_or(0);
let mut padded_ids = Vec::new();
for ids in &input_ids {
let mut padded = ids.clone();
while padded.len() < max_len {
padded.push(0); }
padded_ids.push(padded);
}
let batch_size = input_ids.len();
let mut mask_data = Vec::with_capacity(batch_size * max_len * max_len);
for &seq_len in seq_lengths.iter() {
for j in 0..max_len {
for k in 0..max_len {
let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
mask_data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
}
}
}
let attention_mask = Some(DenseTensor::new(mask_data, vec![batch_size, max_len, max_len]));
Self {
input_ids: padded_ids,
attention_mask,
position_ids: None,
seq_lengths,
}
}
pub fn batch_size(&self) -> usize {
self.input_ids.len()
}
pub fn max_seq_len(&self) -> usize {
self.seq_lengths.iter().max().copied().unwrap_or(0)
}
pub fn padded_input_ids(&self) -> &[Vec<usize>] {
&self.input_ids
}
}
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub id: usize,
pub input_ids: Vec<usize>,
pub config: GenerationConfig,
pub generated: Vec<usize>,
pub completed: bool,
pub priority: usize,
}
impl InferenceRequest {
pub fn new(id: usize, input_ids: Vec<usize>, config: GenerationConfig) -> Self {
Self {
id,
input_ids: input_ids.clone(),
config,
generated: input_ids,
completed: false,
priority: 0,
}
}
pub fn append_token(&mut self, token: usize) {
self.generated.push(token);
if self.generated.len() >= self.config.max_length {
self.completed = true;
}
if let Some(eos) = self.config.eos_token_id {
if token == eos {
self.completed = true;
}
}
}
pub fn current_len(&self) -> usize {
self.generated.len()
}
}
#[derive(Debug)]
pub struct RequestScheduler {
pending: Vec<InferenceRequest>,
active: Vec<InferenceRequest>,
completed: Vec<InferenceRequest>,
next_id: usize,
max_batch_size: usize,
}
impl RequestScheduler {
pub fn new(max_batch_size: usize) -> Self {
Self {
pending: Vec::new(),
active: Vec::new(),
completed: Vec::new(),
next_id: 0,
max_batch_size,
}
}
pub fn add_request(&mut self, input_ids: Vec<usize>, config: GenerationConfig) -> usize {
let id = self.next_id;
self.next_id += 1;
let request = InferenceRequest::new(id, input_ids, config);
self.pending.push(request);
id
}
pub fn schedule(&mut self) -> Vec<&mut InferenceRequest> {
self.active.retain(|req| {
!req.completed
});
while !self.pending.is_empty() && self.active.len() < self.max_batch_size {
let request = self.pending.remove(0);
self.active.push(request);
}
self.active.iter_mut().collect()
}
pub fn num_pending(&self) -> usize {
self.pending.len()
}
pub fn num_active(&self) -> usize {
self.active.len()
}
pub fn num_completed(&self) -> usize {
self.completed.len()
}
pub fn pop_completed(&mut self) -> Vec<InferenceRequest> {
std::mem::take(&mut self.completed)
}
}
#[derive(Debug)]
pub struct BatchInference<'a> {
model: &'a LlamaModel,
kv_caches: Vec<KVCache>,
batch_size: usize,
}
impl<'a> BatchInference<'a> {
pub fn new(model: &'a LlamaModel, max_batch_size: usize, max_seq_len: usize) -> Self {
let kv_caches = vec![
KVCache::new(
model.num_layers(),
max_seq_len,
model.hidden_dim(),
model.config.get_num_key_value_heads(),
);
max_batch_size
];
Self {
model,
kv_caches,
batch_size: 0,
}
}
pub fn forward(&mut self, batch: &BatchData) -> DenseTensor {
let batch_size = batch.batch_size();
self.batch_size = batch_size;
self.model.forward(&batch.input_ids, batch.attention_mask.as_ref())
}
pub fn step(&mut self, requests: &[&mut InferenceRequest]) -> Vec<usize> {
let input_ids: Vec<Vec<usize>> = requests
.iter()
.map(|req| vec![*req.generated.last().unwrap()])
.collect();
let batch = BatchData::new(input_ids);
let logits = self.forward(&batch);
let mut tokens = Vec::new();
for (i, req) in requests.iter().enumerate() {
let seq_len = req.current_len();
let token_logits = logits.get_row(i * seq_len + seq_len - 1);
let mut probs = token_logits.clone();
if req.config.temperature != 1.0 {
probs = probs.scale(1.0 / req.config.temperature);
}
probs = probs.softmax(-1);
let token = if req.config.do_sample {
self.sample_from_probs(probs.data())
} else {
self.argmax(probs.data())
};
tokens.push(token);
}
tokens
}
pub fn generate_continuous(&mut self, scheduler: &mut RequestScheduler) -> Vec<Vec<usize>> {
let mut results: Vec<Option<Vec<usize>>> = Vec::new();
for _ in 0..scheduler.next_id {
results.push(None);
}
while scheduler.num_active() > 0 || scheduler.num_pending() > 0 {
let mut active_requests = scheduler.schedule();
if active_requests.is_empty() {
break;
}
let tokens = self.step(&active_requests);
for (req, token) in active_requests.iter_mut().zip(tokens) {
req.append_token(token);
if req.completed {
results[req.id] = Some(req.generated.clone());
}
}
}
results.into_iter().flatten().collect()
}
pub fn reset(&mut self) {
for cache in &mut self.kv_caches {
cache.reset();
}
}
fn argmax(&self, probs: &[f64]) -> usize {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
fn sample_from_probs(&self, probs: &[f64]) -> usize {
use rand::Rng;
let mut rng = rand::thread_rng();
let r: f64 = rng.gen();
let mut cumulative = 0.0;
for (i, &prob) in probs.iter().enumerate() {
cumulative += prob;
if r < cumulative {
return i;
}
}
probs.len() - 1
}
}
pub mod utils {
use super::*;
pub fn pad_sequences(sequences: &[Vec<usize>], pad_token: usize) -> (Vec<Vec<usize>>, Vec<usize>) {
let max_len = sequences.iter().map(|s| s.len()).max().unwrap_or(0);
let mut padded = Vec::new();
let mut lengths = Vec::new();
for seq in sequences {
lengths.push(seq.len());
let mut padded_seq = seq.clone();
while padded_seq.len() < max_len {
padded_seq.push(pad_token);
}
padded.push(padded_seq);
}
(padded, lengths)
}
pub fn create_attention_mask(lengths: &[usize]) -> DenseTensor {
let batch_size = lengths.len();
let max_len = lengths.iter().max().copied().unwrap_or(0);
let mut data = Vec::with_capacity(batch_size * max_len * max_len);
for &seq_len in lengths.iter() {
for j in 0..max_len {
for k in 0..max_len {
let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
}
}
}
DenseTensor::new(data, vec![batch_size, max_len, max_len])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transformer::model::LlamaModel;
use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
use crate::transformer::loader::LlamaConfig;
use crate::tensor::DenseTensor;
fn create_test_model() -> LlamaModel {
let config = LlamaConfig::llama_7b();
let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
let hidden_dim = config.hidden_size;
let num_heads = config.num_attention_heads;
let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
let input_layernorm = RMSNorm::default(hidden_dim);
let post_attention_layernorm = RMSNorm::default(hidden_dim);
let layer = super::super::model::LlamaDecoderLayer::new(
self_attn, mlp, input_layernorm, post_attention_layernorm
);
let layers = vec![layer; 2];
let norm = RMSNorm::default(hidden_dim);
LlamaModel::new(config, embed_tokens, layers, norm, None)
}
#[test]
fn test_batch_data_creation() {
let input_ids = vec![
vec![1, 2, 3],
vec![4, 5],
vec![6, 7, 8, 9],
];
let batch = BatchData::new(input_ids.clone());
assert_eq!(batch.batch_size(), 3);
assert_eq!(batch.max_seq_len(), 4);
assert_eq!(batch.seq_lengths, vec![3, 2, 4]);
}
#[test]
fn test_inference_request() {
let config = GenerationConfig::greedy();
let mut request = InferenceRequest::new(0, vec![1, 2, 3], config);
assert!(!request.completed);
assert_eq!(request.current_len(), 3);
request.append_token(4);
assert_eq!(request.current_len(), 4);
}
#[test]
fn test_request_scheduler() {
let mut scheduler = RequestScheduler::new(2);
let _id1 = scheduler.add_request(vec![1, 2, 3], GenerationConfig::greedy());
let _id2 = scheduler.add_request(vec![4, 5], GenerationConfig::greedy());
let _id3 = scheduler.add_request(vec![6, 7, 8], GenerationConfig::greedy());
assert_eq!(scheduler.num_pending(), 3);
assert_eq!(scheduler.num_active(), 0);
let active = scheduler.schedule();
assert_eq!(active.len(), 2); assert_eq!(scheduler.num_pending(), 1);
assert_eq!(scheduler.num_active(), 2);
}
#[test]
fn test_batch_inference_creation() {
let model = create_test_model();
let batch_infer = BatchInference::new(&model, 4, 512);
assert_eq!(batch_infer.kv_caches.len(), 4);
}
#[test]
fn test_pad_sequences() {
let sequences = vec![
vec![1, 2],
vec![3, 4, 5],
vec![6],
];
let (padded, lengths) = utils::pad_sequences(&sequences, 0);
assert_eq!(padded, vec![
vec![1, 2, 0],
vec![3, 4, 5],
vec![6, 0, 0],
]);
assert_eq!(lengths, vec![2, 3, 1]);
}
}