use burn::prelude::*;
#[derive(Clone, Debug)]
pub struct BlockState<B: Backend> {
pub blocks: Vec<Tensor<B, 3>>,
pub partial_block: Option<Tensor<B, 3>>, }
impl<B: Backend> BlockState<B> {
pub fn new(token_embeddings: Tensor<B, 3>) -> Self {
Self {
blocks: vec![token_embeddings],
partial_block: None,
}
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_new_state() {
let device = Default::default();
let emb = Tensor::<TestBackend, 3>::zeros([2, 16, 64], &device);
let state = BlockState::new(emb);
assert_eq!(state.num_blocks(), 1);
assert!(state.partial_block.is_none());
}
}