gel_db_protocol/
buffer.rs1use crate::prelude::*;
2use std::{collections::VecDeque, marker::PhantomData};
3
4#[derive(derive_more::Debug)]
7pub struct StructBuffer<M: StructLength> {
8 #[debug(skip)]
9 _phantom: PhantomData<M>,
10 #[debug("Length={}", accum.len())]
11 accum: VecDeque<u8>,
12}
13
14impl<M: StructLength> Default for StructBuffer<M> {
15 fn default() -> Self {
16 Self {
17 _phantom: PhantomData,
18 accum: VecDeque::with_capacity(32 * 1024),
19 }
20 }
21}
22
23impl<M: StructLength> StructBuffer<M> {
24 pub fn push<'a: 'b, 'b>(
43 &'b mut self,
44 bytes: &'a [u8],
45 mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>),
46 ) {
47 if self.accum.is_empty() {
48 let mut offset = 0;
50 while offset < bytes.len() {
51 if let Some(len) = M::length_of_buf(&bytes[offset..]) {
52 if offset + len <= bytes.len() {
53 f(M::new(&bytes[offset..offset + len]));
54 offset += len;
55 } else {
56 break;
57 }
58 } else {
59 break;
60 }
61 }
62 if offset == bytes.len() {
63 return;
64 }
65 self.accum.extend(&bytes[offset..]);
66 } else {
67 self.accum.extend(bytes);
68 }
69
70 let contiguous = self.accum.make_contiguous();
72 let mut total_processed = 0;
73 while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
74 if total_processed + len <= contiguous.len() {
75 let message_bytes = &contiguous[total_processed..total_processed + len];
76 f(M::new(message_bytes));
77 total_processed += len;
78 } else {
79 break;
80 }
81 }
82 if total_processed > 0 {
83 self.accum.rotate_left(total_processed);
84 self.accum.truncate(self.accum.len() - total_processed);
85 }
86 }
87
88 pub fn push_fallible<'a: 'b, 'b, E>(
107 &'b mut self,
108 bytes: &'a [u8],
109 mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>) -> Result<(), E>,
110 ) -> Result<(), E> {
111 if self.accum.is_empty() {
112 let mut offset = 0;
114 while offset < bytes.len() {
115 if let Some(len) = M::length_of_buf(&bytes[offset..]) {
116 if offset + len <= bytes.len() {
117 let msg = M::new(&bytes[offset..offset + len]);
118 f(msg)?;
119 offset += len;
120 } else {
121 break;
122 }
123 } else {
124 break;
125 }
126 }
127 if offset == bytes.len() {
128 return Ok(());
129 }
130 self.accum.extend(&bytes[offset..]);
131 } else {
132 self.accum.extend(bytes);
133 }
134
135 let contiguous = self.accum.make_contiguous();
137 let mut total_processed = 0;
138 while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
139 if total_processed + len <= contiguous.len() {
140 let message_bytes = &contiguous[total_processed..total_processed + len];
141 f(M::new(message_bytes))?;
142 total_processed += len;
143 } else {
144 break;
145 }
146 }
147 if total_processed > 0 {
148 self.accum.rotate_left(total_processed);
149 self.accum.truncate(self.accum.len() - total_processed);
150 }
151 Ok(())
152 }
153
154 pub fn into_inner(self) -> VecDeque<u8> {
155 self.accum
156 }
157
158 pub fn is_empty(&self) -> bool {
159 self.accum.is_empty()
160 }
161
162 pub fn len(&self) -> usize {
163 self.accum.len()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::prelude::*;
170
171 use super::StructBuffer;
172 use crate::test_protocol::*;
173
174 fn test_data() -> (Vec<u8>, Vec<usize>) {
176 let mut test_data = vec![];
177 let mut lengths = vec![];
178 test_data.append(&mut SyncBuilder::default().to_vec());
179 let len = test_data.len();
180 lengths.push(len);
181 test_data.append(&mut CommandCompleteBuilder { tag: "TAG" }.to_vec());
182 lengths.push(test_data.len() - len);
183 let len = test_data.len();
184 test_data.append(
185 &mut DataRowBuilder {
186 values: &[Encoded::Value(b"1")],
187 }
188 .to_vec(),
189 );
190 lengths.push(test_data.len() - len);
191 (test_data, lengths)
192 }
193
194 fn process_chunks(buf: &[u8], chunk_lengths: &[usize]) {
195 assert_eq!(
196 chunk_lengths.iter().sum::<usize>(),
197 buf.len(),
198 "Sum of chunk lengths must equal total buffer length"
199 );
200
201 let mut accumulated_messages: Vec<Vec<u8>> = Vec::new();
202 let mut buffer = StructBuffer::<Message>::default();
203 let mut f = |msg: Result<Message, ParseError>| {
204 let msg = msg.unwrap();
205 eprintln!("Message: {msg:?}");
206 accumulated_messages.push(msg.to_vec());
207 };
208
209 let mut start = 0;
210 for &length in chunk_lengths {
211 let end = start + length;
212 let chunk = &buf[start..end];
213 eprintln!("Chunk: {chunk:?}");
214
215 buffer.push(chunk, &mut f);
216 start = end;
217 }
218
219 assert_eq!(accumulated_messages.len(), 3);
220
221 let mut out = vec![];
222 for message in accumulated_messages {
223 out.append(&mut message.to_vec());
224 }
225
226 assert_eq!(&out, buf);
227 }
228
229 #[test]
230 fn test_message_buffer_chunked() {
231 let (test_data, chunk_lengths) = test_data();
232 process_chunks(&test_data, &chunk_lengths);
233 }
234
235 #[test]
236 fn test_message_buffer_byte_by_byte() {
237 let (test_data, _) = test_data();
238 let chunk_lengths: Vec<usize> = vec![1; test_data.len()];
239 process_chunks(&test_data, &chunk_lengths);
240 }
241
242 #[test]
243 fn test_message_buffer_incremental_chunks() {
244 let (test_data, _) = test_data();
245 for i in 0..test_data.len() {
246 let chunk_lengths = vec![i, test_data.len() - i];
247 process_chunks(&test_data, &chunk_lengths);
248 }
249 }
250}