use crate::error::{ModelError, ModelResult};
use crate::AutoregressiveModel;
use kizzasi_core::HiddenState;
use scirs2_core::ndarray::{s, Array2, Array3};
pub struct BatchedModel<M: AutoregressiveModel> {
model: M,
batch_size: usize,
batch_states: Vec<Vec<HiddenState>>,
sequence_lengths: Vec<usize>,
max_length: usize,
}
impl<M: AutoregressiveModel> BatchedModel<M> {
pub fn new(model: M, batch_size: usize) -> ModelResult<Self> {
if batch_size == 0 {
return Err(ModelError::invalid_config("batch_size must be > 0"));
}
let template_states = model.get_states();
let batch_states = (0..batch_size).map(|_| template_states.clone()).collect();
Ok(Self {
model,
batch_size,
batch_states,
sequence_lengths: vec![0; batch_size],
max_length: 0,
})
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn predict_batch(&mut self, inputs: &Array2<f32>) -> ModelResult<Array2<f32>> {
let (batch_size, _input_dim) = inputs.dim();
if batch_size != self.batch_size {
return Err(ModelError::dimension_mismatch(
"batch predict",
self.batch_size,
batch_size,
));
}
let output_dim = self.model.hidden_dim();
let mut outputs = Array2::zeros((batch_size, output_dim));
for batch_idx in 0..batch_size {
let input = inputs.row(batch_idx).to_owned();
self.model
.set_states(self.batch_states[batch_idx].clone())?;
let output = self.model.step(&input)?;
outputs.row_mut(batch_idx).assign(&output);
self.batch_states[batch_idx] = self.model.get_states();
self.sequence_lengths[batch_idx] += 1;
}
self.max_length = *self.sequence_lengths.iter().max().unwrap_or(&0);
Ok(outputs)
}
pub fn predict_sequence_batch(&mut self, inputs: &Array3<f32>) -> ModelResult<Array3<f32>> {
let (batch_size, seq_len, _input_dim) = inputs.dim();
if batch_size != self.batch_size {
return Err(ModelError::dimension_mismatch(
"batch forward_sequence",
self.batch_size,
batch_size,
));
}
let output_dim = self.model.hidden_dim();
let mut outputs = Array3::zeros((batch_size, seq_len, output_dim));
for t in 0..seq_len {
let step_inputs = inputs.slice(s![.., t, ..]).to_owned();
let step_outputs = self.predict_batch(&step_inputs)?;
outputs.slice_mut(s![.., t, ..]).assign(&step_outputs);
}
Ok(outputs)
}
pub fn reset_batch_items(&mut self, indices: &[usize]) -> ModelResult<()> {
self.model.reset();
let template_states = self.model.get_states();
for &idx in indices {
if idx >= self.batch_size {
return Err(ModelError::invalid_config(format!(
"batch index {} out of range for batch_size {}",
idx, self.batch_size
)));
}
self.batch_states[idx] = template_states.clone();
self.sequence_lengths[idx] = 0;
}
Ok(())
}
pub fn reset_all(&mut self) {
self.model.reset();
let template_states = self.model.get_states();
self.batch_states = (0..self.batch_size)
.map(|_| template_states.clone())
.collect();
self.sequence_lengths = vec![0; self.batch_size];
self.max_length = 0;
}
pub fn sequence_lengths(&self) -> &[usize] {
&self.sequence_lengths
}
pub fn max_sequence_length(&self) -> usize {
self.max_length
}
pub fn get_batch_item_states(&self, batch_idx: usize) -> ModelResult<&Vec<HiddenState>> {
if batch_idx >= self.batch_size {
return Err(ModelError::invalid_config(format!(
"batch index {} out of range for batch_size {}",
batch_idx, self.batch_size
)));
}
Ok(&self.batch_states[batch_idx])
}
pub fn set_batch_item_states(
&mut self,
batch_idx: usize,
states: Vec<HiddenState>,
) -> ModelResult<()> {
if batch_idx >= self.batch_size {
return Err(ModelError::invalid_config(format!(
"batch index {} out of range for batch_size {}",
batch_idx, self.batch_size
)));
}
if states.len() != self.model.num_layers() {
return Err(ModelError::invalid_config(format!(
"expected {} layer states, got {}",
self.model.num_layers(),
states.len()
)));
}
self.batch_states[batch_idx] = states;
Ok(())
}
pub fn model(&self) -> &M {
&self.model
}
pub fn model_mut(&mut self) -> &mut M {
&mut self.model
}
}
pub struct DynamicBatcher<M: AutoregressiveModel> {
#[allow(dead_code)]
model: M,
min_batch_size: usize,
max_batch_size: usize,
current_batch_size: usize,
target_latency_us: u64,
}
impl<M: AutoregressiveModel> DynamicBatcher<M> {
pub fn new(
model: M,
min_batch_size: usize,
max_batch_size: usize,
target_latency_us: u64,
) -> ModelResult<Self> {
if min_batch_size == 0 {
return Err(ModelError::invalid_config("min_batch_size must be > 0"));
}
if max_batch_size < min_batch_size {
return Err(ModelError::invalid_config(
"max_batch_size must be >= min_batch_size",
));
}
Ok(Self {
model,
min_batch_size,
max_batch_size,
current_batch_size: min_batch_size,
target_latency_us,
})
}
pub fn current_batch_size(&self) -> usize {
self.current_batch_size
}
pub fn update_batch_size(&mut self, observed_latency_us: u64) {
if observed_latency_us < self.target_latency_us {
self.current_batch_size = (self.current_batch_size + 1).min(self.max_batch_size);
} else if observed_latency_us > self.target_latency_us * 2 {
self.current_batch_size = (self.current_batch_size - 1).max(self.min_batch_size);
}
}
pub fn reset(&mut self) {
self.current_batch_size = self.min_batch_size;
}
}
pub mod padding {
use super::*;
pub fn pad_sequences(sequences: &[Array2<f32>], pad_value: f32) -> (Array3<f32>, Vec<usize>) {
if sequences.is_empty() {
return (Array3::zeros((0, 0, 0)), vec![]);
}
let batch_size = sequences.len();
let max_len = sequences.iter().map(|s| s.nrows()).max().unwrap_or(0);
let feature_dim = sequences[0].ncols();
let mut padded = Array3::from_elem((batch_size, max_len, feature_dim), pad_value);
let mut lengths = Vec::with_capacity(batch_size);
for (batch_idx, seq) in sequences.iter().enumerate() {
let seq_len = seq.nrows();
lengths.push(seq_len);
for t in 0..seq_len {
for f in 0..feature_dim {
padded[[batch_idx, t, f]] = seq[[t, f]];
}
}
}
(padded, lengths)
}
pub fn unpad_sequences(padded: &Array3<f32>, lengths: &[usize]) -> Vec<Array2<f32>> {
let (batch_size, _, feature_dim) = padded.dim();
assert_eq!(batch_size, lengths.len());
lengths
.iter()
.enumerate()
.map(|(batch_idx, &seq_len)| {
let mut seq = Array2::zeros((seq_len, feature_dim));
for t in 0..seq_len {
for f in 0..feature_dim {
seq[[t, f]] = padded[[batch_idx, t, f]];
}
}
seq
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mamba::{Mamba, MambaConfig};
#[test]
fn test_batched_model_creation() {
let config = MambaConfig::default().hidden_dim(64).num_layers(2);
let model = Mamba::new(config).expect("Failed to create Mamba model");
let batched = BatchedModel::new(model, 4);
assert!(batched.is_ok());
let batched = batched.expect("Failed to create batched model");
assert_eq!(batched.batch_size(), 4);
}
#[test]
fn test_batch_prediction() {
let config = MambaConfig::default().hidden_dim(64).num_layers(2);
let model = Mamba::new(config).expect("Failed to create Mamba model");
let mut batched = BatchedModel::new(model, 4).expect("Failed to create batched model");
let inputs = Array2::zeros((4, 1));
let outputs = batched.predict_batch(&inputs);
assert!(outputs.is_ok());
let outputs = outputs.expect("Failed to predict batch");
assert_eq!(outputs.dim(), (4, 64));
}
#[test]
fn test_sequence_batch_prediction() {
let config = MambaConfig::default().hidden_dim(64).num_layers(2);
let model = Mamba::new(config).expect("Failed to create Mamba model");
let mut batched = BatchedModel::new(model, 2).expect("Failed to create batched model");
let inputs = Array3::zeros((2, 10, 1));
let outputs = batched.predict_sequence_batch(&inputs);
assert!(outputs.is_ok());
let outputs = outputs.expect("Failed to predict sequence batch");
assert_eq!(outputs.dim(), (2, 10, 64));
}
#[test]
fn test_batch_reset() {
let config = MambaConfig::default().hidden_dim(64).num_layers(2);
let model = Mamba::new(config).expect("Failed to create Mamba model");
let mut batched = BatchedModel::new(model, 4).expect("Failed to create batched model");
let inputs = Array2::zeros((4, 1));
let _ = batched
.predict_batch(&inputs)
.expect("Failed to predict batch");
batched
.reset_batch_items(&[0, 2])
.expect("Failed to reset batch items");
assert_eq!(batched.sequence_lengths()[0], 0);
assert_eq!(batched.sequence_lengths()[1], 1);
assert_eq!(batched.sequence_lengths()[2], 0);
assert_eq!(batched.sequence_lengths()[3], 1);
batched.reset_all();
assert!(batched.sequence_lengths().iter().all(|&len| len == 0));
}
#[test]
fn test_padding() {
let seq1 = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Failed to create test array");
let seq2 = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
)
.expect("Failed to create test array");
let seq3 = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
.expect("Failed to create test array");
let sequences = vec![seq1, seq2, seq3];
let (padded, lengths) = padding::pad_sequences(&sequences, 0.0);
assert_eq!(padded.dim(), (3, 5, 2));
assert_eq!(lengths, vec![3, 5, 2]);
let unpadded = padding::unpad_sequences(&padded, &lengths);
assert_eq!(unpadded.len(), 3);
assert_eq!(unpadded[0].dim(), (3, 2));
assert_eq!(unpadded[1].dim(), (5, 2));
assert_eq!(unpadded[2].dim(), (2, 2));
}
#[test]
fn test_dynamic_batcher() {
let config = MambaConfig::default();
let model = Mamba::new(config).expect("Failed to create Mamba model");
let mut batcher =
DynamicBatcher::new(model, 1, 16, 1000).expect("Failed to create dynamic batcher");
assert_eq!(batcher.current_batch_size(), 1);
batcher.update_batch_size(500);
assert_eq!(batcher.current_batch_size(), 2);
batcher.update_batch_size(3000);
assert_eq!(batcher.current_batch_size(), 1);
}
}