gel_db_protocol/
buffer.rs1use crate::prelude::*;
2use std::{collections::VecDeque, marker::PhantomData};
3
4#[derive(Default)]
7pub struct StructBuffer<M: StructLength> {
8 _phantom: PhantomData<M>,
9 accum: VecDeque<u8>,
10}
11
12impl<M: StructLength> StructBuffer<M> {
13 pub fn push<'a: 'b, 'b>(
32 &'b mut self,
33 bytes: &'a [u8],
34 mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>),
35 ) {
36 if self.accum.is_empty() {
37 let mut offset = 0;
39 while offset < bytes.len() {
40 if let Some(len) = M::length_of_buf(&bytes[offset..]) {
41 if offset + len <= bytes.len() {
42 f(M::new(&bytes[offset..offset + len]));
43 offset += len;
44 } else {
45 break;
46 }
47 } else {
48 break;
49 }
50 }
51 if offset == bytes.len() {
52 return;
53 }
54 self.accum.extend(&bytes[offset..]);
55 } else {
56 self.accum.extend(bytes);
57 }
58
59 let contiguous = self.accum.make_contiguous();
61 let mut total_processed = 0;
62 while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
63 if total_processed + len <= contiguous.len() {
64 let message_bytes = &contiguous[total_processed..total_processed + len];
65 f(M::new(message_bytes));
66 total_processed += len;
67 } else {
68 break;
69 }
70 }
71 if total_processed > 0 {
72 self.accum.rotate_left(total_processed);
73 self.accum.truncate(self.accum.len() - total_processed);
74 }
75 }
76
77 pub fn push_fallible<'a: 'b, 'b, E>(
96 &'b mut self,
97 bytes: &'a [u8],
98 mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>) -> Result<(), E>,
99 ) -> Result<(), E> {
100 if self.accum.is_empty() {
101 let mut offset = 0;
103 while offset < bytes.len() {
104 if let Some(len) = M::length_of_buf(&bytes[offset..]) {
105 if offset + len <= bytes.len() {
106 let msg = M::new(&bytes[offset..offset + len]);
107 f(msg)?;
108 offset += len;
109 } else {
110 break;
111 }
112 } else {
113 break;
114 }
115 }
116 if offset == bytes.len() {
117 return Ok(());
118 }
119 self.accum.extend(&bytes[offset..]);
120 } else {
121 self.accum.extend(bytes);
122 }
123
124 let contiguous = self.accum.make_contiguous();
126 let mut total_processed = 0;
127 while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
128 if total_processed + len <= contiguous.len() {
129 let message_bytes = &contiguous[total_processed..total_processed + len];
130 f(M::new(message_bytes))?;
131 total_processed += len;
132 } else {
133 break;
134 }
135 }
136 if total_processed > 0 {
137 self.accum.rotate_left(total_processed);
138 self.accum.truncate(self.accum.len() - total_processed);
139 }
140 Ok(())
141 }
142
143 pub fn into_inner(self) -> VecDeque<u8> {
144 self.accum
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.accum.is_empty()
149 }
150
151 pub fn len(&self) -> usize {
152 self.accum.len()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use crate::prelude::*;
159
160 use super::StructBuffer;
161 use crate::test_protocol::*;
162
163 fn test_data() -> (Vec<u8>, Vec<usize>) {
165 let mut test_data = vec![];
166 let mut lengths = vec![];
167 test_data.append(&mut SyncBuilder::default().to_vec());
168 let len = test_data.len();
169 lengths.push(len);
170 test_data.append(&mut CommandCompleteBuilder { tag: "TAG" }.to_vec());
171 lengths.push(test_data.len() - len);
172 let len = test_data.len();
173 test_data.append(
174 &mut DataRowBuilder {
175 values: &[Encoded::Value(b"1")],
176 }
177 .to_vec(),
178 );
179 lengths.push(test_data.len() - len);
180 (test_data, lengths)
181 }
182
183 fn process_chunks(buf: &[u8], chunk_lengths: &[usize]) {
184 assert_eq!(
185 chunk_lengths.iter().sum::<usize>(),
186 buf.len(),
187 "Sum of chunk lengths must equal total buffer length"
188 );
189
190 let mut accumulated_messages: Vec<Vec<u8>> = Vec::new();
191 let mut buffer = StructBuffer::<Message>::default();
192 let mut f = |msg: Result<Message, ParseError>| {
193 let msg = msg.unwrap();
194 eprintln!("Message: {msg:?}");
195 accumulated_messages.push(msg.to_vec());
196 };
197
198 let mut start = 0;
199 for &length in chunk_lengths {
200 let end = start + length;
201 let chunk = &buf[start..end];
202 eprintln!("Chunk: {chunk:?}");
203
204 buffer.push(chunk, &mut f);
205 start = end;
206 }
207
208 assert_eq!(accumulated_messages.len(), 3);
209
210 let mut out = vec![];
211 for message in accumulated_messages {
212 out.append(&mut message.to_vec());
213 }
214
215 assert_eq!(&out, buf);
216 }
217
218 #[test]
219 fn test_message_buffer_chunked() {
220 let (test_data, chunk_lengths) = test_data();
221 process_chunks(&test_data, &chunk_lengths);
222 }
223
224 #[test]
225 fn test_message_buffer_byte_by_byte() {
226 let (test_data, _) = test_data();
227 let chunk_lengths: Vec<usize> = vec![1; test_data.len()];
228 process_chunks(&test_data, &chunk_lengths);
229 }
230
231 #[test]
232 fn test_message_buffer_incremental_chunks() {
233 let (test_data, _) = test_data();
234 for i in 0..test_data.len() {
235 let chunk_lengths = vec![i, test_data.len() - i];
236 process_chunks(&test_data, &chunk_lengths);
237 }
238 }
239}