1use crate::wire::{BoundedReader, ProtocolLimits, TtcReader, TtcWriter};
30use crate::{ProtocolError, Result};
31
32pub const TNS_VECTOR_MAGIC_BYTE: u8 = 0xDB;
34
35pub const TNS_VECTOR_VERSION_BASE: u8 = 0;
37pub const TNS_VECTOR_VERSION_WITH_BINARY: u8 = 1;
38pub const TNS_VECTOR_VERSION_WITH_SPARSE: u8 = 2;
39
40pub const TNS_VECTOR_FLAG_NORM: u16 = 0x0002;
42pub const TNS_VECTOR_FLAG_NORM_RESERVED: u16 = 0x0010;
43pub const TNS_VECTOR_FLAG_SPARSE: u16 = 0x0020;
44
45pub const VECTOR_FORMAT_FLOAT32: u8 = 2;
47pub const VECTOR_FORMAT_FLOAT64: u8 = 3;
48pub const VECTOR_FORMAT_INT8: u8 = 4;
49pub const VECTOR_FORMAT_BINARY: u8 = 5;
50
51#[derive(Clone, Debug, PartialEq)]
56pub enum VectorValues {
57 Float32(Vec<f32>),
59 Float64(Vec<f64>),
61 Int8(Vec<i8>),
63 Binary(Vec<u8>),
65}
66
67impl VectorValues {
68 pub fn format(&self) -> u8 {
70 match self {
71 VectorValues::Float32(_) => VECTOR_FORMAT_FLOAT32,
72 VectorValues::Float64(_) => VECTOR_FORMAT_FLOAT64,
73 VectorValues::Int8(_) => VECTOR_FORMAT_INT8,
74 VectorValues::Binary(_) => VECTOR_FORMAT_BINARY,
75 }
76 }
77
78 pub fn len(&self) -> usize {
81 match self {
82 VectorValues::Float32(v) => v.len(),
83 VectorValues::Float64(v) => v.len(),
84 VectorValues::Int8(v) => v.len(),
85 VectorValues::Binary(v) => v.len(),
86 }
87 }
88
89 pub fn is_empty(&self) -> bool {
90 self.len() == 0
91 }
92}
93
94#[derive(Clone, Debug, PartialEq)]
97pub enum Vector {
98 Dense(VectorValues),
99 Sparse {
100 num_dimensions: u32,
101 indices: Vec<u32>,
102 values: VectorValues,
103 },
104}
105
106pub fn decode_vector(data: &[u8]) -> Result<Vector> {
108 decode_vector_with_limits(data, ProtocolLimits::DEFAULT)
109}
110
111pub fn decode_vector_with_limits(data: &[u8], limits: ProtocolLimits) -> Result<Vector> {
113 limits.check_response_bytes(data.len())?;
114 let mut reader = TtcReader::with_limits(data, limits)?;
115
116 let magic = reader.read_u8()?;
117 if magic != TNS_VECTOR_MAGIC_BYTE {
118 return Err(ProtocolError::TtcDecode("vector: bad magic byte"));
119 }
120 let version = reader.read_u8()?;
121 if version > TNS_VECTOR_VERSION_WITH_SPARSE {
122 return Err(ProtocolError::TtcDecode("vector: unsupported version"));
123 }
124 let flags = read_u16be(&mut reader)?;
125 let format = reader.read_u8()?;
126 let mut num_elements = read_u32be(&mut reader)?;
127 reader
128 .limits()
129 .check_vector_dimensions(num_elements as usize)?;
130 if flags & TNS_VECTOR_FLAG_NORM_RESERVED != 0 || flags & TNS_VECTOR_FLAG_NORM != 0 {
131 reader.skip(8)?;
132 }
133
134 if flags & TNS_VECTOR_FLAG_SPARSE != 0 {
135 let num_dimensions = num_elements;
136 let num_sparse = read_u16be(&mut reader)?;
137 reader
138 .limits()
139 .check_vector_dimensions(usize::from(num_sparse))?;
140 let mut indices: Vec<u32> = reader.with_capacity_limited(
144 usize::from(num_sparse),
145 4,
146 ProtocolLimits::check_vector_dimensions,
147 )?;
148 for _ in 0..num_sparse {
149 indices.push(read_u32be(&mut reader)?);
150 }
151 let values = decode_values(&mut reader, u32::from(num_sparse), format)?;
152 return Ok(Vector::Sparse {
153 num_dimensions,
154 indices,
155 values,
156 });
157 }
158
159 if format == VECTOR_FORMAT_BINARY {
161 num_elements /= 8;
162 }
163 let values = decode_values(&mut reader, num_elements, format)?;
164 Ok(Vector::Dense(values))
165}
166
167fn decode_values(reader: &mut TtcReader<'_>, count: u32, format: u8) -> Result<VectorValues> {
168 let count = count as usize;
169 reader.limits().check_vector_dimensions(count)?;
170 match format {
179 VECTOR_FORMAT_FLOAT32 => {
180 let mut out: Vec<f32> =
181 reader.with_capacity_limited(count, 4, ProtocolLimits::check_vector_dimensions)?;
182 for _ in 0..count {
183 let raw = reader.read_raw(4)?;
184 out.push(decode_binary_float([raw[0], raw[1], raw[2], raw[3]]));
185 }
186 Ok(VectorValues::Float32(out))
187 }
188 VECTOR_FORMAT_FLOAT64 => {
189 let mut out: Vec<f64> =
190 reader.with_capacity_limited(count, 8, ProtocolLimits::check_vector_dimensions)?;
191 for _ in 0..count {
192 let raw = reader.read_raw(8)?;
193 out.push(decode_binary_double([
194 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
195 ]));
196 }
197 Ok(VectorValues::Float64(out))
198 }
199 VECTOR_FORMAT_INT8 => {
200 let mut out: Vec<i8> =
201 reader.with_capacity_limited(count, 1, ProtocolLimits::check_vector_dimensions)?;
202 for _ in 0..count {
203 out.push(reader.read_u8()? as i8);
204 }
205 Ok(VectorValues::Int8(out))
206 }
207 VECTOR_FORMAT_BINARY => Ok(VectorValues::Binary(reader.read_raw(count)?.to_vec())),
208 _ => Err(ProtocolError::TtcDecode(
209 "vector: unsupported element format",
210 )),
211 }
212}
213
214pub fn encode_vector(vector: &Vector) -> Vec<u8> {
225 match encode_vector_checked(vector) {
226 Ok(image) => image,
227 Err(err) => panic!("invalid VECTOR value for encoding: {err}"),
228 }
229}
230
231pub(crate) fn encode_vector_checked(vector: &Vector) -> Result<Vec<u8>> {
232 let mut buf = Vec::new();
233
234 let mut flags = TNS_VECTOR_FLAG_NORM_RESERVED;
235 let (format, version, num_elements) = match vector {
236 Vector::Sparse {
237 num_dimensions,
238 values,
239 ..
240 } => {
241 flags |= TNS_VECTOR_FLAG_SPARSE | TNS_VECTOR_FLAG_NORM;
242 (
243 values.format(),
244 TNS_VECTOR_VERSION_WITH_SPARSE,
245 *num_dimensions,
246 )
247 }
248 Vector::Dense(values) => {
249 let format = values.format();
250 if format == VECTOR_FORMAT_BINARY {
251 (
252 format,
253 TNS_VECTOR_VERSION_WITH_BINARY,
254 (values.len() as u32) * 8,
255 )
256 } else {
257 flags |= TNS_VECTOR_FLAG_NORM;
258 (format, TNS_VECTOR_VERSION_BASE, values.len() as u32)
259 }
260 }
261 };
262
263 buf.push(TNS_VECTOR_MAGIC_BYTE);
264 buf.push(version);
265 buf.extend_from_slice(&flags.to_be_bytes());
266 buf.push(format);
267 buf.extend_from_slice(&num_elements.to_be_bytes());
268 buf.extend_from_slice(&[0u8; 8]); match vector {
271 Vector::Dense(values) => encode_values(&mut buf, values),
272 Vector::Sparse {
273 indices, values, ..
274 } => {
275 if indices.len() != values.len() {
276 return Err(ProtocolError::TtcDecode(
277 "vector: sparse index/value count mismatch",
278 ));
279 }
280 let num_sparse =
281 u16::try_from(indices.len()).map_err(|_| ProtocolError::InvalidPacketLength {
282 length: indices.len(),
283 minimum: 0,
284 })?;
285 buf.extend_from_slice(&num_sparse.to_be_bytes());
286 for index in indices {
287 buf.extend_from_slice(&index.to_be_bytes());
288 }
289 encode_values(&mut buf, values);
290 }
291 }
292
293 Ok(buf)
294}
295
296fn encode_values(buf: &mut Vec<u8>, values: &VectorValues) {
297 match values {
298 VectorValues::Float32(v) => {
299 for value in v {
300 buf.extend_from_slice(&encode_binary_float(*value));
301 }
302 }
303 VectorValues::Float64(v) => {
304 for value in v {
305 buf.extend_from_slice(&encode_binary_double(*value));
306 }
307 }
308 VectorValues::Int8(v) => {
309 for value in v {
310 buf.push(*value as u8);
311 }
312 }
313 VectorValues::Binary(v) => buf.extend_from_slice(v),
314 }
315}
316
317fn decode_binary_double(bytes: [u8; 8]) -> f64 {
326 let mut decoded = bytes;
327 if decoded[0] & 0x80 != 0 {
328 decoded[0] &= 0x7f;
329 } else {
330 for byte in &mut decoded {
331 *byte = !*byte;
332 }
333 }
334 f64::from_bits(u64::from_be_bytes(decoded))
335}
336
337fn decode_binary_float(bytes: [u8; 4]) -> f32 {
339 let mut decoded = bytes;
340 if decoded[0] & 0x80 != 0 {
341 decoded[0] &= 0x7f;
342 } else {
343 for byte in &mut decoded {
344 *byte = !*byte;
345 }
346 }
347 f32::from_bits(u32::from_be_bytes(decoded))
348}
349
350fn encode_binary_double(value: f64) -> [u8; 8] {
352 let mut bytes = value.to_bits().to_be_bytes();
353 if bytes[0] & 0x80 == 0 {
354 bytes[0] |= 0x80;
355 } else {
356 for byte in &mut bytes {
357 *byte = !*byte;
358 }
359 }
360 bytes
361}
362
363fn encode_binary_float(value: f32) -> [u8; 4] {
365 let mut bytes = value.to_bits().to_be_bytes();
366 if bytes[0] & 0x80 == 0 {
367 bytes[0] |= 0x80;
368 } else {
369 for byte in &mut bytes {
370 *byte = !*byte;
371 }
372 }
373 bytes
374}
375
376fn read_u16be(reader: &mut TtcReader<'_>) -> Result<u16> {
379 let raw = reader.read_raw(2)?;
380 Ok(u16::from_be_bytes([raw[0], raw[1]]))
381}
382
383fn read_u32be(reader: &mut TtcReader<'_>) -> Result<u32> {
384 let raw = reader.read_raw(4)?;
385 Ok(u32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]))
386}
387
388pub fn write_vector_image(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
394 write_qlocator(writer, image.len() as u64, true);
395 writer.write_bytes_with_length(image)?;
396 Ok(())
397}
398
399pub fn write_oson_aq_payload(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
403 write_qlocator(writer, image.len() as u64, false);
404 writer.write_bytes_with_length(image)?;
405 Ok(())
406}
407
408fn write_qlocator(writer: &mut TtcWriter, data_length: u64, write_length: bool) {
413 const TNS_LOB_QLOCATOR_VERSION: u16 = 4;
414 const TNS_LOB_LOC_FLAGS_VALUE_BASED: u8 = 0x20;
415 const TNS_LOB_LOC_FLAGS_BLOB: u8 = 0x01;
416 const TNS_LOB_LOC_FLAGS_ABSTRACT: u8 = 0x40;
417 const TNS_LOB_LOC_FLAGS_INIT: u8 = 0x08;
418
419 writer.write_ub4(40); if write_length {
421 writer.write_u8(40); }
423 writer.write_u16be(38); writer.write_u16be(TNS_LOB_QLOCATOR_VERSION);
425 writer.write_u8(
426 TNS_LOB_LOC_FLAGS_VALUE_BASED | TNS_LOB_LOC_FLAGS_BLOB | TNS_LOB_LOC_FLAGS_ABSTRACT,
427 );
428 writer.write_u8(TNS_LOB_LOC_FLAGS_INIT);
429 writer.write_u16be(0); writer.write_u16be(1); writer.write_u64be(data_length);
432 writer.write_u16be(0); writer.write_u16be(0); writer.write_u16be(0); writer.write_u64be(0); writer.write_u64be(0); }
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use serde_json::Value;
443
444 fn roundtrip(vector: Vector) {
445 let image = encode_vector(&vector);
446 let decoded = decode_vector(&image).expect("decode");
447 assert_eq!(decoded, vector);
448 }
449
450 #[test]
455 fn legitimate_large_vector_still_decodes_fully() {
456 let big_f32: Vec<f32> = (0..4096).map(|i| i as f32 * 0.5 - 1024.0).collect();
457 roundtrip(Vector::Dense(VectorValues::Float32(big_f32)));
458 let big_f64: Vec<f64> = (0..2048).map(|i| i as f64 * 0.25).collect();
459 roundtrip(Vector::Dense(VectorValues::Float64(big_f64)));
460 roundtrip(Vector::Sparse {
462 num_dimensions: 100_000,
463 indices: (0..1000).map(|i| i * 7).collect(),
464 values: VectorValues::Float32((0..1000).map(|i| i as f32).collect()),
465 });
466 }
467
468 #[test]
469 fn roundtrips_every_dense_format() {
470 roundtrip(Vector::Dense(VectorValues::Float32(vec![
471 1.5, -2.25, 3.0, 0.0,
472 ])));
473 roundtrip(Vector::Dense(VectorValues::Float64(vec![
474 6501.0, 25.25, 18.125, -3.5,
475 ])));
476 roundtrip(Vector::Dense(VectorValues::Int8(vec![
477 -5, 1, -2, 127, -128,
478 ])));
479 roundtrip(Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
480 }
481
482 #[test]
483 fn roundtrips_every_sparse_format() {
484 roundtrip(Vector::Sparse {
485 num_dimensions: 8,
486 indices: vec![1, 4, 6],
487 values: VectorValues::Float64(vec![1.5, -2.0, 9.25]),
488 });
489 roundtrip(Vector::Sparse {
490 num_dimensions: 6,
491 indices: vec![0, 3],
492 values: VectorValues::Float32(vec![2.5, -7.0]),
493 });
494 roundtrip(Vector::Sparse {
495 num_dimensions: 5,
496 indices: vec![2],
497 values: VectorValues::Int8(vec![42]),
498 });
499 }
500
501 #[test]
502 fn sparse_int8_roundtrips_max_u16_count() {
503 let indices = (0..u16::MAX).map(u32::from).collect::<Vec<_>>();
504 let values = VectorValues::Int8((0..u16::MAX).map(|i| (i % 127) as i8).collect::<Vec<_>>());
505 let vector = Vector::Sparse {
506 num_dimensions: u32::from(u16::MAX),
507 indices,
508 values,
509 };
510
511 let image = encode_vector_checked(&vector).expect("encode max u16 sparse vector");
512 let decoded = decode_vector(&image).expect("decode max u16 sparse vector");
513 assert_eq!(decoded, vector);
514 }
515
516 #[test]
517 fn sparse_int8_rejects_count_that_exceeds_wire_field() {
518 let count = usize::from(u16::MAX) + 1;
519 let vector = Vector::Sparse {
520 num_dimensions: count as u32,
521 indices: (0..count as u32).collect(),
522 values: VectorValues::Int8(vec![1; count]),
523 };
524
525 let err = encode_vector_checked(&vector).expect_err("oversized sparse count must fail");
526 assert!(
527 matches!(
528 err,
529 ProtocolError::InvalidPacketLength {
530 length,
531 minimum: 0
532 } if length == count
533 ),
534 "got {err:?}"
535 );
536 }
537
538 #[test]
539 fn sparse_encode_rejects_mismatched_index_value_counts() {
540 let vector = Vector::Sparse {
541 num_dimensions: 4,
542 indices: vec![0, 1, 2],
543 values: VectorValues::Int8(vec![7, 8]),
544 };
545
546 let err = encode_vector_checked(&vector).expect_err("mismatched sparse vector must fail");
547 assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
548 }
549
550 #[test]
557 fn float_elements_use_oracle_binary_transform() {
558 let image = encode_vector(&Vector::Dense(VectorValues::Float64(vec![1.0, -2.0])));
560 let body = &image[17..]; assert_eq!(&body[0..8], &[0xbf, 0xf0, 0, 0, 0, 0, 0, 0], "f64 +1.0");
562 assert_eq!(
564 &body[8..16],
565 &[0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
566 "f64 -2.0"
567 );
568
569 let image32 = encode_vector(&Vector::Dense(VectorValues::Float32(vec![1.0, -2.0])));
571 let body32 = &image32[17..];
572 assert_eq!(&body32[0..4], &[0xbf, 0x80, 0, 0], "f32 +1.0");
573 assert_eq!(&body32[4..8], &[0x3f, 0xff, 0xff, 0xff], "f32 -2.0");
574
575 assert_eq!(
577 decode_vector(&image).expect("decode f64"),
578 Vector::Dense(VectorValues::Float64(vec![1.0, -2.0]))
579 );
580 assert_eq!(
581 decode_vector(&image32).expect("decode f32"),
582 Vector::Dense(VectorValues::Float32(vec![1.0, -2.0]))
583 );
584 }
585
586 #[test]
587 fn rejects_bad_magic() {
588 let err = decode_vector(&[0x00, 0, 0, 0, 0, 0, 0, 0, 0]).expect_err("bad magic must fail");
589 assert!(matches!(err, ProtocolError::TtcDecode(_)));
590 }
591
592 #[test]
593 fn rejects_unsupported_version() {
594 let mut image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1])));
595 image[1] = 99; let err = decode_vector(&image).expect_err("bad version must fail");
597 assert!(matches!(err, ProtocolError::TtcDecode(_)));
598 }
599
600 #[test]
606 fn fuzz_regression_oom_oversized_element_count() {
607 let input = [219, 0, 0, 18, 3, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
608 let err = decode_vector(&input).expect_err("oversized count must fail closed");
609 assert!(
610 matches!(
611 err,
612 ProtocolError::TtcDecode(_) | ProtocolError::ResourceLimit { .. }
613 ),
614 "got {err:?}"
615 );
616 }
617
618 #[test]
619 fn decode_vector_with_limits_rejects_dense_dimensions() {
620 let image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1, 2, 3, 4, 5])));
621 let limits = ProtocolLimits {
622 max_vector_dimensions: 4,
623 ..ProtocolLimits::DEFAULT
624 };
625 assert!(matches!(
626 decode_vector_with_limits(&image, limits),
627 Err(ProtocolError::ResourceLimit {
628 limit: "vector_dimensions",
629 observed: 5,
630 maximum: 4,
631 })
632 ));
633 }
634
635 #[test]
641 fn sparse_oversized_index_count_fails_closed_not_oom() {
642 let input = [
646 TNS_VECTOR_MAGIC_BYTE,
647 TNS_VECTOR_VERSION_WITH_SPARSE,
648 0x00,
649 0x20, VECTOR_FORMAT_FLOAT64,
651 0x00,
652 0x00,
653 0x00,
654 0x00, 0xFF,
656 0xFF, ];
658 let err = decode_vector(&input).expect_err("oversized sparse count must fail closed");
659 assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
660 }
661
662 #[test]
663 fn binary_dense_bit_count_header() {
664 let image = encode_vector(&Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
666 let num_elements = u32::from_be_bytes([image[5], image[6], image[7], image[8]]);
667 assert_eq!(num_elements, 16);
668 assert_eq!(image[1], TNS_VECTOR_VERSION_WITH_BINARY);
669 }
670
671 fn build_from_golden(entry: &Value) -> Vector {
675 let typecode = entry["typecode"].as_str().expect("typecode");
676 let f64_at = |x: &Value| x.as_f64().expect("number");
677 let i64_at = |x: &Value| x.as_i64().expect("int");
678 let u64_at = |x: &Value| x.as_u64().expect("uint");
679 let make_values = |arr: &Value| -> VectorValues {
680 let v = arr.as_array().expect("array");
681 match typecode {
682 "f" => VectorValues::Float32(v.iter().map(|x| f64_at(x) as f32).collect()),
683 "d" => VectorValues::Float64(v.iter().map(f64_at).collect()),
684 "b" => VectorValues::Int8(v.iter().map(|x| i64_at(x) as i8).collect()),
685 "B" => VectorValues::Binary(v.iter().map(|x| u64_at(x) as u8).collect()),
686 other => panic!("unknown typecode {other}"),
687 }
688 };
689 if entry["kind"] == "sparse" {
690 Vector::Sparse {
691 num_dimensions: u64_at(&entry["num_dimensions"]) as u32,
692 indices: entry["indices"]
693 .as_array()
694 .expect("indices array")
695 .iter()
696 .map(|x| u64_at(x) as u32)
697 .collect(),
698 values: make_values(&entry["values"]),
699 }
700 } else {
701 Vector::Dense(make_values(&entry["values"]))
702 }
703 }
704
705 #[test]
706 fn matches_golden_capture() {
707 let raw = include_str!("../tests/golden/vectors.json");
708 let golden: Value = serde_json::from_str(raw).expect("parse golden json");
709 let obj = golden.as_object().expect("golden is an object");
710 assert!(!obj.is_empty(), "golden capture must not be empty");
711 for (name, entry) in obj {
712 let expected_hex = entry["image_hex"].as_str().expect("image_hex");
713 let expected = hex::decode(expected_hex).expect("decode golden hex");
714
715 let vector = build_from_golden(entry);
717 let image = encode_vector(&vector);
718 assert_eq!(
719 hex::encode(&image),
720 expected_hex,
721 "encode mismatch for golden case {name}"
722 );
723
724 let decoded = decode_vector(&expected).expect("decode golden image");
726 assert_eq!(decoded, vector, "decode mismatch for golden case {name}");
727 }
728 }
729}