use crate::wire::{BoundedReader, ProtocolLimits, TtcReader, TtcWriter};
use crate::{ProtocolError, Result};
pub const TNS_VECTOR_MAGIC_BYTE: u8 = 0xDB;
pub const TNS_VECTOR_VERSION_BASE: u8 = 0;
pub const TNS_VECTOR_VERSION_WITH_BINARY: u8 = 1;
pub const TNS_VECTOR_VERSION_WITH_SPARSE: u8 = 2;
pub const TNS_VECTOR_FLAG_NORM: u16 = 0x0002;
pub const TNS_VECTOR_FLAG_NORM_RESERVED: u16 = 0x0010;
pub const TNS_VECTOR_FLAG_SPARSE: u16 = 0x0020;
pub const VECTOR_FORMAT_FLOAT32: u8 = 2;
pub const VECTOR_FORMAT_FLOAT64: u8 = 3;
pub const VECTOR_FORMAT_INT8: u8 = 4;
pub const VECTOR_FORMAT_BINARY: u8 = 5;
#[derive(Clone, Debug, PartialEq)]
pub enum VectorValues {
Float32(Vec<f32>),
Float64(Vec<f64>),
Int8(Vec<i8>),
Binary(Vec<u8>),
}
impl VectorValues {
pub fn format(&self) -> u8 {
match self {
VectorValues::Float32(_) => VECTOR_FORMAT_FLOAT32,
VectorValues::Float64(_) => VECTOR_FORMAT_FLOAT64,
VectorValues::Int8(_) => VECTOR_FORMAT_INT8,
VectorValues::Binary(_) => VECTOR_FORMAT_BINARY,
}
}
pub fn len(&self) -> usize {
match self {
VectorValues::Float32(v) => v.len(),
VectorValues::Float64(v) => v.len(),
VectorValues::Int8(v) => v.len(),
VectorValues::Binary(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Vector {
Dense(VectorValues),
Sparse {
num_dimensions: u32,
indices: Vec<u32>,
values: VectorValues,
},
}
pub fn decode_vector(data: &[u8]) -> Result<Vector> {
decode_vector_with_limits(data, ProtocolLimits::DEFAULT)
}
pub fn decode_vector_with_limits(data: &[u8], limits: ProtocolLimits) -> Result<Vector> {
limits.check_response_bytes(data.len())?;
let mut reader = TtcReader::with_limits(data, limits)?;
let magic = reader.read_u8()?;
if magic != TNS_VECTOR_MAGIC_BYTE {
return Err(ProtocolError::TtcDecode("vector: bad magic byte"));
}
let version = reader.read_u8()?;
if version > TNS_VECTOR_VERSION_WITH_SPARSE {
return Err(ProtocolError::TtcDecode("vector: unsupported version"));
}
let flags = read_u16be(&mut reader)?;
let format = reader.read_u8()?;
let mut num_elements = read_u32be(&mut reader)?;
reader
.limits()
.check_vector_dimensions(num_elements as usize)?;
if flags & TNS_VECTOR_FLAG_NORM_RESERVED != 0 || flags & TNS_VECTOR_FLAG_NORM != 0 {
reader.skip(8)?;
}
if flags & TNS_VECTOR_FLAG_SPARSE != 0 {
let num_dimensions = num_elements;
let num_sparse = read_u16be(&mut reader)?;
reader
.limits()
.check_vector_dimensions(usize::from(num_sparse))?;
let mut indices: Vec<u32> = reader.with_capacity_limited(
usize::from(num_sparse),
4,
ProtocolLimits::check_vector_dimensions,
)?;
for _ in 0..num_sparse {
indices.push(read_u32be(&mut reader)?);
}
let values = decode_values(&mut reader, u32::from(num_sparse), format)?;
return Ok(Vector::Sparse {
num_dimensions,
indices,
values,
});
}
if format == VECTOR_FORMAT_BINARY {
num_elements /= 8;
}
let values = decode_values(&mut reader, num_elements, format)?;
Ok(Vector::Dense(values))
}
fn decode_values(reader: &mut TtcReader<'_>, count: u32, format: u8) -> Result<VectorValues> {
let count = count as usize;
reader.limits().check_vector_dimensions(count)?;
match format {
VECTOR_FORMAT_FLOAT32 => {
let mut out: Vec<f32> =
reader.with_capacity_limited(count, 4, ProtocolLimits::check_vector_dimensions)?;
for _ in 0..count {
let raw = reader.read_raw(4)?;
out.push(decode_binary_float([raw[0], raw[1], raw[2], raw[3]]));
}
Ok(VectorValues::Float32(out))
}
VECTOR_FORMAT_FLOAT64 => {
let mut out: Vec<f64> =
reader.with_capacity_limited(count, 8, ProtocolLimits::check_vector_dimensions)?;
for _ in 0..count {
let raw = reader.read_raw(8)?;
out.push(decode_binary_double([
raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
]));
}
Ok(VectorValues::Float64(out))
}
VECTOR_FORMAT_INT8 => {
let mut out: Vec<i8> =
reader.with_capacity_limited(count, 1, ProtocolLimits::check_vector_dimensions)?;
for _ in 0..count {
out.push(reader.read_u8()? as i8);
}
Ok(VectorValues::Int8(out))
}
VECTOR_FORMAT_BINARY => Ok(VectorValues::Binary(reader.read_raw(count)?.to_vec())),
_ => Err(ProtocolError::TtcDecode(
"vector: unsupported element format",
)),
}
}
pub fn encode_vector(vector: &Vector) -> Vec<u8> {
match encode_vector_checked(vector) {
Ok(image) => image,
Err(err) => panic!("invalid VECTOR value for encoding: {err}"),
}
}
pub(crate) fn encode_vector_checked(vector: &Vector) -> Result<Vec<u8>> {
let mut buf = Vec::new();
let mut flags = TNS_VECTOR_FLAG_NORM_RESERVED;
let (format, version, num_elements) = match vector {
Vector::Sparse {
num_dimensions,
values,
..
} => {
flags |= TNS_VECTOR_FLAG_SPARSE | TNS_VECTOR_FLAG_NORM;
(
values.format(),
TNS_VECTOR_VERSION_WITH_SPARSE,
*num_dimensions,
)
}
Vector::Dense(values) => {
let format = values.format();
if format == VECTOR_FORMAT_BINARY {
(
format,
TNS_VECTOR_VERSION_WITH_BINARY,
(values.len() as u32) * 8,
)
} else {
flags |= TNS_VECTOR_FLAG_NORM;
(format, TNS_VECTOR_VERSION_BASE, values.len() as u32)
}
}
};
buf.push(TNS_VECTOR_MAGIC_BYTE);
buf.push(version);
buf.extend_from_slice(&flags.to_be_bytes());
buf.push(format);
buf.extend_from_slice(&num_elements.to_be_bytes());
buf.extend_from_slice(&[0u8; 8]);
match vector {
Vector::Dense(values) => encode_values(&mut buf, values),
Vector::Sparse {
indices, values, ..
} => {
if indices.len() != values.len() {
return Err(ProtocolError::TtcDecode(
"vector: sparse index/value count mismatch",
));
}
let num_sparse =
u16::try_from(indices.len()).map_err(|_| ProtocolError::InvalidPacketLength {
length: indices.len(),
minimum: 0,
})?;
buf.extend_from_slice(&num_sparse.to_be_bytes());
for index in indices {
buf.extend_from_slice(&index.to_be_bytes());
}
encode_values(&mut buf, values);
}
}
Ok(buf)
}
fn encode_values(buf: &mut Vec<u8>, values: &VectorValues) {
match values {
VectorValues::Float32(v) => {
for value in v {
buf.extend_from_slice(&encode_binary_float(*value));
}
}
VectorValues::Float64(v) => {
for value in v {
buf.extend_from_slice(&encode_binary_double(*value));
}
}
VectorValues::Int8(v) => {
for value in v {
buf.push(*value as u8);
}
}
VectorValues::Binary(v) => buf.extend_from_slice(v),
}
}
fn decode_binary_double(bytes: [u8; 8]) -> f64 {
let mut decoded = bytes;
if decoded[0] & 0x80 != 0 {
decoded[0] &= 0x7f;
} else {
for byte in &mut decoded {
*byte = !*byte;
}
}
f64::from_bits(u64::from_be_bytes(decoded))
}
fn decode_binary_float(bytes: [u8; 4]) -> f32 {
let mut decoded = bytes;
if decoded[0] & 0x80 != 0 {
decoded[0] &= 0x7f;
} else {
for byte in &mut decoded {
*byte = !*byte;
}
}
f32::from_bits(u32::from_be_bytes(decoded))
}
fn encode_binary_double(value: f64) -> [u8; 8] {
let mut bytes = value.to_bits().to_be_bytes();
if bytes[0] & 0x80 == 0 {
bytes[0] |= 0x80;
} else {
for byte in &mut bytes {
*byte = !*byte;
}
}
bytes
}
fn encode_binary_float(value: f32) -> [u8; 4] {
let mut bytes = value.to_bits().to_be_bytes();
if bytes[0] & 0x80 == 0 {
bytes[0] |= 0x80;
} else {
for byte in &mut bytes {
*byte = !*byte;
}
}
bytes
}
fn read_u16be(reader: &mut TtcReader<'_>) -> Result<u16> {
let raw = reader.read_raw(2)?;
Ok(u16::from_be_bytes([raw[0], raw[1]]))
}
fn read_u32be(reader: &mut TtcReader<'_>) -> Result<u32> {
let raw = reader.read_raw(4)?;
Ok(u32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]))
}
pub fn write_vector_image(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
write_qlocator(writer, image.len() as u64, true);
writer.write_bytes_with_length(image)?;
Ok(())
}
pub fn write_oson_aq_payload(writer: &mut TtcWriter, image: &[u8]) -> Result<()> {
write_qlocator(writer, image.len() as u64, false);
writer.write_bytes_with_length(image)?;
Ok(())
}
fn write_qlocator(writer: &mut TtcWriter, data_length: u64, write_length: bool) {
const TNS_LOB_QLOCATOR_VERSION: u16 = 4;
const TNS_LOB_LOC_FLAGS_VALUE_BASED: u8 = 0x20;
const TNS_LOB_LOC_FLAGS_BLOB: u8 = 0x01;
const TNS_LOB_LOC_FLAGS_ABSTRACT: u8 = 0x40;
const TNS_LOB_LOC_FLAGS_INIT: u8 = 0x08;
writer.write_ub4(40); if write_length {
writer.write_u8(40); }
writer.write_u16be(38); writer.write_u16be(TNS_LOB_QLOCATOR_VERSION);
writer.write_u8(
TNS_LOB_LOC_FLAGS_VALUE_BASED | TNS_LOB_LOC_FLAGS_BLOB | TNS_LOB_LOC_FLAGS_ABSTRACT,
);
writer.write_u8(TNS_LOB_LOC_FLAGS_INIT);
writer.write_u16be(0); writer.write_u16be(1); writer.write_u64be(data_length);
writer.write_u16be(0); writer.write_u16be(0); writer.write_u16be(0); writer.write_u64be(0); writer.write_u64be(0); }
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
fn roundtrip(vector: Vector) {
let image = encode_vector(&vector);
let decoded = decode_vector(&image).expect("decode");
assert_eq!(decoded, vector);
}
#[test]
fn legitimate_large_vector_still_decodes_fully() {
let big_f32: Vec<f32> = (0..4096).map(|i| i as f32 * 0.5 - 1024.0).collect();
roundtrip(Vector::Dense(VectorValues::Float32(big_f32)));
let big_f64: Vec<f64> = (0..2048).map(|i| i as f64 * 0.25).collect();
roundtrip(Vector::Dense(VectorValues::Float64(big_f64)));
roundtrip(Vector::Sparse {
num_dimensions: 100_000,
indices: (0..1000).map(|i| i * 7).collect(),
values: VectorValues::Float32((0..1000).map(|i| i as f32).collect()),
});
}
#[test]
fn roundtrips_every_dense_format() {
roundtrip(Vector::Dense(VectorValues::Float32(vec![
1.5, -2.25, 3.0, 0.0,
])));
roundtrip(Vector::Dense(VectorValues::Float64(vec![
6501.0, 25.25, 18.125, -3.5,
])));
roundtrip(Vector::Dense(VectorValues::Int8(vec![
-5, 1, -2, 127, -128,
])));
roundtrip(Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
}
#[test]
fn roundtrips_every_sparse_format() {
roundtrip(Vector::Sparse {
num_dimensions: 8,
indices: vec![1, 4, 6],
values: VectorValues::Float64(vec![1.5, -2.0, 9.25]),
});
roundtrip(Vector::Sparse {
num_dimensions: 6,
indices: vec![0, 3],
values: VectorValues::Float32(vec![2.5, -7.0]),
});
roundtrip(Vector::Sparse {
num_dimensions: 5,
indices: vec![2],
values: VectorValues::Int8(vec![42]),
});
}
#[test]
fn sparse_int8_roundtrips_max_u16_count() {
let indices = (0..u16::MAX).map(u32::from).collect::<Vec<_>>();
let values = VectorValues::Int8((0..u16::MAX).map(|i| (i % 127) as i8).collect::<Vec<_>>());
let vector = Vector::Sparse {
num_dimensions: u32::from(u16::MAX),
indices,
values,
};
let image = encode_vector_checked(&vector).expect("encode max u16 sparse vector");
let decoded = decode_vector(&image).expect("decode max u16 sparse vector");
assert_eq!(decoded, vector);
}
#[test]
fn sparse_int8_rejects_count_that_exceeds_wire_field() {
let count = usize::from(u16::MAX) + 1;
let vector = Vector::Sparse {
num_dimensions: count as u32,
indices: (0..count as u32).collect(),
values: VectorValues::Int8(vec![1; count]),
};
let err = encode_vector_checked(&vector).expect_err("oversized sparse count must fail");
assert!(
matches!(
err,
ProtocolError::InvalidPacketLength {
length,
minimum: 0
} if length == count
),
"got {err:?}"
);
}
#[test]
fn sparse_encode_rejects_mismatched_index_value_counts() {
let vector = Vector::Sparse {
num_dimensions: 4,
indices: vec![0, 1, 2],
values: VectorValues::Int8(vec![7, 8]),
};
let err = encode_vector_checked(&vector).expect_err("mismatched sparse vector must fail");
assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
}
#[test]
fn float_elements_use_oracle_binary_transform() {
let image = encode_vector(&Vector::Dense(VectorValues::Float64(vec![1.0, -2.0])));
let body = &image[17..]; assert_eq!(&body[0..8], &[0xbf, 0xf0, 0, 0, 0, 0, 0, 0], "f64 +1.0");
assert_eq!(
&body[8..16],
&[0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
"f64 -2.0"
);
let image32 = encode_vector(&Vector::Dense(VectorValues::Float32(vec![1.0, -2.0])));
let body32 = &image32[17..];
assert_eq!(&body32[0..4], &[0xbf, 0x80, 0, 0], "f32 +1.0");
assert_eq!(&body32[4..8], &[0x3f, 0xff, 0xff, 0xff], "f32 -2.0");
assert_eq!(
decode_vector(&image).expect("decode f64"),
Vector::Dense(VectorValues::Float64(vec![1.0, -2.0]))
);
assert_eq!(
decode_vector(&image32).expect("decode f32"),
Vector::Dense(VectorValues::Float32(vec![1.0, -2.0]))
);
}
#[test]
fn rejects_bad_magic() {
let err = decode_vector(&[0x00, 0, 0, 0, 0, 0, 0, 0, 0]).expect_err("bad magic must fail");
assert!(matches!(err, ProtocolError::TtcDecode(_)));
}
#[test]
fn rejects_unsupported_version() {
let mut image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1])));
image[1] = 99; let err = decode_vector(&image).expect_err("bad version must fail");
assert!(matches!(err, ProtocolError::TtcDecode(_)));
}
#[test]
fn fuzz_regression_oom_oversized_element_count() {
let input = [219, 0, 0, 18, 3, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let err = decode_vector(&input).expect_err("oversized count must fail closed");
assert!(
matches!(
err,
ProtocolError::TtcDecode(_) | ProtocolError::ResourceLimit { .. }
),
"got {err:?}"
);
}
#[test]
fn decode_vector_with_limits_rejects_dense_dimensions() {
let image = encode_vector(&Vector::Dense(VectorValues::Int8(vec![1, 2, 3, 4, 5])));
let limits = ProtocolLimits {
max_vector_dimensions: 4,
..ProtocolLimits::DEFAULT
};
assert!(matches!(
decode_vector_with_limits(&image, limits),
Err(ProtocolError::ResourceLimit {
limit: "vector_dimensions",
observed: 5,
maximum: 4,
})
));
}
#[test]
fn sparse_oversized_index_count_fails_closed_not_oom() {
let input = [
TNS_VECTOR_MAGIC_BYTE,
TNS_VECTOR_VERSION_WITH_SPARSE,
0x00,
0x20, VECTOR_FORMAT_FLOAT64,
0x00,
0x00,
0x00,
0x00, 0xFF,
0xFF, ];
let err = decode_vector(&input).expect_err("oversized sparse count must fail closed");
assert!(matches!(err, ProtocolError::TtcDecode(_)), "got {err:?}");
}
#[test]
fn binary_dense_bit_count_header() {
let image = encode_vector(&Vector::Dense(VectorValues::Binary(vec![0xA5, 0x3C])));
let num_elements = u32::from_be_bytes([image[5], image[6], image[7], image[8]]);
assert_eq!(num_elements, 16);
assert_eq!(image[1], TNS_VECTOR_VERSION_WITH_BINARY);
}
fn build_from_golden(entry: &Value) -> Vector {
let typecode = entry["typecode"].as_str().expect("typecode");
let f64_at = |x: &Value| x.as_f64().expect("number");
let i64_at = |x: &Value| x.as_i64().expect("int");
let u64_at = |x: &Value| x.as_u64().expect("uint");
let make_values = |arr: &Value| -> VectorValues {
let v = arr.as_array().expect("array");
match typecode {
"f" => VectorValues::Float32(v.iter().map(|x| f64_at(x) as f32).collect()),
"d" => VectorValues::Float64(v.iter().map(f64_at).collect()),
"b" => VectorValues::Int8(v.iter().map(|x| i64_at(x) as i8).collect()),
"B" => VectorValues::Binary(v.iter().map(|x| u64_at(x) as u8).collect()),
other => panic!("unknown typecode {other}"),
}
};
if entry["kind"] == "sparse" {
Vector::Sparse {
num_dimensions: u64_at(&entry["num_dimensions"]) as u32,
indices: entry["indices"]
.as_array()
.expect("indices array")
.iter()
.map(|x| u64_at(x) as u32)
.collect(),
values: make_values(&entry["values"]),
}
} else {
Vector::Dense(make_values(&entry["values"]))
}
}
#[test]
fn matches_golden_capture() {
let raw = include_str!("../tests/golden/vectors.json");
let golden: Value = serde_json::from_str(raw).expect("parse golden json");
let obj = golden.as_object().expect("golden is an object");
assert!(!obj.is_empty(), "golden capture must not be empty");
for (name, entry) in obj {
let expected_hex = entry["image_hex"].as_str().expect("image_hex");
let expected = hex::decode(expected_hex).expect("decode golden hex");
let vector = build_from_golden(entry);
let image = encode_vector(&vector);
assert_eq!(
hex::encode(&image),
expected_hex,
"encode mismatch for golden case {name}"
);
let decoded = decode_vector(&expected).expect("decode golden image");
assert_eq!(decoded, vector, "decode mismatch for golden case {name}");
}
}
}