1use crate::wire::{BoundedReader, ProtocolLimits, TtcReader, TtcWriter};
30use crate::{ProtocolError, Result};
31
32pub const TNS_VECTOR_MAGIC_BYTE: u8 = 0xDB;
34
35pub const TNS_VECTOR_VERSION_BASE: u8 = 0;
37pub const TNS_VECTOR_VERSION_WITH_BINARY: u8 = 1;
38pub const TNS_VECTOR_VERSION_WITH_SPARSE: u8 = 2;
39
40pub const TNS_VECTOR_FLAG_NORM: u16 = 0x0002;
42pub const TNS_VECTOR_FLAG_NORM_RESERVED: u16 = 0x0010;
43pub const TNS_VECTOR_FLAG_SPARSE: u16 = 0x0020;
44
45pub const VECTOR_FORMAT_FLOAT32: u8 = 2;
47pub const VECTOR_FORMAT_FLOAT64: u8 = 3;
48pub const VECTOR_FORMAT_INT8: u8 = 4;
49pub const VECTOR_FORMAT_BINARY: u8 = 5;
50
51#[derive(Clone, Debug, PartialEq)]
56pub enum VectorValues {
57 Float32(Vec<f32>),
59 Float64(Vec<f64>),
61 Int8(Vec<i8>),
63 Binary(Vec<u8>),
65}
66
67impl VectorValues {
68 pub fn format(&self) -> u8 {
70 match self {
71 VectorValues::Float32(_) => VECTOR_FORMAT_FLOAT32,
72 VectorValues::Float64(_) => VECTOR_FORMAT_FLOAT64,
73 VectorValues::Int8(_) => VECTOR_FORMAT_INT8,
74 VectorValues::Binary(_) => VECTOR_FORMAT_BINARY,
75 }
76 }
77
78 pub fn len(&self) -> usize {
81 match self {
82 VectorValues::Float32(v) => v.len(),
83 VectorValues::Float64(v) => v.len(),
84 VectorValues::Int8(v) => v.len(),
85 VectorValues::Binary(v) => v.len(),
86 }
87 }
88
89 pub fn is_empty(&self) -> bool {
90 self.len() == 0
91 }
92}
93
94#[derive(Clone, Debug, PartialEq)]
97pub enum Vector {
98 Dense(VectorValues),
99 Sparse {
100 num_dimensions: u32,
101 indices: Vec<u32>,
102 values: VectorValues,
103 },
104}
105
106pub fn decode_vector(data: &[u8]) -> Result<Vector> {
108 decode_vector_with_limits(data, ProtocolLimits::DEFAULT)
109}
110
111pub fn decode_vector_with_limits(data: &[u8], limits: ProtocolLimits) -> Result<Vector> {
113 limits.check_response_bytes(data.len())?;
114 let mut reader = TtcReader::with_limits(data, limits)?;
115
116 let magic = reader.read_u8()?;
117 if magic != TNS_VECTOR_MAGIC_BYTE {
118 return Err(ProtocolError::TtcDecode("vector: bad magic byte"));
119 }
120 let version = reader.read_u8()?;
121 if version > TNS_VECTOR_VERSION_WITH_SPARSE {
122 return Err(ProtocolError::TtcDecode("vector: unsupported version"));
123 }
124 let flags = read_u16be(&mut reader)?;
125 let format = reader.read_u8()?;
126 let mut num_elements = read_u32be(&mut reader)?;
127 reader
128 .limits()
129 .check_vector_dimensions(num_elements as usize)?;
130 if flags & TNS_VECTOR_FLAG_NORM_RESERVED != 0 || flags & TNS_VECTOR_FLAG_NORM != 0 {
131 reader.skip(8)?;
132 }
133
134 if flags & TNS_VECTOR_FLAG_SPARSE != 0 {
135 let num_dimensions = num_elements;
136 let num_sparse = read_u16be(&mut reader)?;
137 reader
138 .limits()
139 .check_vector_dimensions(usize::from(num_sparse))?;
140 let mut indices: Vec<u32> = reader.with_capacity_limited(
144 usize::from(num_sparse),
145 4,
146 ProtocolLimits::check_vector_dimensions,
147 )?;
148 for _ in 0..num_sparse {
149 indices.push(read_u32be(&mut reader)?);
150 }
151 let values = decode_values(&mut reader, u32::from(num_sparse), format)?;
152 return Ok(Vector::Sparse {
153 num_dimensions,
154 indices,
155 values,
156 });
157 }
158
159 if format == VECTOR_FORMAT_BINARY {
161 num_elements /= 8;
162 }
163 let values = decode_values(&mut reader, num_elements, format)?;
164 Ok(Vector::Dense(values))
165}
166
167fn decode_values(reader: &mut TtcReader<'_>, count: u32, format: u8) -> Result<VectorValues> {
168 let count = count as usize;
169 reader.limits().check_vector_dimensions(count)?;
170 match format {
179 VECTOR_FORMAT_FLOAT32 => {
180 let mut out: Vec<f32> =
181 reader.with_capacity_limited(count, 4, ProtocolLimits::check_vector_dimensions)?;
182 for _ in 0..count {
183 let raw = reader.read_raw(4)?;
184 out.push(decode_binary_float([raw[0], raw[1], raw[2], raw[3]]));
185 }
186 Ok(VectorValues::Float32(out))
187 }
188 VECTOR_FORMAT_FLOAT64 => {
189 let mut out: Vec<f64> =
190 reader.with_capacity_limited(count, 8, ProtocolLimits::check_vector_dimensions)?;
191 for _ in 0..count {
192 let raw = reader.read_raw(8)?;
193 out.push(decode_binary_double([
194 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
195 ]));
196 }
197 Ok(VectorValues::Float64(out))
198 }
199 VECTOR_FORMAT_INT8 => {
200 let mut out: Vec<i8> =
201 reader.with_capacity_limited(count, 1, ProtocolLimits::check_vector_dimensions)?;
202 for _ in 0..count {
203 out.push(reader.read_u8()? as i8);
204 }
205 Ok(VectorValues::Int8(out))
206 }
207 VECTOR_FORMAT_BINARY => Ok(VectorValues::Binary(reader.read_raw(count)?.to_vec())),
208 _ => Err(ProtocolError::TtcDecode(
209 "vector: unsupported element format",
210 )),
211 }
212}
213
214pub fn encode_vector(vector: &Vector) -> Vec<u8> {
217 let mut buf = Vec::new();
218
219 let mut flags = TNS_VECTOR_FLAG_NORM_RESERVED;
220 let (format, version, num_elements) = match vector {
221 Vector::Sparse {
222 num_dimensions,
223 values,
224 ..
225 } => {
226 flags |= TNS_VECTOR_FLAG_SPARSE | TNS_VECTOR_FLAG_NORM;
227 (
228 values.format(),
229 TNS_VECTOR_VERSION_WITH_SPARSE,
230 *num_dimensions,
231 )
232 }
233 Vector::Dense(values) => {
234 let format = values.format();
235 if format == VECTOR_FORMAT_BINARY {
236 (
237 format,
238 TNS_VECTOR_VERSION_WITH_BINARY,
239 (values.len() as u32) * 8,
240 )
241 } else {
242 flags |= TNS_VECTOR_FLAG_NORM;
243 (format, TNS_VECTOR_VERSION_BASE, values.len() as u32)
244 }
245 }
246 };
247
248 buf.push(TNS_VECTOR_MAGIC_BYTE);
249 buf.push(version);
250 buf.extend_from_slice(&flags.to_be_bytes());
251 buf.push(format);
252 buf.extend_from_slice(&num_elements.to_be_bytes());
253 buf.extend_from_slice(&[0u8; 8]); match vector {
256 Vector::Dense(values) => encode_values(&mut buf, values),
257 Vector::Sparse {
258 indices, values, ..
259 } => {
260 let num_sparse = indices.len() as u16;
261 buf.extend_from_slice(&num_sparse.to_be_bytes());
262 for index in indices {
263 buf.extend_from_slice(&index.to_be_bytes());
264 }
265 encode_values(&mut buf, values);
266 }
267 }
268
269 buf
270}
271
272fn encode_values(buf: &mut Vec<u8>, values: &VectorValues) {
273 match values {
274 VectorValues::Float32(v) => {
275 for value in v {
276 buf.extend_from_slice(&encode_binary_float(*value));
277 }
278 }
279 VectorValues::Float64(v) => {
280 for value in v {
281 buf.extend_from_slice(&encode_binary_double(*value));
282 }
283 }
284 VectorValues::Int8(v) => {
285 for value in v {
286 buf.push(*value as u8);
287 }
288 }
289 VectorValues::Binary(v) => buf.extend_from_slice(v),
290 }
291}
292
293fn decode_binary_double(bytes: [u8; 8]) -> f64 {
302 let mut decoded = bytes;
303 if decoded[0] & 0x80 != 0 {
304 decoded[0] &= 0x7f;
305 } else {
306 for byte in &mut decoded {
307 *byte = !*byte;
308 }
309 }
310 f64::from_bits(u64::from_be_bytes(decoded))
311}
312
313fn decode_binary_float(bytes: [u8; 4]) -> f32 {
315 let mut decoded = bytes;
316 if decoded[0] & 0x80 != 0 {
317 decoded[0] &= 0x7f;
318 } else {
319 for byte in &mut decoded {
320 *byte = !*byte;
321 }
322 }
323 f32::from_bits(u32::from_be_bytes(decoded))
324}
325
326fn encode_binary_double(value: f64) -> [u8; 8] {
328 let mut bytes = value.to_bits().to_be_bytes();
329 if bytes[0] & 0x80 == 0 {
330 bytes[0] |= 0x80;
331 } else {
332 for byte in &mut bytes {
333 *byte = !*byte;
334 }
335 }
336 bytes
337}
338
339fn encode_binary_float(value: f32) -> [u8; 4] {
341 let mut bytes = value.to_bits().to_be_bytes();
342 if bytes[0] & 0x80 == 0 {
343 bytes[0] |= 0x80;
344 } else {
345 for byte in &mut bytes {
346 *byte = !*byte;
347 }
348 }
349 bytes
350}
351
352fn read_u16be(reader: &mut TtcReader<'_>) -> Result<u16> {
355 let raw = reader.read_raw(2)?;
356 Ok(u16::from_be_bytes([raw[0], raw[1]]))
357}
358
359fn read_u32be(reader: &mut TtcReader<'_>) -> Result<u32> {
360 let raw = reader.read_raw(4)?;
361 Ok(u32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]))
362}
363
364pub fn write_vector_image(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
370 write_qlocator(writer, image.len() as u64, true);
371 writer.write_bytes_with_length(image)?;
372 Ok(())
373}
374
375pub fn write_oson_aq_payload(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
379 write_qlocator(writer, image.len() as u64, false);
380 writer.write_bytes_with_length(image)?;
381 Ok(())
382}
383
384fn write_qlocator(writer: &mut TtcWriter, data_length: u64, write_length: bool) {
389 const TNS_LOB_QLOCATOR_VERSION: u16 = 4;
390 const TNS_LOB_LOC_FLAGS_VALUE_BASED: u8 = 0x20;
391 const TNS_LOB_LOC_FLAGS_BLOB: u8 = 0x01;
392 const TNS_LOB_LOC_FLAGS_ABSTRACT: u8 = 0x40;
393 const TNS_LOB_LOC_FLAGS_INIT: u8 = 0x08;
394
395 writer.write_ub4(40); if write_length {
397 writer.write_u8(40); }
399 writer.write_u16be(38); writer.write_u16be(TNS_LOB_QLOCATOR_VERSION);
401 writer.write_u8(
402 TNS_LOB_LOC_FLAGS_VALUE_BASED | TNS_LOB_LOC_FLAGS_BLOB | TNS_LOB_LOC_FLAGS_ABSTRACT,
403 );
404 writer.write_u8(TNS_LOB_LOC_FLAGS_INIT);
405 writer.write_u16be(0); writer.write_u16be(1); writer.write_u64be(data_length);
408 writer.write_u16be(0); writer.write_u16be(0); writer.write_u16be(0); writer.write_u64be(0); writer.write_u64be(0); }
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use serde_json::Value;
419
420 fn roundtrip(vector: Vector) {
421 let image = encode_vector(&vector);
422 let decoded = decode_vector(&image).expect("decode");
423 assert_eq!(decoded, vector);
424 }
425
426 #[test]
431 fn legitimate_large_vector_still_decodes_fully() {
432 let big_f32: Vec<f32> = (0..4096).map(|i| i as f32 * 0.5 - 1024.0).collect();
433 roundtrip(Vector::Dense(VectorValues::Float32(big_f32)));
434 let big_f64: Vec<f64> = (0..2048).map(|i| i as f64 * 0.25).collect();
435 roundtrip(Vector::Dense(VectorValues::Float64(big_f64)));
436 roundtrip(Vector::Sparse {
438 num_dimensions: 100_000,
439 indices: (0..1000).map(|i| i * 7).collect(),
440 values: VectorValues::Float32((0..1000).map(|i| i as f32).collect()),
441 });
442 }
443
444 #[test]
445 fn roundtrips_every_dense_format() {
446 roundtrip(Vector::Dense(VectorValues::Float32(vec![
447 1.5, -2.25, 3.0, 0.0,
448 ])));
449 roundtrip(Vector::Dense(VectorValues::Float64(vec![
450 6501.0, 25.25, 18.125, -3.5,
451 ])));
452 roundtrip(Vector::Dense(VectorValues::Int8(vec![
453 -5, 1, -2, 127, -128,
454 ])));
455 roundtrip(Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
456 }
457
458 #[test]
459 fn roundtrips_every_sparse_format() {
460 roundtrip(Vector::Sparse {
461 num_dimensions: 8,
462 indices: vec![1, 4, 6],
463 values: VectorValues::Float64(vec![1.5, -2.0, 9.25]),
464 });
465 roundtrip(Vector::Sparse {
466 num_dimensions: 6,
467 indices: vec![0, 3],
468 values: VectorValues::Float32(vec![2.5, -7.0]),
469 });
470 roundtrip(Vector::Sparse {
471 num_dimensions: 5,
472 indices: vec![2],
473 values: VectorValues::Int8(vec![42]),
474 });
475 }
476
477 #[test]
484 fn float_elements_use_oracle_binary_transform() {
485 let image = encode_vector(&Vector::Dense(VectorValues::Float64(vec![1.0, -2.0])));
487 let body = &image[17..]; assert_eq!(&body[0..8], &[0xbf, 0xf0, 0, 0, 0, 0, 0, 0], "f64 +1.0");
489 assert_eq!(
491 &body[8..16],
492 &[0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
493 "f64 -2.0"
494 );
495
496 let image32 = encode_vector(&Vector::Dense(VectorValues::Float32(vec![1.0, -2.0])));
498 let body32 = &image32[17..];
499 assert_eq!(&body32[0..4], &[0xbf, 0x80, 0, 0], "f32 +1.0");
500 assert_eq!(&body32[4..8], &[0x3f, 0xff, 0xff, 0xff], "f32 -2.0");
501
502 assert_eq!(
504 decode_vector(&image).expect("decode f64"),
505 Vector::Dense(VectorValues::Float64(vec![1.0, -2.0]))
506 );
507 assert_eq!(
508 decode_vector(&image32).expect("decode f32"),
509 Vector::Dense(VectorValues::Float32(vec![1.0, -2.0]))
510 );
511 }
512
513 #[test]
514 fn rejects_bad_magic() {
515 let err = decode_vector(&[0x00, 0, 0, 0, 0, 0, 0, 0, 0]).expect_err("bad magic must fail");
516 assert!(matches!(err, ProtocolError::TtcDecode(_)));
517 }
518
519 #[test]
520 fn rejects_unsupported_version() {
521 let mut image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1])));
522 image[1] = 99; let err = decode_vector(&image).expect_err("bad version must fail");
524 assert!(matches!(err, ProtocolError::TtcDecode(_)));
525 }
526
527 #[test]
533 fn fuzz_regression_oom_oversized_element_count() {
534 let input = [219, 0, 0, 18, 3, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
535 let err = decode_vector(&input).expect_err("oversized count must fail closed");
536 assert!(
537 matches!(
538 err,
539 ProtocolError::TtcDecode(_) | ProtocolError::ResourceLimit { .. }
540 ),
541 "got {err:?}"
542 );
543 }
544
545 #[test]
546 fn decode_vector_with_limits_rejects_dense_dimensions() {
547 let image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1, 2, 3, 4, 5])));
548 let limits = ProtocolLimits {
549 max_vector_dimensions: 4,
550 ..ProtocolLimits::DEFAULT
551 };
552 assert!(matches!(
553 decode_vector_with_limits(&image, limits),
554 Err(ProtocolError::ResourceLimit {
555 limit: "vector_dimensions",
556 observed: 5,
557 maximum: 4,
558 })
559 ));
560 }
561
562 #[test]
568 fn sparse_oversized_index_count_fails_closed_not_oom() {
569 let input = [
573 TNS_VECTOR_MAGIC_BYTE,
574 TNS_VECTOR_VERSION_WITH_SPARSE,
575 0x00,
576 0x20, VECTOR_FORMAT_FLOAT64,
578 0x00,
579 0x00,
580 0x00,
581 0x00, 0xFF,
583 0xFF, ];
585 let err = decode_vector(&input).expect_err("oversized sparse count must fail closed");
586 assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
587 }
588
589 #[test]
590 fn binary_dense_bit_count_header() {
591 let image = encode_vector(&Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
593 let num_elements = u32::from_be_bytes([image[5], image[6], image[7], image[8]]);
594 assert_eq!(num_elements, 16);
595 assert_eq!(image[1], TNS_VECTOR_VERSION_WITH_BINARY);
596 }
597
598 fn build_from_golden(entry: &Value) -> Vector {
602 let typecode = entry["typecode"].as_str().expect("typecode");
603 let f64_at = |x: &Value| x.as_f64().expect("number");
604 let i64_at = |x: &Value| x.as_i64().expect("int");
605 let u64_at = |x: &Value| x.as_u64().expect("uint");
606 let make_values = |arr: &Value| -> VectorValues {
607 let v = arr.as_array().expect("array");
608 match typecode {
609 "f" => VectorValues::Float32(v.iter().map(|x| f64_at(x) as f32).collect()),
610 "d" => VectorValues::Float64(v.iter().map(f64_at).collect()),
611 "b" => VectorValues::Int8(v.iter().map(|x| i64_at(x) as i8).collect()),
612 "B" => VectorValues::Binary(v.iter().map(|x| u64_at(x) as u8).collect()),
613 other => panic!("unknown typecode {other}"),
614 }
615 };
616 if entry["kind"] == "sparse" {
617 Vector::Sparse {
618 num_dimensions: u64_at(&entry["num_dimensions"]) as u32,
619 indices: entry["indices"]
620 .as_array()
621 .expect("indices array")
622 .iter()
623 .map(|x| u64_at(x) as u32)
624 .collect(),
625 values: make_values(&entry["values"]),
626 }
627 } else {
628 Vector::Dense(make_values(&entry["values"]))
629 }
630 }
631
632 #[test]
633 fn matches_golden_capture() {
634 let raw = include_str!("../tests/golden/vectors.json");
635 let golden: Value = serde_json::from_str(raw).expect("parse golden json");
636 let obj = golden.as_object().expect("golden is an object");
637 assert!(!obj.is_empty(), "golden capture must not be empty");
638 for (name, entry) in obj {
639 let expected_hex = entry["image_hex"].as_str().expect("image_hex");
640 let expected = hex::decode(expected_hex).expect("decode golden hex");
641
642 let vector = build_from_golden(entry);
644 let image = encode_vector(&vector);
645 assert_eq!(
646 hex::encode(&image),
647 expected_hex,
648 "encode mismatch for golden case {name}"
649 );
650
651 let decoded = decode_vector(&expected).expect("decode golden image");
653 assert_eq!(decoded, vector, "decode mismatch for golden case {name}");
654 }
655 }
656}