feophantlib/engine/io/row_formats/
null_mask.rs

1//! Implementation of the null bit flags to know if a column is null or not
2//! I'm not using a standard library because the bitvector library collides with nom
3use crate::engine::{io::EncodedSize, objects::SqlTuple};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use thiserror::Error;
6
7pub struct NullMask {}
8
9impl NullMask {
10    /// Writes out a bit vector that sets a bit for every null value.
11    ///
12    /// # Examples
13    /// ```
14    /// # use feophantlib::engine::{io::row_formats::NullMask, objects::{SqlTuple, types::BaseSqlTypes}};
15    /// # use hex_literal::hex;
16    /// # use bytes::Bytes;
17    ///
18    /// let test = SqlTuple(vec![Some(BaseSqlTypes::Bool(true)),
19    ///     Some(BaseSqlTypes::Bool(true)),
20    ///     Some(BaseSqlTypes::Bool(true)),
21    ///     ]);
22    ///
23    /// let mask = NullMask::serialize(&test);
24    /// assert_eq!(hex!("00").to_vec(), mask);
25    /// ```
26    pub fn serialize(input: &SqlTuple) -> Bytes {
27        if input.0.is_empty() {
28            return Bytes::new();
29        }
30
31        let mut buffer = BytesMut::new();
32
33        let mut value: u8 = 0;
34        let mut mask: u8 = 0x80;
35        let mut i = 0;
36        loop {
37            if input.0[i].is_none() {
38                value |= mask;
39            }
40
41            if (i + 1) == input.0.len() {
42                if (i + 1) % 8 != 0 {
43                    buffer.put_u8(value);
44                }
45                break;
46            }
47
48            if (i + 1) % 8 == 0 && i > 0 {
49                buffer.put_u8(value);
50                value = 0;
51                mask = 0x80;
52            } else {
53                mask >>= 1;
54            }
55
56            i += 1;
57        }
58
59        buffer.freeze()
60    }
61
62    pub fn parse(buffer: &mut impl Buf, column_count: usize) -> Result<Vec<bool>, NullMaskError> {
63        let mut nulls = Vec::with_capacity(((column_count + 7) / 8) * 8);
64
65        if buffer.remaining() <= column_count / 8 {
66            return Err(NullMaskError::BufferTooShort(
67                buffer.remaining(),
68                column_count / 8,
69            ));
70        }
71
72        let mut remaining_columns = column_count;
73        while remaining_columns > 0 {
74            let mut temp = buffer.get_u8();
75            for _ in 0..8 {
76                if temp & 0x80 > 0 {
77                    nulls.push(true);
78                } else {
79                    nulls.push(false);
80                }
81                temp <<= 1;
82            }
83            remaining_columns = remaining_columns.saturating_sub(8);
84        }
85
86        //This is needed since we encode more values than columns
87        nulls.resize(column_count, false);
88
89        Ok(nulls)
90    }
91}
92
93impl EncodedSize<&SqlTuple> for NullMask {
94    fn encoded_size(input: &SqlTuple) -> usize {
95        //Discussion here: https://github.com/rust-lang/rfcs/issues/2844
96        (input.len() + 8 - 1) / 8
97    }
98}
99
100#[derive(Debug, Error, PartialEq)]
101pub enum NullMaskError {
102    #[error("Buffer too short to parse found {0} bytes, need {1}")]
103    BufferTooShort(usize, usize),
104}
105
106#[cfg(test)]
107mod tests {
108    use crate::engine::objects::types::BaseSqlTypes;
109
110    use super::*;
111    use hex_literal::hex;
112
113    fn get_tuple() -> SqlTuple {
114        SqlTuple(vec![
115            None,
116            Some(BaseSqlTypes::Bool(true)),
117            None,
118            Some(BaseSqlTypes::Bool(true)),
119            None,
120            Some(BaseSqlTypes::Bool(true)),
121            None,
122            Some(BaseSqlTypes::Bool(true)),
123            None,
124            Some(BaseSqlTypes::Bool(true)),
125        ])
126    }
127
128    #[test]
129    fn sizes_match() -> Result<(), Box<dyn std::error::Error>> {
130        let test = SqlTuple(vec![
131            Some(BaseSqlTypes::Text("test".to_string())),
132            Some(BaseSqlTypes::Integer(0)),
133        ]);
134
135        let calc_len = NullMask::encoded_size(&test);
136
137        let buffer = NullMask::serialize(&test);
138
139        assert_eq!(calc_len, buffer.len());
140        Ok(())
141    }
142
143    #[test]
144    fn test_null_mask_serialize() {
145        let test = get_tuple();
146
147        let result = NullMask::serialize(&test);
148
149        assert_eq!(hex!("aa 80").to_vec(), result.to_vec());
150    }
151
152    #[test]
153    fn test_null_mask_single() {
154        let test = SqlTuple(vec![None]);
155
156        let result = NullMask::serialize(&test);
157
158        assert_eq!(hex!("80").to_vec(), result.to_vec());
159    }
160
161    #[test]
162    fn test_null_mask_parse() -> Result<(), Box<dyn std::error::Error>> {
163        let test = vec![
164            true, false, true, false, true, false, true, false, true, false,
165        ];
166
167        let res = NullMask::parse(&mut Bytes::from_static(&hex!("aa 80")), 10)?;
168
169        assert_eq!(res, test);
170        Ok(())
171    }
172
173    #[test]
174    fn test_null_mask_parse_short() -> Result<(), Box<dyn std::error::Error>> {
175        let res = NullMask::parse(&mut Bytes::from_static(&hex!("80")), 9);
176        assert_eq!(res, Err(NullMaskError::BufferTooShort(1, 1)));
177        Ok(())
178    }
179
180    #[test]
181    fn test_null_mask_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
182        let test = get_tuple();
183
184        let end = vec![
185            true, false, true, false, true, false, true, false, true, false, false, false,
186        ];
187
188        let mut result = NullMask::serialize(&test);
189        assert_eq!(Bytes::from_static(&hex!("aa 80")), result);
190        let parse = NullMask::parse(&mut result, 12)?;
191
192        assert_eq!(end, parse);
193        Ok(())
194    }
195
196    #[test]
197    fn test_encoded_size() {
198        assert_eq!(2, NullMask::encoded_size(&get_tuple()));
199
200        let none_null = SqlTuple(vec![
201            Some(BaseSqlTypes::Bool(true)),
202            Some(BaseSqlTypes::Bool(true)),
203        ]);
204        assert_eq!(1, NullMask::encoded_size(&none_null));
205        assert_eq!(1, NullMask::serialize(&none_null).len());
206    }
207}