Skip to main content

cjc_runtime/
scratchpad.rs

1use std::fmt;
2
3use crate::buffer::Buffer;
4use crate::error::RuntimeError;
5use crate::tensor::Tensor;
6
7// ---------------------------------------------------------------------------
8// 2b. KV-Cache Scratchpad (Zero-Allocation State Persistence)
9// ---------------------------------------------------------------------------
10
11/// A pre-allocated scratch buffer for KV-cache. Allows appending new
12/// key/value vectors without re-allocation, up to a fixed `max_seq_len`.
13///
14/// Layout: `[max_seq_len, dim]` with a `current_len` cursor.
15/// All memory is allocated once at construction; `append` only copies
16/// new data into existing storage (zero GC pressure per token).
17#[derive(Debug, Clone)]
18pub struct Scratchpad {
19    /// Underlying tensor of shape `[max_seq_len, dim]`.
20    buffer: Buffer<f64>,
21    /// Maximum sequence length (pre-allocated).
22    max_seq_len: usize,
23    /// Hidden dimension per token.
24    dim: usize,
25    /// Current number of tokens stored.
26    current_len: usize,
27}
28
29impl Scratchpad {
30    /// Create a new scratchpad pre-allocated for `max_seq_len` tokens of
31    /// dimension `dim`. Zero-fills all storage upfront.
32    pub fn new(max_seq_len: usize, dim: usize) -> Self {
33        Scratchpad {
34            buffer: Buffer::alloc(max_seq_len * dim, 0.0),
35            max_seq_len,
36            dim,
37            current_len: 0,
38        }
39    }
40
41    /// Number of tokens currently stored.
42    pub fn len(&self) -> usize {
43        self.current_len
44    }
45
46    /// Whether no tokens are stored.
47    pub fn is_empty(&self) -> bool {
48        self.current_len == 0
49    }
50
51    /// Maximum sequence length this scratchpad can hold.
52    pub fn capacity(&self) -> usize {
53        self.max_seq_len
54    }
55
56    /// Hidden dimension per token.
57    pub fn dim(&self) -> usize {
58        self.dim
59    }
60
61    /// Append a single token vector `[dim]` to the cache.
62    /// Returns an error if the cache is full. **Zero allocation.**
63    pub fn append(&mut self, token_vec: &[f64]) -> Result<(), RuntimeError> {
64        if token_vec.len() != self.dim {
65            return Err(RuntimeError::ShapeMismatch {
66                expected: self.dim,
67                got: token_vec.len(),
68            });
69        }
70        if self.current_len >= self.max_seq_len {
71            return Err(RuntimeError::InvalidOperation(
72                format!(
73                    "Scratchpad full: {} / {} tokens",
74                    self.current_len, self.max_seq_len
75                ),
76            ));
77        }
78        let base = self.current_len * self.dim;
79        self.buffer.make_unique();
80        for (i, &val) in token_vec.iter().enumerate() {
81            self.buffer.set(base + i, val)?;
82        }
83        self.current_len += 1;
84        Ok(())
85    }
86
87    /// Append a batch of token vectors from a tensor of shape `[n, dim]`.
88    /// **Zero allocation** — writes directly into pre-allocated storage.
89    pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
90        if t.ndim() != 2 || t.shape()[1] != self.dim {
91            return Err(RuntimeError::InvalidOperation(
92                format!(
93                    "append_tensor: expected shape [n, {}], got {:?}",
94                    self.dim,
95                    t.shape()
96                ),
97            ));
98        }
99        let n = t.shape()[0];
100        if self.current_len + n > self.max_seq_len {
101            return Err(RuntimeError::InvalidOperation(
102                format!(
103                    "Scratchpad overflow: {} + {} > {} max",
104                    self.current_len, n, self.max_seq_len
105                ),
106            ));
107        }
108        let data = t.to_vec();
109        self.buffer.make_unique();
110        let base = self.current_len * self.dim;
111        for (i, &val) in data.iter().enumerate() {
112            self.buffer.set(base + i, val)?;
113        }
114        self.current_len += n;
115        Ok(())
116    }
117
118    /// Get a Tensor view `[current_len, dim]` of the stored data.
119    /// Shares the underlying buffer (zero-copy).
120    pub fn as_tensor(&self) -> Tensor {
121        let shape = vec![self.current_len, self.dim];
122        Tensor {
123            buffer: self.buffer.clone(), // Rc clone, not data copy
124            shape: shape.clone(),
125            strides: Tensor::compute_strides(&shape),
126            offset: 0,
127        }
128    }
129
130    /// Reset the cache to empty without deallocating.
131    /// The underlying buffer is retained for reuse.
132    pub fn clear(&mut self) {
133        self.current_len = 0;
134    }
135}
136
137impl fmt::Display for Scratchpad {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        write!(
140            f,
141            "Scratchpad(len={}, capacity={}, dim={})",
142            self.current_len, self.max_seq_len, self.dim
143        )
144    }
145}
146