gel_db_protocol/
buffer.rs

1use crate::prelude::*;
2use std::{collections::VecDeque, marker::PhantomData};
3
4/// A buffer that accumulates bytes of sized structs and feeds them to provided sink function when messages
5/// are complete. This buffer handles partial messages and multiple messages in a single push.
6#[derive(Default)]
7pub struct StructBuffer<M: StructLength> {
8    _phantom: PhantomData<M>,
9    accum: VecDeque<u8>,
10}
11
12impl<M: StructLength> StructBuffer<M> {
13    /// Pushes bytes into the buffer, potentially feeding output to the function.
14    ///
15    /// # Lifetimes
16    /// - `'a`: The lifetime of the input byte slice.
17    /// - `'b`: The lifetime of the mutable reference to `self`.
18    /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it.
19    ///
20    /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`.
21    ///
22    /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure
23    /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because:
24    ///
25    /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known
26    ///    at the time the closure is defined.
27    /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration
28    ///    of each call to the closure, not for the entire lifetime of the `push` method.
29    /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime
30    ///    is limited to the scope of each closure invocation.
31    pub fn push<'a: 'b, 'b>(
32        &'b mut self,
33        bytes: &'a [u8],
34        mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>),
35    ) {
36        if self.accum.is_empty() {
37            // Fast path: try to process the input directly
38            let mut offset = 0;
39            while offset < bytes.len() {
40                if let Some(len) = M::length_of_buf(&bytes[offset..]) {
41                    if offset + len <= bytes.len() {
42                        f(M::new(&bytes[offset..offset + len]));
43                        offset += len;
44                    } else {
45                        break;
46                    }
47                } else {
48                    break;
49                }
50            }
51            if offset == bytes.len() {
52                return;
53            }
54            self.accum.extend(&bytes[offset..]);
55        } else {
56            self.accum.extend(bytes);
57        }
58
59        // Slow path: process accumulated data
60        let contiguous = self.accum.make_contiguous();
61        let mut total_processed = 0;
62        while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
63            if total_processed + len <= contiguous.len() {
64                let message_bytes = &contiguous[total_processed..total_processed + len];
65                f(M::new(message_bytes));
66                total_processed += len;
67            } else {
68                break;
69            }
70        }
71        if total_processed > 0 {
72            self.accum.rotate_left(total_processed);
73            self.accum.truncate(self.accum.len() - total_processed);
74        }
75    }
76
77    /// Pushes bytes into the buffer, potentially feeding output to the function.
78    ///
79    /// # Lifetimes
80    /// - `'a`: The lifetime of the input byte slice.
81    /// - `'b`: The lifetime of the mutable reference to `self`.
82    /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it.
83    ///
84    /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`.
85    ///
86    /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure
87    /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because:
88    ///
89    /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known
90    ///    at the time the closure is defined.
91    /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration
92    ///    of each call to the closure, not for the entire lifetime of the `push` method.
93    /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime
94    ///    is limited to the scope of each closure invocation.
95    pub fn push_fallible<'a: 'b, 'b, E>(
96        &'b mut self,
97        bytes: &'a [u8],
98        mut f: impl for<'c> FnMut(Result<M::Struct<'c>, ParseError>) -> Result<(), E>,
99    ) -> Result<(), E> {
100        if self.accum.is_empty() {
101            // Fast path: try to process the input directly
102            let mut offset = 0;
103            while offset < bytes.len() {
104                if let Some(len) = M::length_of_buf(&bytes[offset..]) {
105                    if offset + len <= bytes.len() {
106                        let msg = M::new(&bytes[offset..offset + len]);
107                        f(msg)?;
108                        offset += len;
109                    } else {
110                        break;
111                    }
112                } else {
113                    break;
114                }
115            }
116            if offset == bytes.len() {
117                return Ok(());
118            }
119            self.accum.extend(&bytes[offset..]);
120        } else {
121            self.accum.extend(bytes);
122        }
123
124        // Slow path: process accumulated data
125        let contiguous = self.accum.make_contiguous();
126        let mut total_processed = 0;
127        while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) {
128            if total_processed + len <= contiguous.len() {
129                let message_bytes = &contiguous[total_processed..total_processed + len];
130                f(M::new(message_bytes))?;
131                total_processed += len;
132            } else {
133                break;
134            }
135        }
136        if total_processed > 0 {
137            self.accum.rotate_left(total_processed);
138            self.accum.truncate(self.accum.len() - total_processed);
139        }
140        Ok(())
141    }
142
143    pub fn into_inner(self) -> VecDeque<u8> {
144        self.accum
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.accum.is_empty()
149    }
150
151    pub fn len(&self) -> usize {
152        self.accum.len()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use crate::prelude::*;
159
160    use super::StructBuffer;
161    use crate::test_protocol::*;
162
163    /// Create a test data buffer containing three messages
164    fn test_data() -> (Vec<u8>, Vec<usize>) {
165        let mut test_data = vec![];
166        let mut lengths = vec![];
167        test_data.append(&mut SyncBuilder::default().to_vec());
168        let len = test_data.len();
169        lengths.push(len);
170        test_data.append(&mut CommandCompleteBuilder { tag: "TAG" }.to_vec());
171        lengths.push(test_data.len() - len);
172        let len = test_data.len();
173        test_data.append(
174            &mut DataRowBuilder {
175                values: &[Encoded::Value(b"1")],
176            }
177            .to_vec(),
178        );
179        lengths.push(test_data.len() - len);
180        (test_data, lengths)
181    }
182
183    fn process_chunks(buf: &[u8], chunk_lengths: &[usize]) {
184        assert_eq!(
185            chunk_lengths.iter().sum::<usize>(),
186            buf.len(),
187            "Sum of chunk lengths must equal total buffer length"
188        );
189
190        let mut accumulated_messages: Vec<Vec<u8>> = Vec::new();
191        let mut buffer = StructBuffer::<Message>::default();
192        let mut f = |msg: Result<Message, ParseError>| {
193            let msg = msg.unwrap();
194            eprintln!("Message: {msg:?}");
195            accumulated_messages.push(msg.to_vec());
196        };
197
198        let mut start = 0;
199        for &length in chunk_lengths {
200            let end = start + length;
201            let chunk = &buf[start..end];
202            eprintln!("Chunk: {chunk:?}");
203
204            buffer.push(chunk, &mut f);
205            start = end;
206        }
207
208        assert_eq!(accumulated_messages.len(), 3);
209
210        let mut out = vec![];
211        for message in accumulated_messages {
212            out.append(&mut message.to_vec());
213        }
214
215        assert_eq!(&out, buf);
216    }
217
218    #[test]
219    fn test_message_buffer_chunked() {
220        let (test_data, chunk_lengths) = test_data();
221        process_chunks(&test_data, &chunk_lengths);
222    }
223
224    #[test]
225    fn test_message_buffer_byte_by_byte() {
226        let (test_data, _) = test_data();
227        let chunk_lengths: Vec<usize> = vec![1; test_data.len()];
228        process_chunks(&test_data, &chunk_lengths);
229    }
230
231    #[test]
232    fn test_message_buffer_incremental_chunks() {
233        let (test_data, _) = test_data();
234        for i in 0..test_data.len() {
235            let chunk_lengths = vec![i, test_data.len() - i];
236            process_chunks(&test_data, &chunk_lengths);
237        }
238    }
239}