Skip to main content

cjc_runtime/
scratchpad.rs

1//! KV-cache scratchpad -- zero-allocation state persistence for transformer inference.
2//!
3//! Provides [`Scratchpad`], a pre-allocated linear buffer for appending
4//! key/value token vectors without per-token heap allocation. The entire
5//! `[max_seq_len, dim]` storage is allocated once at construction; subsequent
6//! [`append`](Scratchpad::append) calls copy data into existing storage.
7//!
8//! # NoGC guarantee
9//!
10//! After construction, `append` performs no heap allocation -- it writes
11//! directly into the pre-allocated [`Buffer`]. The [`as_tensor`](Scratchpad::as_tensor)
12//! method returns a zero-copy view via `Rc` clone of the underlying buffer.
13//!
14//! # Relationship to [`PagedKvCache`](crate::paged_kv::PagedKvCache)
15//!
16//! `Scratchpad` uses a single contiguous buffer (simpler, better for small
17//! sequences). [`PagedKvCache`](crate::paged_kv::PagedKvCache) uses block
18//! paging (better for large sequences where contiguous allocation may
19//! fragment).
20
21use std::fmt;
22
23use crate::buffer::Buffer;
24use crate::error::RuntimeError;
25use crate::tensor::Tensor;
26
27// ---------------------------------------------------------------------------
28// 2b. KV-Cache Scratchpad (Zero-Allocation State Persistence)
29// ---------------------------------------------------------------------------
30
31/// A pre-allocated scratch buffer for KV-cache. Allows appending new
32/// key/value vectors without re-allocation, up to a fixed `max_seq_len`.
33///
34/// Layout: `[max_seq_len, dim]` with a `current_len` cursor.
35/// All memory is allocated once at construction; `append` only copies
36/// new data into existing storage (zero GC pressure per token).
37#[derive(Debug, Clone)]
38pub struct Scratchpad {
39    /// Underlying tensor of shape `[max_seq_len, dim]`.
40    buffer: Buffer<f64>,
41    /// Maximum sequence length (pre-allocated).
42    max_seq_len: usize,
43    /// Hidden dimension per token.
44    dim: usize,
45    /// Current number of tokens stored.
46    current_len: usize,
47}
48
49impl Scratchpad {
50    /// Create a new scratchpad pre-allocated for `max_seq_len` tokens of
51    /// dimension `dim`. Zero-fills all storage upfront.
52    pub fn new(max_seq_len: usize, dim: usize) -> Self {
53        Scratchpad {
54            buffer: Buffer::alloc(max_seq_len * dim, 0.0),
55            max_seq_len,
56            dim,
57            current_len: 0,
58        }
59    }
60
61    /// Number of tokens currently stored.
62    pub fn len(&self) -> usize {
63        self.current_len
64    }
65
66    /// Whether no tokens are stored.
67    pub fn is_empty(&self) -> bool {
68        self.current_len == 0
69    }
70
71    /// Maximum sequence length this scratchpad can hold.
72    pub fn capacity(&self) -> usize {
73        self.max_seq_len
74    }
75
76    /// Hidden dimension per token.
77    pub fn dim(&self) -> usize {
78        self.dim
79    }
80
81    /// Append a single token vector `[dim]` to the cache.
82    /// Returns an error if the cache is full. **Zero allocation.**
83    pub fn append(&mut self, token_vec: &[f64]) -> Result<(), RuntimeError> {
84        if token_vec.len() != self.dim {
85            return Err(RuntimeError::ShapeMismatch {
86                expected: self.dim,
87                got: token_vec.len(),
88            });
89        }
90        if self.current_len >= self.max_seq_len {
91            return Err(RuntimeError::InvalidOperation(
92                format!(
93                    "Scratchpad full: {} / {} tokens",
94                    self.current_len, self.max_seq_len
95                ),
96            ));
97        }
98        let base = self.current_len * self.dim;
99        self.buffer.make_unique();
100        for (i, &val) in token_vec.iter().enumerate() {
101            self.buffer.set(base + i, val)?;
102        }
103        self.current_len += 1;
104        Ok(())
105    }
106
107    /// Append a batch of token vectors from a tensor of shape `[n, dim]`.
108    /// **Zero allocation** — writes directly into pre-allocated storage.
109    pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
110        if t.ndim() != 2 || t.shape()[1] != self.dim {
111            return Err(RuntimeError::InvalidOperation(
112                format!(
113                    "append_tensor: expected shape [n, {}], got {:?}",
114                    self.dim,
115                    t.shape()
116                ),
117            ));
118        }
119        let n = t.shape()[0];
120        if self.current_len + n > self.max_seq_len {
121            return Err(RuntimeError::InvalidOperation(
122                format!(
123                    "Scratchpad overflow: {} + {} > {} max",
124                    self.current_len, n, self.max_seq_len
125                ),
126            ));
127        }
128        let data = t.to_vec();
129        self.buffer.make_unique();
130        let base = self.current_len * self.dim;
131        for (i, &val) in data.iter().enumerate() {
132            self.buffer.set(base + i, val)?;
133        }
134        self.current_len += n;
135        Ok(())
136    }
137
138    /// Get a Tensor view `[current_len, dim]` of the stored data.
139    /// Shares the underlying buffer (zero-copy).
140    pub fn as_tensor(&self) -> Tensor {
141        let shape = vec![self.current_len, self.dim];
142        Tensor {
143            buffer: self.buffer.clone(), // Rc clone, not data copy
144            shape: shape.clone(),
145            strides: Tensor::compute_strides(&shape),
146            offset: 0,
147        }
148    }
149
150    /// Reset the cache to empty without deallocating.
151    /// The underlying buffer is retained for reuse.
152    pub fn clear(&mut self) {
153        self.current_len = 0;
154    }
155}
156
157impl fmt::Display for Scratchpad {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        write!(
160            f,
161            "Scratchpad(len={}, capacity={}, dim={})",
162            self.current_len, self.max_seq_len, self.dim
163        )
164    }
165}
166