1use std::convert::TryFrom;
2
3use crate::planner::ResolvedType;
4
5use super::error::{Result, StorageError};
6use super::value::SqlValue;
7
8const MAX_INLINE_BYTES: usize = 16 * 1024 * 1024; const MAX_VECTOR_LEN: usize = 4 * 1024 * 1024; pub struct RowCodec;
22
23impl RowCodec {
24 pub fn encode(row: &[SqlValue]) -> Vec<u8> {
26 let column_count =
27 u16::try_from(row.len()).expect("row column count exceeds u16::MAX (design limit)");
28 let null_bytes = (column_count as usize).div_ceil(8);
29
30 let mut buf = Vec::with_capacity(2 + null_bytes + row.len() * 8);
32 buf.extend_from_slice(&column_count.to_le_bytes());
33
34 let mut null_bitmap = vec![0u8; null_bytes];
35 for (idx, val) in row.iter().enumerate() {
36 if val.is_null() {
37 null_bitmap[idx / 8] |= 1 << (idx % 8);
38 }
39 }
40 buf.extend_from_slice(&null_bitmap);
41
42 for value in row {
43 if value.is_null() {
44 continue;
45 }
46 buf.push(value.type_tag());
47 encode_value(value, &mut buf);
48 }
49
50 buf
51 }
52
53 pub fn decode(bytes: &[u8]) -> Result<Vec<SqlValue>> {
55 let mut cursor = 0;
56 if bytes.len() < 2 {
57 return Err(StorageError::CorruptedData {
58 reason: "missing column count".into(),
59 });
60 }
61
62 let column_count =
63 u16::from_le_bytes(bytes[cursor..cursor + 2].try_into().unwrap()) as usize;
64 cursor += 2;
65
66 let null_bytes = column_count.div_ceil(8);
67 if bytes.len() < cursor + null_bytes {
68 return Err(StorageError::CorruptedData {
69 reason: "missing null bitmap".into(),
70 });
71 }
72 let null_bitmap = &bytes[cursor..cursor + null_bytes];
73 cursor += null_bytes;
74
75 let mut values = Vec::with_capacity(column_count);
76 for idx in 0..column_count {
77 let is_null = (null_bitmap[idx / 8] & (1 << (idx % 8))) != 0;
78 if is_null {
79 values.push(SqlValue::Null);
80 continue;
81 }
82
83 if cursor >= bytes.len() {
84 return Err(StorageError::CorruptedData {
85 reason: "missing type tag".into(),
86 });
87 }
88 let tag = bytes[cursor];
89 cursor += 1;
90
91 let value = decode_value(tag, bytes, &mut cursor)?;
92 values.push(value);
93 }
94
95 if cursor != bytes.len() {
96 return Err(StorageError::CorruptedData {
97 reason: "trailing bytes after decoding row".into(),
98 });
99 }
100
101 Ok(values)
102 }
103
104 pub fn decode_with_schema(bytes: &[u8], schema: &[ResolvedType]) -> Result<Vec<SqlValue>> {
106 let values = Self::decode(bytes)?;
107
108 if values.len() != schema.len() {
109 return Err(StorageError::CorruptedData {
110 reason: format!(
111 "column count mismatch: encoded={}, expected={}",
112 values.len(),
113 schema.len()
114 ),
115 });
116 }
117
118 values
119 .into_iter()
120 .zip(schema.iter())
121 .map(|(value, ty)| ensure_type(value, ty))
122 .collect()
123 }
124}
125
126fn encode_value(value: &SqlValue, buf: &mut Vec<u8>) {
127 match value {
128 SqlValue::Null => {}
129 SqlValue::Integer(v) => buf.extend_from_slice(&v.to_le_bytes()),
130 SqlValue::BigInt(v) => buf.extend_from_slice(&v.to_le_bytes()),
131 SqlValue::Float(v) => buf.extend_from_slice(&v.to_bits().to_le_bytes()),
132 SqlValue::Double(v) => buf.extend_from_slice(&v.to_bits().to_le_bytes()),
133 SqlValue::Text(s) => {
134 let len = u32::try_from(s.len())
135 .expect("text length exceeds u32::MAX (design limit for row encoding)");
136 buf.extend_from_slice(&len.to_le_bytes());
137 buf.extend_from_slice(s.as_bytes());
138 }
139 SqlValue::Blob(bytes) => {
140 let len = u32::try_from(bytes.len())
141 .expect("blob length exceeds u32::MAX (design limit for row encoding)");
142 buf.extend_from_slice(&len.to_le_bytes());
143 buf.extend_from_slice(bytes);
144 }
145 SqlValue::Boolean(b) => buf.push(u8::from(*b)),
146 SqlValue::Timestamp(v) => buf.extend_from_slice(&v.to_le_bytes()),
147 SqlValue::Vector(values) => {
148 let len = u32::try_from(values.len())
149 .expect("vector length exceeds u32::MAX (design limit for row encoding)");
150 buf.extend_from_slice(&len.to_le_bytes());
151 for f in values {
152 buf.extend_from_slice(&f.to_bits().to_le_bytes());
153 }
154 }
155 }
156}
157
158fn decode_value(tag: u8, bytes: &[u8], cursor: &mut usize) -> Result<SqlValue> {
159 let mut take = |len: usize, reason: &'static str| -> Result<&[u8]> {
160 let end = cursor
161 .checked_add(len)
162 .ok_or_else(|| StorageError::CorruptedData {
163 reason: reason.to_string(),
164 })?;
165 if end > bytes.len() {
166 return Err(StorageError::CorruptedData {
167 reason: reason.to_string(),
168 });
169 }
170 let slice = &bytes[*cursor..end];
171 *cursor = end;
172 Ok(slice)
173 };
174
175 match tag {
176 0x00 => Ok(SqlValue::Null),
177 0x01 => {
178 let raw = take(4, "truncated Integer value")?;
179 Ok(SqlValue::Integer(i32::from_le_bytes(
180 raw.try_into().unwrap(),
181 )))
182 }
183 0x02 => {
184 let raw = take(8, "truncated BigInt value")?;
185 Ok(SqlValue::BigInt(i64::from_le_bytes(
186 raw.try_into().unwrap(),
187 )))
188 }
189 0x03 => {
190 let raw = take(4, "truncated Float value")?;
191 Ok(SqlValue::Float(f32::from_bits(u32::from_le_bytes(
192 raw.try_into().unwrap(),
193 ))))
194 }
195 0x04 => {
196 let raw = take(8, "truncated Double value")?;
197 Ok(SqlValue::Double(f64::from_bits(u64::from_le_bytes(
198 raw.try_into().unwrap(),
199 ))))
200 }
201 0x05 => {
202 let len_bytes = take(4, "truncated Text length")?;
203 let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
204 if len > MAX_INLINE_BYTES {
205 return Err(StorageError::CorruptedData {
206 reason: format!("text length exceeds limit: {len}"),
207 });
208 }
209 let raw = take(len, "truncated Text payload")?;
210 let s = String::from_utf8(raw.to_vec()).map_err(|_| StorageError::CorruptedData {
211 reason: "invalid UTF-8 in Text".into(),
212 })?;
213 Ok(SqlValue::Text(s))
214 }
215 0x06 => {
216 let len_bytes = take(4, "truncated Blob length")?;
217 let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
218 if len > MAX_INLINE_BYTES {
219 return Err(StorageError::CorruptedData {
220 reason: format!("blob length exceeds limit: {len}"),
221 });
222 }
223 let raw = take(len, "truncated Blob payload")?;
224 Ok(SqlValue::Blob(raw.to_vec()))
225 }
226 0x07 => {
227 let raw = take(1, "truncated Boolean")?[0];
228 match raw {
229 0 => Ok(SqlValue::Boolean(false)),
230 1 => Ok(SqlValue::Boolean(true)),
231 other => Err(StorageError::CorruptedData {
232 reason: format!("invalid boolean value: {}", other),
233 }),
234 }
235 }
236 0x08 => {
237 let raw = take(8, "truncated Timestamp value")?;
238 Ok(SqlValue::Timestamp(i64::from_le_bytes(
239 raw.try_into().unwrap(),
240 )))
241 }
242 0x09 => {
243 let len_bytes = take(4, "truncated Vector length")?;
244 let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
245 if len > MAX_VECTOR_LEN {
246 return Err(StorageError::CorruptedData {
247 reason: format!("vector length exceeds limit: {len}"),
248 });
249 }
250 let total = len
251 .checked_mul(4)
252 .ok_or_else(|| StorageError::CorruptedData {
253 reason: "vector length overflow".into(),
254 })?;
255 let raw = take(total, "truncated Vector payload")?;
256
257 let mut values = Vec::with_capacity(len);
258 for chunk in raw.chunks_exact(4) {
259 values.push(f32::from_bits(u32::from_le_bytes(
260 chunk.try_into().unwrap(),
261 )));
262 }
263 Ok(SqlValue::Vector(values))
264 }
265 other => Err(StorageError::CorruptedData {
266 reason: format!("unknown type tag: 0x{other:02x}"),
267 }),
268 }
269}
270
271fn ensure_type(value: SqlValue, expected: &ResolvedType) -> Result<SqlValue> {
272 use ResolvedType::*;
273 match (expected, value) {
274 (_, SqlValue::Null) => Ok(SqlValue::Null),
275 (Integer, SqlValue::Integer(v)) => Ok(SqlValue::Integer(v)),
276 (BigInt, SqlValue::BigInt(v)) => Ok(SqlValue::BigInt(v)),
277 (Float, SqlValue::Float(v)) => Ok(SqlValue::Float(v)),
278 (Double, SqlValue::Double(v)) => Ok(SqlValue::Double(v)),
279 (Text, SqlValue::Text(s)) => Ok(SqlValue::Text(s)),
280 (Blob, SqlValue::Blob(b)) => Ok(SqlValue::Blob(b)),
281 (Boolean, SqlValue::Boolean(v)) => Ok(SqlValue::Boolean(v)),
282 (Timestamp, SqlValue::Timestamp(v)) => Ok(SqlValue::Timestamp(v)),
283 (Vector { dimension, .. }, SqlValue::Vector(values)) => {
284 if values.len() as u32 == *dimension {
285 Ok(SqlValue::Vector(values))
286 } else {
287 Err(StorageError::TypeMismatch {
288 expected: format!("Vector(dim={})", dimension),
289 actual: format!("Vector(dim={})", values.len()),
290 })
291 }
292 }
293 (expected_ty, actual) => Err(StorageError::TypeMismatch {
294 expected: expected_ty.type_name().to_string(),
295 actual: actual.type_name().to_string(),
296 }),
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use proptest::prelude::*;
304
305 fn values_equal(a: &SqlValue, b: &SqlValue) -> bool {
306 match (a, b) {
307 (SqlValue::Float(x), SqlValue::Float(y)) => x.to_bits() == y.to_bits(),
308 (SqlValue::Double(x), SqlValue::Double(y)) => x.to_bits() == y.to_bits(),
309 (SqlValue::Vector(xs), SqlValue::Vector(ys)) => {
310 xs.len() == ys.len()
311 && xs
312 .iter()
313 .zip(ys.iter())
314 .all(|(x, y)| x.to_bits() == y.to_bits())
315 }
316 _ => a == b,
317 }
318 }
319
320 fn row_equal(a: &[SqlValue], b: &[SqlValue]) -> bool {
321 a.len() == b.len()
322 && a.iter()
323 .zip(b.iter())
324 .all(|(lhs, rhs)| values_equal(lhs, rhs))
325 }
326
327 fn sql_value_strategy() -> impl Strategy<Value = SqlValue> {
328 let finite_f32 = any::<f32>();
329 let finite_f64 = any::<f64>();
330 prop_oneof![
331 Just(SqlValue::Null),
332 any::<i32>().prop_map(SqlValue::Integer),
333 any::<i64>().prop_map(SqlValue::BigInt),
334 finite_f32.prop_map(SqlValue::Float),
335 finite_f64.prop_map(SqlValue::Double),
336 ".*".prop_map(SqlValue::Text),
337 proptest::collection::vec(any::<u8>(), 0..32).prop_map(SqlValue::Blob),
338 any::<bool>().prop_map(SqlValue::Boolean),
339 any::<i64>().prop_map(SqlValue::Timestamp),
340 proptest::collection::vec(any::<f32>(), 0..8).prop_map(SqlValue::Vector),
341 ]
342 }
343
344 #[test]
345 fn roundtrip_preserves_all_types() {
346 let row = vec![
347 SqlValue::Null,
348 SqlValue::Integer(42),
349 SqlValue::BigInt(-42),
350 SqlValue::Float(1.5),
351 SqlValue::Double(-2.5),
352 SqlValue::Text("hello".into()),
353 SqlValue::Blob(vec![0x01, 0x02]),
354 SqlValue::Boolean(true),
355 SqlValue::Timestamp(1_700_000_000),
356 SqlValue::Vector(vec![0.1, 0.2, 0.3]),
357 ];
358
359 let encoded = RowCodec::encode(&row);
360 let decoded = RowCodec::decode(&encoded).unwrap();
361
362 assert!(row_equal(&row, &decoded));
363 }
364
365 #[test]
366 fn null_bitmap_is_respected() {
367 let row = vec![SqlValue::Integer(1), SqlValue::Null, SqlValue::Integer(2)];
368 let encoded = RowCodec::encode(&row);
369 let decoded = RowCodec::decode(&encoded).unwrap();
370 assert!(matches!(decoded[1], SqlValue::Null));
371 }
372
373 #[test]
374 fn corruption_is_detected_for_truncated_payload() {
375 let row = vec![SqlValue::Text("abc".into())];
376 let mut encoded = RowCodec::encode(&row);
377 encoded.pop(); let err = RowCodec::decode(&encoded).unwrap_err();
379 assert!(matches!(err, StorageError::CorruptedData { .. }));
380 }
381
382 #[test]
383 fn corruption_is_detected_for_unknown_tag() {
384 let bytes = vec![1, 0, 0, 0xFF];
386 let err = RowCodec::decode(&bytes).unwrap_err();
387 assert!(matches!(err, StorageError::CorruptedData { .. }));
388 }
389
390 #[test]
391 fn oversized_lengths_are_rejected() {
392 let mut bytes = Vec::new();
394 bytes.extend_from_slice(&(1u16).to_le_bytes()); bytes.push(0); bytes.push(0x05); let too_large = (super::MAX_INLINE_BYTES as u32) + 1;
398 bytes.extend_from_slice(&too_large.to_le_bytes());
399 let err = RowCodec::decode(&bytes).unwrap_err();
400 assert!(matches!(err, StorageError::CorruptedData { .. }));
401 }
402
403 #[test]
404 fn oversized_vector_is_rejected() {
405 let mut bytes = Vec::new();
406 bytes.extend_from_slice(&(1u16).to_le_bytes()); bytes.push(0); bytes.push(0x09); let too_large = (super::MAX_VECTOR_LEN as u32) + 1;
410 bytes.extend_from_slice(&too_large.to_le_bytes());
411 let err = RowCodec::decode(&bytes).unwrap_err();
412 assert!(matches!(err, StorageError::CorruptedData { .. }));
413 }
414
415 #[test]
416 fn decode_with_schema_validates_types() {
417 let row = vec![SqlValue::Vector(vec![1.0, 2.0])];
418 let encoded = RowCodec::encode(&row);
419 let schema = vec![ResolvedType::Vector {
420 dimension: 3,
421 metric: crate::ast::ddl::VectorMetric::Cosine,
422 }];
423 let err = RowCodec::decode_with_schema(&encoded, &schema).unwrap_err();
424 assert!(matches!(err, StorageError::TypeMismatch { .. }));
425 }
426
427 proptest! {
428 #[test]
429 fn proptest_roundtrip(row in proptest::collection::vec(sql_value_strategy(), 0..16)) {
430 let encoded = RowCodec::encode(&row);
431 let decoded = RowCodec::decode(&encoded).unwrap();
432 prop_assert!(row_equal(&row, &decoded));
433 }
434
435 #[test]
436 fn decode_with_schema_matches_lengths(row in proptest::collection::vec(sql_value_strategy(), 1..5)) {
437 let schema: Vec<ResolvedType> = row.iter().map(|v| v.resolved_type()).collect();
438 let encoded = RowCodec::encode(&row);
439 let decoded = RowCodec::decode_with_schema(&encoded, &schema).unwrap();
440 prop_assert!(row_equal(&row, &decoded));
441 }
442 }
443}