use std::fmt;
use crate::buffer::Buffer;
use crate::error::RuntimeError;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Scratchpad {
buffer: Buffer<f64>,
max_seq_len: usize,
dim: usize,
current_len: usize,
}
impl Scratchpad {
pub fn new(max_seq_len: usize, dim: usize) -> Self {
Scratchpad {
buffer: Buffer::alloc(max_seq_len * dim, 0.0),
max_seq_len,
dim,
current_len: 0,
}
}
pub fn len(&self) -> usize {
self.current_len
}
pub fn is_empty(&self) -> bool {
self.current_len == 0
}
pub fn capacity(&self) -> usize {
self.max_seq_len
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn append(&mut self, token_vec: &[f64]) -> Result<(), RuntimeError> {
if token_vec.len() != self.dim {
return Err(RuntimeError::ShapeMismatch {
expected: self.dim,
got: token_vec.len(),
});
}
if self.current_len >= self.max_seq_len {
return Err(RuntimeError::InvalidOperation(
format!(
"Scratchpad full: {} / {} tokens",
self.current_len, self.max_seq_len
),
));
}
let base = self.current_len * self.dim;
self.buffer.make_unique();
for (i, &val) in token_vec.iter().enumerate() {
self.buffer.set(base + i, val)?;
}
self.current_len += 1;
Ok(())
}
pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
if t.ndim() != 2 || t.shape()[1] != self.dim {
return Err(RuntimeError::InvalidOperation(
format!(
"append_tensor: expected shape [n, {}], got {:?}",
self.dim,
t.shape()
),
));
}
let n = t.shape()[0];
if self.current_len + n > self.max_seq_len {
return Err(RuntimeError::InvalidOperation(
format!(
"Scratchpad overflow: {} + {} > {} max",
self.current_len, n, self.max_seq_len
),
));
}
let data = t.to_vec();
self.buffer.make_unique();
let base = self.current_len * self.dim;
for (i, &val) in data.iter().enumerate() {
self.buffer.set(base + i, val)?;
}
self.current_len += n;
Ok(())
}
pub fn as_tensor(&self) -> Tensor {
let shape = vec![self.current_len, self.dim];
Tensor {
buffer: self.buffer.clone(), shape: shape.clone(),
strides: Tensor::compute_strides(&shape),
offset: 0,
}
}
pub fn clear(&mut self) {
self.current_len = 0;
}
}
impl fmt::Display for Scratchpad {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Scratchpad(len={}, capacity={}, dim={})",
self.current_len, self.max_seq_len, self.dim
)
}
}