cjc_runtime/
scratchpad.rs1use std::fmt;
2
3use crate::buffer::Buffer;
4use crate::error::RuntimeError;
5use crate::tensor::Tensor;
6
7#[derive(Debug, Clone)]
18pub struct Scratchpad {
19 buffer: Buffer<f64>,
21 max_seq_len: usize,
23 dim: usize,
25 current_len: usize,
27}
28
29impl Scratchpad {
30 pub fn new(max_seq_len: usize, dim: usize) -> Self {
33 Scratchpad {
34 buffer: Buffer::alloc(max_seq_len * dim, 0.0),
35 max_seq_len,
36 dim,
37 current_len: 0,
38 }
39 }
40
41 pub fn len(&self) -> usize {
43 self.current_len
44 }
45
46 pub fn is_empty(&self) -> bool {
48 self.current_len == 0
49 }
50
51 pub fn capacity(&self) -> usize {
53 self.max_seq_len
54 }
55
56 pub fn dim(&self) -> usize {
58 self.dim
59 }
60
61 pub fn append(&mut self, token_vec: &[f64]) -> Result<(), RuntimeError> {
64 if token_vec.len() != self.dim {
65 return Err(RuntimeError::ShapeMismatch {
66 expected: self.dim,
67 got: token_vec.len(),
68 });
69 }
70 if self.current_len >= self.max_seq_len {
71 return Err(RuntimeError::InvalidOperation(
72 format!(
73 "Scratchpad full: {} / {} tokens",
74 self.current_len, self.max_seq_len
75 ),
76 ));
77 }
78 let base = self.current_len * self.dim;
79 self.buffer.make_unique();
80 for (i, &val) in token_vec.iter().enumerate() {
81 self.buffer.set(base + i, val)?;
82 }
83 self.current_len += 1;
84 Ok(())
85 }
86
87 pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
90 if t.ndim() != 2 || t.shape()[1] != self.dim {
91 return Err(RuntimeError::InvalidOperation(
92 format!(
93 "append_tensor: expected shape [n, {}], got {:?}",
94 self.dim,
95 t.shape()
96 ),
97 ));
98 }
99 let n = t.shape()[0];
100 if self.current_len + n > self.max_seq_len {
101 return Err(RuntimeError::InvalidOperation(
102 format!(
103 "Scratchpad overflow: {} + {} > {} max",
104 self.current_len, n, self.max_seq_len
105 ),
106 ));
107 }
108 let data = t.to_vec();
109 self.buffer.make_unique();
110 let base = self.current_len * self.dim;
111 for (i, &val) in data.iter().enumerate() {
112 self.buffer.set(base + i, val)?;
113 }
114 self.current_len += n;
115 Ok(())
116 }
117
118 pub fn as_tensor(&self) -> Tensor {
121 let shape = vec![self.current_len, self.dim];
122 Tensor {
123 buffer: self.buffer.clone(), shape: shape.clone(),
125 strides: Tensor::compute_strides(&shape),
126 offset: 0,
127 }
128 }
129
130 pub fn clear(&mut self) {
133 self.current_len = 0;
134 }
135}
136
137impl fmt::Display for Scratchpad {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 write!(
140 f,
141 "Scratchpad(len={}, capacity={}, dim={})",
142 self.current_len, self.max_seq_len, self.dim
143 )
144 }
145}
146