use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use super::shard_manager::{ShardConfig, ShardManager, ShardLocation};
#[derive(Debug, Clone)]
pub struct SequenceConfig {
pub shard: ShardConfig,
pub num_layers: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub block_size: usize,
}
impl Default for SequenceConfig {
fn default() -> Self {
Self {
shard: ShardConfig::default(),
num_layers: 32,
num_kv_heads: 8,
head_dim: 128,
block_size: 16,
}
}
}
pub struct SequenceHandle {
id: usize,
_marker: PhantomData<*const ()>,
}
impl SequenceHandle {
fn new(id: usize) -> Self {
Self {
id,
_marker: PhantomData,
}
}
pub fn id(&self) -> usize {
self.id
}
}
#[derive(Debug)]
pub struct LayerKV<B: Backend> {
pub keys: Option<Tensor<B, 3>>,
pub values: Option<Tensor<B, 3>>,
pub seq_len: usize,
}
impl<B: Backend> LayerKV<B> {
pub fn new() -> Self {
Self {
keys: None,
values: None,
seq_len: 0,
}
}
pub fn append(&mut self, k: Tensor<B, 3>, v: Tensor<B, 3>) {
let new_len = k.dims()[0];
self.keys = Some(match self.keys.take() {
Some(existing) => Tensor::cat(vec![existing, k], 0),
None => k,
});
self.values = Some(match self.values.take() {
Some(existing) => Tensor::cat(vec![existing, v], 0),
None => v,
});
self.seq_len += new_len;
}
pub fn get(&self) -> Option<(Tensor<B, 3>, Tensor<B, 3>)> {
match (&self.keys, &self.values) {
(Some(k), Some(v)) => Some((k.clone(), v.clone())),
_ => None,
}
}
}
impl<B: Backend> Default for LayerKV<B> {
fn default() -> Self {
Self::new()
}
}
pub struct SequenceKV<B: Backend> {
id: usize,
layers: Vec<LayerKV<B>>,
shard_manager: ShardManager,
config: SequenceConfig,
device: B::Device,
}
impl<B: Backend> SequenceKV<B> {
pub fn new(id: usize, config: SequenceConfig, device: B::Device) -> Self {
let layers = (0..config.num_layers).map(|_| LayerKV::new()).collect();
let shard_manager = ShardManager::new(config.shard.clone());
Self {
id,
layers,
shard_manager,
config,
device,
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn append(&mut self, layer: usize, k: Tensor<B, 3>, v: Tensor<B, 3>) {
if layer >= self.layers.len() {
return;
}
let new_len = k.dims()[0];
if layer == 0 {
self.shard_manager.allocate(new_len);
}
self.layers[layer].append(k, v);
}
pub fn get_kv(&self, layer: usize) -> Option<(Tensor<B, 3>, Tensor<B, 3>)> {
self.layers.get(layer).and_then(|l| l.get())
}
pub fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.seq_len).unwrap_or(0)
}
pub fn shards(&self) -> &[ShardLocation] {
self.shard_manager.shards()
}
pub fn device(&self) -> &B::Device {
&self.device
}
pub fn config(&self) -> &SequenceConfig {
&self.config
}
pub fn has_capacity(&self, additional: usize) -> bool {
self.shard_manager.has_capacity(additional)
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.keys = None;
layer.values = None;
layer.seq_len = 0;
}
self.shard_manager.reset();
}
}
pub struct SequenceFactory<B: Backend> {
next_id: AtomicUsize,
config: SequenceConfig,
device: B::Device,
}
impl<B: Backend> SequenceFactory<B> {
pub fn new(config: SequenceConfig, device: B::Device) -> Self {
Self {
next_id: AtomicUsize::new(0),
config,
device,
}
}
pub fn create(&self) -> (SequenceHandle, SequenceKV<B>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let handle = SequenceHandle::new(id);
let kv = SequenceKV::new(id, self.config.clone(), self.device.clone());
(handle, kv)
}
pub fn num_created(&self) -> usize {
self.next_id.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::ndarray::NdArray;
type TestBackend = NdArray<f32>;
#[test]
fn test_sequence_kv_basic() {
let device = <TestBackend as Backend>::Device::default();
let config = SequenceConfig {
num_layers: 2,
num_kv_heads: 4,
head_dim: 8,
..Default::default()
};
let mut seq = SequenceKV::<TestBackend>::new(0, config, device.clone());
let k = Tensor::zeros([10, 4, 8], &device);
let v = Tensor::zeros([10, 4, 8], &device);
seq.append(0, k, v);
assert_eq!(seq.seq_len(), 10);
let (k, v) = seq.get_kv(0).unwrap();
assert_eq!(k.dims(), [10, 4, 8]);
assert_eq!(v.dims(), [10, 4, 8]);
}
#[test]
fn test_sequence_kv_multiple_appends() {
let device = <TestBackend as Backend>::Device::default();
let config = SequenceConfig {
num_layers: 1,
num_kv_heads: 2,
head_dim: 4,
..Default::default()
};
let mut seq = SequenceKV::<TestBackend>::new(0, config, device.clone());
for _ in 0..5 {
let k = Tensor::zeros([8, 2, 4], &device);
let v = Tensor::zeros([8, 2, 4], &device);
seq.append(0, k, v);
}
assert_eq!(seq.seq_len(), 40);
let (k, _) = seq.get_kv(0).unwrap();
assert_eq!(k.dims(), [40, 2, 4]);
}
#[test]
fn test_sequence_factory() {
let device = <TestBackend as Backend>::Device::default();
let config = SequenceConfig::default();
let factory = SequenceFactory::<TestBackend>::new(config, device);
let (handle1, _kv1) = factory.create();
let (handle2, _kv2) = factory.create();
assert_eq!(handle1.id(), 0);
assert_eq!(handle2.id(), 1);
assert_eq!(factory.num_created(), 2);
}
#[test]
fn test_sequence_isolation() {
let device = <TestBackend as Backend>::Device::default();
let config = SequenceConfig {
num_layers: 1,
num_kv_heads: 2,
head_dim: 4,
..Default::default()
};
let factory = SequenceFactory::<TestBackend>::new(config, device.clone());
let (_, mut seq1) = factory.create();
let (_, mut seq2) = factory.create();
let k = Tensor::ones([5, 2, 4], &device);
let v = Tensor::ones([5, 2, 4], &device);
seq1.append(0, k, v);
assert_eq!(seq1.seq_len(), 5);
assert_eq!(seq2.seq_len(), 0);
}
}