gel_db_protocol/
buffer.rs

1use crate::prelude::*;
2use std::{collections::VecDeque, marker::PhantomData};
3
4/// A buffer that accumulates bytes of sized structs and feeds them to provided sink function when messages
5/// are complete. This buffer handles partial messages and multiple messages in a single push.
6#[derive(derive_more::Debug)]
7pub struct StructBuffer<M: StructLength> {
8    #[debug(skip)]
9    _phantom: PhantomData<M>,
10    #[debug("Length={}", accum.len())]
11    accum: VecDeque<u8>,
12}
13
14impl<M: StructLength> Default for StructBuffer<M> {
15    fn default() -> Self {
16        Self {
17            _phantom: PhantomData,
18            accum: VecDeque::with_capacity(32 * 1024),
19        }
20    }
21}
22
23impl<M: StructLength> StructBuffer<M> {
24    /// Pushes bytes into the buffer, potentially feeding output to the function.
25    ///
26    /// # Lifetimes
27    /// - `'a`: The lifetime of the input byte slice.
28    /// - `'b`: The lifetime of the mutable reference to `self`.
29    /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it.
30    ///
31    /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`.
32    ///
33    /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure
34    /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because:
35    ///
36    /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known
37    ///    at the time the closure is defined.
38    /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration
39    ///    of each call to the closure, not for the entire lifetime of the `push` method.
40    /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime
41    ///    is limited to the scope of each closure invocation.
42    pub fn push<'a: 'b, 'b>(
43        &'b mut self,
44        bytes: &'a [u8],
45        mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>),
46    ) {
47        if self.accum.is_empty() {
48            // Fast path: try to process the input directly
49            let mut offset = 0;
50            while offset < bytes.len() {
51                if let Some(len) = M::length_of_buf(&bytes[offset..]) {
52                    if offset + len <= bytes.len() {
53                        f(M::new(&bytes[offset..offset + len]));
54                        offset += len;
55                    } else {
56                        break;
57                    }
58                } else {
59                    break;
60                }
61            }
62            if offset == bytes.len() {
63                return;
64            }
65            self.accum.extend(&bytes[offset..]);
66        } else {
67            self.accum.extend(bytes);
68        }
69
70        // Slow path: process accumulated data
71        let contiguous = self.accum.make_contiguous();
72        let mut total_processed = 0;
73        while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
74            if total_processed + len <= contiguous.len() {
75                let message_bytes = &contiguous[total_processed..total_processed + len];
76                f(M::new(message_bytes));
77                total_processed += len;
78            } else {
79                break;
80            }
81        }
82        if total_processed > 0 {
83            self.accum.rotate_left(total_processed);
84            self.accum.truncate(self.accum.len() - total_processed);
85        }
86    }
87
88    /// Pushes bytes into the buffer, potentially feeding output to the function.
89    ///
90    /// # Lifetimes
91    /// - `'a`: The lifetime of the input byte slice.
92    /// - `'b`: The lifetime of the mutable reference to `self`.
93    /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it.
94    ///
95    /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`.
96    ///
97    /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure
98    /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because:
99    ///
100    /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known
101    ///    at the time the closure is defined.
102    /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration
103    ///    of each call to the closure, not for the entire lifetime of the `push` method.
104    /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime
105    ///    is limited to the scope of each closure invocation.
106    pub fn push_fallible<'a: 'b, 'b, E>(
107        &'b mut self,
108        bytes: &'a [u8],
109        mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>) -> Result<(), E>,
110    ) -> Result<(), E> {
111        if self.accum.is_empty() {
112            // Fast path: try to process the input directly
113            let mut offset = 0;
114            while offset < bytes.len() {
115                if let Some(len) = M::length_of_buf(&bytes[offset..]) {
116                    if offset + len <= bytes.len() {
117                        let msg = M::new(&bytes[offset..offset + len]);
118                        f(msg)?;
119                        offset += len;
120                    } else {
121                        break;
122                    }
123                } else {
124                    break;
125                }
126            }
127            if offset == bytes.len() {
128                return Ok(());
129            }
130            self.accum.extend(&bytes[offset..]);
131        } else {
132            self.accum.extend(bytes);
133        }
134
135        // Slow path: process accumulated data
136        let contiguous = self.accum.make_contiguous();
137        let mut total_processed = 0;
138        while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
139            if total_processed + len <= contiguous.len() {
140                let message_bytes = &contiguous[total_processed..total_processed + len];
141                f(M::new(message_bytes))?;
142                total_processed += len;
143            } else {
144                break;
145            }
146        }
147        if total_processed > 0 {
148            self.accum.rotate_left(total_processed);
149            self.accum.truncate(self.accum.len() - total_processed);
150        }
151        Ok(())
152    }
153
154    pub fn into_inner(self) -> VecDeque<u8> {
155        self.accum
156    }
157
158    pub fn is_empty(&self) -> bool {
159        self.accum.is_empty()
160    }
161
162    pub fn len(&self) -> usize {
163        self.accum.len()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::prelude::*;
170
171    use super::StructBuffer;
172    use crate::test_protocol::*;
173
174    /// Create a test data buffer containing three messages
175    fn test_data() -> (Vec<u8>, Vec<usize>) {
176        let mut test_data = vec![];
177        let mut lengths = vec![];
178        test_data.append(&mut SyncBuilder::default().to_vec());
179        let len = test_data.len();
180        lengths.push(len);
181        test_data.append(&mut CommandCompleteBuilder { tag: "TAG" }.to_vec());
182        lengths.push(test_data.len() - len);
183        let len = test_data.len();
184        test_data.append(
185            &mut DataRowBuilder {
186                values: &[Encoded::Value(b"1")],
187            }
188            .to_vec(),
189        );
190        lengths.push(test_data.len() - len);
191        (test_data, lengths)
192    }
193
194    fn process_chunks(buf: &[u8], chunk_lengths: &[usize]) {
195        assert_eq!(
196            chunk_lengths.iter().sum::<usize>(),
197            buf.len(),
198            "Sum of chunk lengths must equal total buffer length"
199        );
200
201        let mut accumulated_messages: Vec<Vec<u8>> = Vec::new();
202        let mut buffer = StructBuffer::<Message>::default();
203        let mut f = |msg: Result<Message, ParseError>| {
204            let msg = msg.unwrap();
205            eprintln!("Message: {msg:?}");
206            accumulated_messages.push(msg.to_vec());
207        };
208
209        let mut start = 0;
210        for &length in chunk_lengths {
211            let end = start + length;
212            let chunk = &buf[start..end];
213            eprintln!("Chunk: {chunk:?}");
214
215            buffer.push(chunk, &mut f);
216            start = end;
217        }
218
219        assert_eq!(accumulated_messages.len(), 3);
220
221        let mut out = vec![];
222        for message in accumulated_messages {
223            out.append(&mut message.to_vec());
224        }
225
226        assert_eq!(&out, buf);
227    }
228
229    #[test]
230    fn test_message_buffer_chunked() {
231        let (test_data, chunk_lengths) = test_data();
232        process_chunks(&test_data, &chunk_lengths);
233    }
234
235    #[test]
236    fn test_message_buffer_byte_by_byte() {
237        let (test_data, _) = test_data();
238        let chunk_lengths: Vec<usize> = vec![1; test_data.len()];
239        process_chunks(&test_data, &chunk_lengths);
240    }
241
242    #[test]
243    fn test_message_buffer_incremental_chunks() {
244        let (test_data, _) = test_data();
245        for i in 0..test_data.len() {
246            let chunk_lengths = vec![i, test_data.len() - i];
247            process_chunks(&test_data, &chunk_lengths);
248        }
249    }
250}