1use crate::wire::{BoundedReader, TtcReader, TtcWriter};
30use crate::{ProtocolError, Result};
31
32pub const TNS_VECTOR_MAGIC_BYTE: u8 = 0xDB;
34
35pub const TNS_VECTOR_VERSION_BASE: u8 = 0;
37pub const TNS_VECTOR_VERSION_WITH_BINARY: u8 = 1;
38pub const TNS_VECTOR_VERSION_WITH_SPARSE: u8 = 2;
39
40pub const TNS_VECTOR_FLAG_NORM: u16 = 0x0002;
42pub const TNS_VECTOR_FLAG_NORM_RESERVED: u16 = 0x0010;
43pub const TNS_VECTOR_FLAG_SPARSE: u16 = 0x0020;
44
45pub const VECTOR_FORMAT_FLOAT32: u8 = 2;
47pub const VECTOR_FORMAT_FLOAT64: u8 = 3;
48pub const VECTOR_FORMAT_INT8: u8 = 4;
49pub const VECTOR_FORMAT_BINARY: u8 = 5;
50
51#[derive(Clone, Debug, PartialEq)]
56pub enum VectorValues {
57 Float32(Vec<f32>),
59 Float64(Vec<f64>),
61 Int8(Vec<i8>),
63 Binary(Vec<u8>),
65}
66
67impl VectorValues {
68 pub fn format(&self) -> u8 {
70 match self {
71 VectorValues::Float32(_) => VECTOR_FORMAT_FLOAT32,
72 VectorValues::Float64(_) => VECTOR_FORMAT_FLOAT64,
73 VectorValues::Int8(_) => VECTOR_FORMAT_INT8,
74 VectorValues::Binary(_) => VECTOR_FORMAT_BINARY,
75 }
76 }
77
78 pub fn len(&self) -> usize {
81 match self {
82 VectorValues::Float32(v) => v.len(),
83 VectorValues::Float64(v) => v.len(),
84 VectorValues::Int8(v) => v.len(),
85 VectorValues::Binary(v) => v.len(),
86 }
87 }
88
89 pub fn is_empty(&self) -> bool {
90 self.len() == 0
91 }
92}
93
94#[derive(Clone, Debug, PartialEq)]
97pub enum Vector {
98 Dense(VectorValues),
99 Sparse {
100 num_dimensions: u32,
101 indices: Vec<u32>,
102 values: VectorValues,
103 },
104}
105
106pub fn decode_vector(data: &[u8]) -> Result<Vector> {
108 let mut reader = TtcReader::new(data);
109
110 let magic = reader.read_u8()?;
111 if magic != TNS_VECTOR_MAGIC_BYTE {
112 return Err(ProtocolError::TtcDecode("vector: bad magic byte"));
113 }
114 let version = reader.read_u8()?;
115 if version > TNS_VECTOR_VERSION_WITH_SPARSE {
116 return Err(ProtocolError::TtcDecode("vector: unsupported version"));
117 }
118 let flags = read_u16be(&mut reader)?;
119 let format = reader.read_u8()?;
120 let mut num_elements = read_u32be(&mut reader)?;
121 if flags & TNS_VECTOR_FLAG_NORM_RESERVED != 0 || flags & TNS_VECTOR_FLAG_NORM != 0 {
122 reader.skip(8)?;
123 }
124
125 if flags & TNS_VECTOR_FLAG_SPARSE != 0 {
126 let num_dimensions = num_elements;
127 let num_sparse = read_u16be(&mut reader)?;
128 let mut indices: Vec<u32> = reader.with_capacity_bounded(usize::from(num_sparse), 4);
132 for _ in 0..num_sparse {
133 indices.push(read_u32be(&mut reader)?);
134 }
135 let values = decode_values(&mut reader, u32::from(num_sparse), format)?;
136 return Ok(Vector::Sparse {
137 num_dimensions,
138 indices,
139 values,
140 });
141 }
142
143 if format == VECTOR_FORMAT_BINARY {
145 num_elements /= 8;
146 }
147 let values = decode_values(&mut reader, num_elements, format)?;
148 Ok(Vector::Dense(values))
149}
150
151fn decode_values(reader: &mut TtcReader<'_>, count: u32, format: u8) -> Result<VectorValues> {
152 let count = count as usize;
153 match format {
162 VECTOR_FORMAT_FLOAT32 => {
163 let mut out: Vec<f32> = reader.with_capacity_bounded(count, 4);
164 for _ in 0..count {
165 let raw = reader.read_raw(4)?;
166 out.push(decode_binary_float([raw[0], raw[1], raw[2], raw[3]]));
167 }
168 Ok(VectorValues::Float32(out))
169 }
170 VECTOR_FORMAT_FLOAT64 => {
171 let mut out: Vec<f64> = reader.with_capacity_bounded(count, 8);
172 for _ in 0..count {
173 let raw = reader.read_raw(8)?;
174 out.push(decode_binary_double([
175 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
176 ]));
177 }
178 Ok(VectorValues::Float64(out))
179 }
180 VECTOR_FORMAT_INT8 => {
181 let mut out: Vec<i8> = reader.with_capacity_bounded(count, 1);
182 for _ in 0..count {
183 out.push(reader.read_u8()? as i8);
184 }
185 Ok(VectorValues::Int8(out))
186 }
187 VECTOR_FORMAT_BINARY => Ok(VectorValues::Binary(reader.read_raw(count)?.to_vec())),
188 _ => Err(ProtocolError::TtcDecode(
189 "vector: unsupported element format",
190 )),
191 }
192}
193
194pub fn encode_vector(vector: &Vector) -> Vec<u8> {
197 let mut buf = Vec::new();
198
199 let mut flags = TNS_VECTOR_FLAG_NORM_RESERVED;
200 let (format, version, num_elements) = match vector {
201 Vector::Sparse {
202 num_dimensions,
203 values,
204 ..
205 } => {
206 flags |= TNS_VECTOR_FLAG_SPARSE | TNS_VECTOR_FLAG_NORM;
207 (
208 values.format(),
209 TNS_VECTOR_VERSION_WITH_SPARSE,
210 *num_dimensions,
211 )
212 }
213 Vector::Dense(values) => {
214 let format = values.format();
215 if format == VECTOR_FORMAT_BINARY {
216 (
217 format,
218 TNS_VECTOR_VERSION_WITH_BINARY,
219 (values.len() as u32) * 8,
220 )
221 } else {
222 flags |= TNS_VECTOR_FLAG_NORM;
223 (format, TNS_VECTOR_VERSION_BASE, values.len() as u32)
224 }
225 }
226 };
227
228 buf.push(TNS_VECTOR_MAGIC_BYTE);
229 buf.push(version);
230 buf.extend_from_slice(&flags.to_be_bytes());
231 buf.push(format);
232 buf.extend_from_slice(&num_elements.to_be_bytes());
233 buf.extend_from_slice(&[0u8; 8]); match vector {
236 Vector::Dense(values) => encode_values(&mut buf, values),
237 Vector::Sparse {
238 indices, values, ..
239 } => {
240 let num_sparse = indices.len() as u16;
241 buf.extend_from_slice(&num_sparse.to_be_bytes());
242 for index in indices {
243 buf.extend_from_slice(&index.to_be_bytes());
244 }
245 encode_values(&mut buf, values);
246 }
247 }
248
249 buf
250}
251
252fn encode_values(buf: &mut Vec<u8>, values: &VectorValues) {
253 match values {
254 VectorValues::Float32(v) => {
255 for value in v {
256 buf.extend_from_slice(&encode_binary_float(*value));
257 }
258 }
259 VectorValues::Float64(v) => {
260 for value in v {
261 buf.extend_from_slice(&encode_binary_double(*value));
262 }
263 }
264 VectorValues::Int8(v) => {
265 for value in v {
266 buf.push(*value as u8);
267 }
268 }
269 VectorValues::Binary(v) => buf.extend_from_slice(v),
270 }
271}
272
273fn decode_binary_double(bytes: [u8; 8]) -> f64 {
282 let mut decoded = bytes;
283 if decoded[0] & 0x80 != 0 {
284 decoded[0] &= 0x7f;
285 } else {
286 for byte in &mut decoded {
287 *byte = !*byte;
288 }
289 }
290 f64::from_bits(u64::from_be_bytes(decoded))
291}
292
293fn decode_binary_float(bytes: [u8; 4]) -> f32 {
295 let mut decoded = bytes;
296 if decoded[0] & 0x80 != 0 {
297 decoded[0] &= 0x7f;
298 } else {
299 for byte in &mut decoded {
300 *byte = !*byte;
301 }
302 }
303 f32::from_bits(u32::from_be_bytes(decoded))
304}
305
306fn encode_binary_double(value: f64) -> [u8; 8] {
308 let mut bytes = value.to_bits().to_be_bytes();
309 if bytes[0] & 0x80 == 0 {
310 bytes[0] |= 0x80;
311 } else {
312 for byte in &mut bytes {
313 *byte = !*byte;
314 }
315 }
316 bytes
317}
318
319fn encode_binary_float(value: f32) -> [u8; 4] {
321 let mut bytes = value.to_bits().to_be_bytes();
322 if bytes[0] & 0x80 == 0 {
323 bytes[0] |= 0x80;
324 } else {
325 for byte in &mut bytes {
326 *byte = !*byte;
327 }
328 }
329 bytes
330}
331
332fn read_u16be(reader: &mut TtcReader<'_>) -> Result<u16> {
335 let raw = reader.read_raw(2)?;
336 Ok(u16::from_be_bytes([raw[0], raw[1]]))
337}
338
339fn read_u32be(reader: &mut TtcReader<'_>) -> Result<u32> {
340 let raw = reader.read_raw(4)?;
341 Ok(u32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]))
342}
343
344pub fn write_vector_image(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
350 write_qlocator(writer, image.len() as u64, true);
351 writer.write_bytes_with_length(image)?;
352 Ok(())
353}
354
355pub fn write_oson_aq_payload(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
359 write_qlocator(writer, image.len() as u64, false);
360 writer.write_bytes_with_length(image)?;
361 Ok(())
362}
363
364fn write_qlocator(writer: &mut TtcWriter, data_length: u64, write_length: bool) {
369 const TNS_LOB_QLOCATOR_VERSION: u16 = 4;
370 const TNS_LOB_LOC_FLAGS_VALUE_BASED: u8 = 0x20;
371 const TNS_LOB_LOC_FLAGS_BLOB: u8 = 0x01;
372 const TNS_LOB_LOC_FLAGS_ABSTRACT: u8 = 0x40;
373 const TNS_LOB_LOC_FLAGS_INIT: u8 = 0x08;
374
375 writer.write_ub4(40); if write_length {
377 writer.write_u8(40); }
379 writer.write_u16be(38); writer.write_u16be(TNS_LOB_QLOCATOR_VERSION);
381 writer.write_u8(
382 TNS_LOB_LOC_FLAGS_VALUE_BASED | TNS_LOB_LOC_FLAGS_BLOB | TNS_LOB_LOC_FLAGS_ABSTRACT,
383 );
384 writer.write_u8(TNS_LOB_LOC_FLAGS_INIT);
385 writer.write_u16be(0); writer.write_u16be(1); writer.write_u64be(data_length);
388 writer.write_u16be(0); writer.write_u16be(0); writer.write_u16be(0); writer.write_u64be(0); writer.write_u64be(0); }
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use serde_json::Value;
399
400 fn roundtrip(vector: Vector) {
401 let image = encode_vector(&vector);
402 let decoded = decode_vector(&image).expect("decode");
403 assert_eq!(decoded, vector);
404 }
405
406 #[test]
411 fn legitimate_large_vector_still_decodes_fully() {
412 let big_f32: Vec<f32> = (0..4096).map(|i| i as f32 * 0.5 - 1024.0).collect();
413 roundtrip(Vector::Dense(VectorValues::Float32(big_f32)));
414 let big_f64: Vec<f64> = (0..2048).map(|i| i as f64 * 0.25).collect();
415 roundtrip(Vector::Dense(VectorValues::Float64(big_f64)));
416 roundtrip(Vector::Sparse {
418 num_dimensions: 100_000,
419 indices: (0..1000).map(|i| i * 7).collect(),
420 values: VectorValues::Float32((0..1000).map(|i| i as f32).collect()),
421 });
422 }
423
424 #[test]
425 fn roundtrips_every_dense_format() {
426 roundtrip(Vector::Dense(VectorValues::Float32(vec![
427 1.5, -2.25, 3.0, 0.0,
428 ])));
429 roundtrip(Vector::Dense(VectorValues::Float64(vec![
430 6501.0, 25.25, 18.125, -3.5,
431 ])));
432 roundtrip(Vector::Dense(VectorValues::Int8(vec![
433 -5, 1, -2, 127, -128,
434 ])));
435 roundtrip(Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
436 }
437
438 #[test]
439 fn roundtrips_every_sparse_format() {
440 roundtrip(Vector::Sparse {
441 num_dimensions: 8,
442 indices: vec![1, 4, 6],
443 values: VectorValues::Float64(vec![1.5, -2.0, 9.25]),
444 });
445 roundtrip(Vector::Sparse {
446 num_dimensions: 6,
447 indices: vec![0, 3],
448 values: VectorValues::Float32(vec![2.5, -7.0]),
449 });
450 roundtrip(Vector::Sparse {
451 num_dimensions: 5,
452 indices: vec![2],
453 values: VectorValues::Int8(vec![42]),
454 });
455 }
456
457 #[test]
464 fn float_elements_use_oracle_binary_transform() {
465 let image = encode_vector(&Vector::Dense(VectorValues::Float64(vec![1.0, -2.0])));
467 let body = &image[17..]; assert_eq!(&body[0..8], &[0xbf, 0xf0, 0, 0, 0, 0, 0, 0], "f64 +1.0");
469 assert_eq!(
471 &body[8..16],
472 &[0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
473 "f64 -2.0"
474 );
475
476 let image32 = encode_vector(&Vector::Dense(VectorValues::Float32(vec![1.0, -2.0])));
478 let body32 = &image32[17..];
479 assert_eq!(&body32[0..4], &[0xbf, 0x80, 0, 0], "f32 +1.0");
480 assert_eq!(&body32[4..8], &[0x3f, 0xff, 0xff, 0xff], "f32 -2.0");
481
482 assert_eq!(
484 decode_vector(&image).expect("decode f64"),
485 Vector::Dense(VectorValues::Float64(vec![1.0, -2.0]))
486 );
487 assert_eq!(
488 decode_vector(&image32).expect("decode f32"),
489 Vector::Dense(VectorValues::Float32(vec![1.0, -2.0]))
490 );
491 }
492
493 #[test]
494 fn rejects_bad_magic() {
495 let err = decode_vector(&[0x00, 0, 0, 0, 0, 0, 0, 0, 0]).expect_err("bad magic must fail");
496 assert!(matches!(err, ProtocolError::TtcDecode(_)));
497 }
498
499 #[test]
500 fn rejects_unsupported_version() {
501 let mut image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1])));
502 image[1] = 99; let err = decode_vector(&image).expect_err("bad version must fail");
504 assert!(matches!(err, ProtocolError::TtcDecode(_)));
505 }
506
507 #[test]
513 fn fuzz_regression_oom_oversized_element_count() {
514 let input = [219, 0, 0, 18, 3, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
515 let err = decode_vector(&input).expect_err("oversized count must fail closed");
516 assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
517 }
518
519 #[test]
525 fn sparse_oversized_index_count_fails_closed_not_oom() {
526 let input = [
530 TNS_VECTOR_MAGIC_BYTE,
531 TNS_VECTOR_VERSION_WITH_SPARSE,
532 0x00,
533 0x20, VECTOR_FORMAT_FLOAT64,
535 0x00,
536 0x00,
537 0x00,
538 0x00, 0xFF,
540 0xFF, ];
542 let err = decode_vector(&input).expect_err("oversized sparse count must fail closed");
543 assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
544 }
545
546 #[test]
547 fn binary_dense_bit_count_header() {
548 let image = encode_vector(&Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
550 let num_elements = u32::from_be_bytes([image[5], image[6], image[7], image[8]]);
551 assert_eq!(num_elements, 16);
552 assert_eq!(image[1], TNS_VECTOR_VERSION_WITH_BINARY);
553 }
554
555 fn build_from_golden(entry: &Value) -> Vector {
559 let typecode = entry["typecode"].as_str().expect("typecode");
560 let f64_at = |x: &Value| x.as_f64().expect("number");
561 let i64_at = |x: &Value| x.as_i64().expect("int");
562 let u64_at = |x: &Value| x.as_u64().expect("uint");
563 let make_values = |arr: &Value| -> VectorValues {
564 let v = arr.as_array().expect("array");
565 match typecode {
566 "f" => VectorValues::Float32(v.iter().map(|x| f64_at(x) as f32).collect()),
567 "d" => VectorValues::Float64(v.iter().map(f64_at).collect()),
568 "b" => VectorValues::Int8(v.iter().map(|x| i64_at(x) as i8).collect()),
569 "B" => VectorValues::Binary(v.iter().map(|x| u64_at(x) as u8).collect()),
570 other => panic!("unknown typecode {other}"),
571 }
572 };
573 if entry["kind"] == "sparse" {
574 Vector::Sparse {
575 num_dimensions: u64_at(&entry["num_dimensions"]) as u32,
576 indices: entry["indices"]
577 .as_array()
578 .expect("indices array")
579 .iter()
580 .map(|x| u64_at(x) as u32)
581 .collect(),
582 values: make_values(&entry["values"]),
583 }
584 } else {
585 Vector::Dense(make_values(&entry["values"]))
586 }
587 }
588
589 #[test]
590 fn matches_golden_capture() {
591 let raw = include_str!("../tests/golden/vectors.json");
592 let golden: Value = serde_json::from_str(raw).expect("parse golden json");
593 let obj = golden.as_object().expect("golden is an object");
594 assert!(!obj.is_empty(), "golden capture must not be empty");
595 for (name, entry) in obj {
596 let expected_hex = entry["image_hex"].as_str().expect("image_hex");
597 let expected = hex::decode(expected_hex).expect("decode golden hex");
598
599 let vector = build_from_golden(entry);
601 let image = encode_vector(&vector);
602 assert_eq!(
603 hex::encode(&image),
604 expected_hex,
605 "encode mismatch for golden case {name}"
606 );
607
608 let decoded = decode_vector(&expected).expect("decode golden image");
610 assert_eq!(decoded, vector, "decode mismatch for golden case {name}");
611 }
612 }
613}