use super::deserialize_sync::q_ipc_decode_sync;
use super::serialize::ENCODING;
use super::{Error, Result, K};
use bytes::{BufMut, BytesMut};
use std::convert::TryInto;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
const HEADER_SIZE: usize = 8;
const COMPRESSION_THRESHOLD: usize = 2000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionMode {
Auto,
Always,
Never,
}
impl Default for CompressionMode {
fn default() -> Self {
CompressionMode::Auto
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationMode {
Strict,
Lenient,
}
impl Default for ValidationMode {
fn default() -> Self {
ValidationMode::Strict
}
}
#[derive(Clone, Copy, Debug)]
pub struct MessageHeader {
pub encoding: u8,
pub message_type: u8,
pub compressed: u8,
pub _unused: u8,
pub length: u32,
}
impl MessageHeader {
pub const fn size() -> usize {
HEADER_SIZE
}
pub fn from_bytes(buf: &[u8]) -> Result<Self> {
if buf.len() < HEADER_SIZE {
return Err(Error::InvalidMessageSize);
}
let encoding = buf[0];
let message_type = buf[1];
let compressed = buf[2];
let _unused = buf[3];
let length = match encoding {
0 => u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]),
_ => u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
};
Ok(MessageHeader {
encoding,
message_type,
compressed,
_unused,
length,
})
}
pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
let mut bytes = [0u8; HEADER_SIZE];
bytes[0] = self.encoding;
bytes[1] = self.message_type;
bytes[2] = self.compressed;
bytes[3] = self._unused;
let length_bytes = match self.encoding {
0 => self.length.to_be_bytes(),
_ => self.length.to_le_bytes(),
};
bytes[4..8].copy_from_slice(&length_bytes);
bytes
}
}
#[derive(Debug, Clone)]
pub struct KdbCodec {
is_local: bool,
compression_mode: CompressionMode,
validation_mode: ValidationMode,
max_list_size: usize,
max_recursion_depth: usize,
max_message_size: Option<usize>,
max_decompressed_size: Option<usize>,
}
#[bon::bon]
impl KdbCodec {
pub fn new(is_local: bool) -> Self {
KdbCodec {
is_local,
compression_mode: CompressionMode::Auto,
validation_mode: ValidationMode::Strict,
max_list_size: crate::MAX_LIST_SIZE,
max_recursion_depth: crate::MAX_RECURSION_DEPTH,
max_message_size: Some(crate::MAX_MESSAGE_SIZE),
max_decompressed_size: Some(crate::MAX_DECOMPRESSED_SIZE),
}
}
pub fn with_options(
is_local: bool,
compression_mode: CompressionMode,
validation_mode: ValidationMode,
max_list_size: usize,
max_recursion_depth: usize,
) -> Self {
KdbCodec {
is_local,
compression_mode,
validation_mode,
max_list_size,
max_recursion_depth,
max_message_size: Some(crate::MAX_MESSAGE_SIZE),
max_decompressed_size: Some(crate::MAX_DECOMPRESSED_SIZE),
}
}
#[builder]
pub fn builder(
#[builder(default = false)] is_local: bool,
#[builder(default)] compression_mode: CompressionMode,
#[builder(default)] validation_mode: ValidationMode,
#[builder(default = crate::MAX_LIST_SIZE)] max_list_size: usize,
#[builder(default = crate::MAX_RECURSION_DEPTH)] max_recursion_depth: usize,
max_message_size: Option<usize>,
max_decompressed_size: Option<usize>,
) -> Self {
KdbCodec {
is_local,
compression_mode,
validation_mode,
max_list_size,
max_recursion_depth,
max_message_size,
max_decompressed_size,
}
}
pub fn set_compression_mode(&mut self, mode: CompressionMode) {
self.compression_mode = mode;
}
pub fn compression_mode(&self) -> CompressionMode {
self.compression_mode
}
pub fn set_validation_mode(&mut self, mode: ValidationMode) {
self.validation_mode = mode;
}
pub fn validation_mode(&self) -> ValidationMode {
self.validation_mode
}
pub fn set_max_list_size(&mut self, size: usize) {
self.max_list_size = size;
}
pub fn max_list_size(&self) -> usize {
self.max_list_size
}
pub fn set_max_recursion_depth(&mut self, depth: usize) {
self.max_recursion_depth = depth;
}
pub fn max_recursion_depth(&self) -> usize {
self.max_recursion_depth
}
pub fn set_max_message_size(&mut self, size: Option<usize>) {
self.max_message_size = size;
}
pub fn max_message_size(&self) -> Option<usize> {
self.max_message_size
}
pub fn set_max_decompressed_size(&mut self, size: Option<usize>) {
self.max_decompressed_size = size;
}
pub fn max_decompressed_size(&self) -> Option<usize> {
self.max_decompressed_size
}
}
#[derive(Debug, Clone)]
pub struct KdbMessage {
pub message_type: u8,
pub payload: K,
}
impl KdbMessage {
pub fn new(message_type: u8, payload: K) -> Self {
KdbMessage {
message_type,
payload,
}
}
}
impl Encoder<KdbMessage> for KdbCodec {
type Error = io::Error;
fn encode(&mut self, item: KdbMessage, dst: &mut BytesMut) -> io::Result<()> {
let payload_bytes = item.payload.q_ipc_encode();
let message_length = payload_bytes.len();
let total_length = (HEADER_SIZE + message_length) as u32;
let should_compress = match self.compression_mode {
CompressionMode::Never => false,
CompressionMode::Always => message_length > COMPRESSION_THRESHOLD - HEADER_SIZE,
CompressionMode::Auto => {
message_length > COMPRESSION_THRESHOLD - HEADER_SIZE && !self.is_local
}
};
if should_compress {
let mut raw = Vec::with_capacity(HEADER_SIZE + message_length);
raw.extend_from_slice(&[ENCODING, item.message_type, 0, 0, 0, 0, 0, 0]);
raw.extend_from_slice(&payload_bytes);
match compress_sync(raw) {
(true, compressed) => {
dst.reserve(compressed.len());
dst.put_slice(&compressed);
}
(false, mut uncompressed) => {
let total_length_bytes = match ENCODING {
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes(),
};
uncompressed[4..8].copy_from_slice(&total_length_bytes);
dst.reserve(uncompressed.len());
dst.put_slice(&uncompressed);
}
}
} else {
let header = MessageHeader {
encoding: ENCODING,
message_type: item.message_type,
compressed: 0,
_unused: 0,
length: total_length,
};
dst.reserve(total_length as usize);
dst.put_slice(&header.to_bytes());
dst.put_slice(&payload_bytes);
}
Ok(())
}
}
impl Decoder for KdbCodec {
type Item = KdbMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
if src.len() < HEADER_SIZE {
return Ok(None);
}
let header = MessageHeader::from_bytes(&src[..HEADER_SIZE]).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("Invalid header: {}", e))
})?;
if self.validation_mode == ValidationMode::Strict {
if header.compressed > 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Invalid compressed flag: {}. Expected 0 (uncompressed) or 1 (compressed)",
header.compressed
),
));
}
if header.message_type > 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Invalid message type: {}. Expected 0 (async), 1 (sync), or 2 (response)",
header.message_type
),
));
}
}
let total_length = header.length as usize;
if total_length < HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Invalid message size: {}. Must be at least {} bytes (header size)",
total_length, HEADER_SIZE
),
));
}
if let Some(max_size) = self.max_message_size {
if total_length > max_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Message size {} exceeds maximum allowed size of {} bytes",
total_length, max_size
),
));
}
}
if src.len() < total_length {
src.reserve(total_length - src.len());
return Ok(None);
}
let message_data = src.split_to(total_length);
let payload_data = &message_data[HEADER_SIZE..];
let decoded_payload = if header.compressed == 1 {
decompress_sync(
payload_data.to_vec(),
header.encoding,
self.max_decompressed_size,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
} else {
payload_data.to_vec()
};
let k_object = q_ipc_decode_sync(
&decoded_payload,
header.encoding,
self.max_list_size,
self.max_recursion_depth,
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
Ok(Some(KdbMessage {
message_type: header.message_type,
payload: k_object,
}))
}
}
pub fn io_error_to_kdb_error(err: io::Error) -> Error {
Error::NetworkError(err.to_string())
}
pub fn compress_sync(raw: Vec<u8>) -> (bool, Vec<u8>) {
let mut i = 0_u8;
let mut f = 0_u8;
let mut h0 = 0_usize;
let mut h = 0_usize;
let mut g: bool;
let mut compressed: Vec<u8> = Vec::with_capacity((raw.len()) / 2);
compressed.resize((raw.len()) / 2, 0_u8);
let mut c = 12;
let mut d = c;
let e = compressed.len();
let mut p = 0_usize;
let mut q: usize;
let mut r: usize;
let mut s0 = 0_usize;
let mut s = 8_usize;
let t = raw.len();
let mut a = [0_i32; 256];
compressed[0..4].copy_from_slice(&raw[0..4]);
compressed[2] = 1;
let raw_size = match ENCODING {
0 => (t as u32).to_be_bytes(),
_ => (t as u32).to_le_bytes(),
};
compressed[8..12].copy_from_slice(&raw_size);
while s < t {
if i == 0 {
if d > e - 17 {
return (false, raw);
}
i = 1;
compressed[c] = f;
c = d;
d += 1;
f = 0;
}
g = s > t - 3;
if !g {
h = (raw[s] ^ raw[s + 1]) as usize;
p = a[h] as usize;
g = (0 == p) || (0 != (raw[s] ^ raw[p]));
}
if 0 < s0 {
a[h0] = s0 as i32;
s0 = 0;
}
if g {
h0 = h;
s0 = s;
compressed[d] = raw[s];
d += 1;
s += 1;
} else {
a[h] = s as i32;
f |= i;
p += 2;
s += 2;
r = s;
q = if s + 255 > t { t } else { s + 255 };
while (s < q) && (raw[p] == raw[s]) {
s += 1;
if s < q {
p += 1;
}
}
compressed[d] = h as u8;
d += 1;
compressed[d] = (s - r) as u8;
d += 1;
}
i = i.wrapping_mul(2);
}
compressed[c] = f;
let compressed_size = match ENCODING {
0 => (d as u32).to_be_bytes(),
_ => (d as u32).to_le_bytes(),
};
compressed[4..8].copy_from_slice(&compressed_size);
let _ = compressed.split_off(d);
(true, compressed)
}
pub fn decompress_sync(
compressed: Vec<u8>,
encoding: u8,
max_decompressed_size: Option<usize>,
) -> Result<Vec<u8>> {
let mut n = 0;
let mut r: usize;
let mut f = 0_usize;
let mut s = 0_usize;
let mut p = s;
let mut i = 0_usize;
if compressed.len() < 4 {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: need at least 4 bytes for size field, got {}",
compressed.len()
)));
}
let size_with_header = match encoding {
0 => i32::from_be_bytes(compressed[0..4].try_into().map_err(|_| {
Error::DeserializationError(
"Invalid compressed data: header size field must be 4 bytes".to_string(),
)
})?),
_ => i32::from_le_bytes(compressed[0..4].try_into().map_err(|_| {
Error::DeserializationError(
"Invalid compressed data: header size field must be 4 bytes".to_string(),
)
})?),
};
if size_with_header < 8 {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: size {} is less than minimum header size",
size_with_header
)));
}
let size = (size_with_header - 8) as usize;
if let Some(max_size) = max_decompressed_size {
if size > max_size {
return Err(Error::DeserializationError(format!(
"Decompressed size {} exceeds maximum allowed size {} (possible compression bomb)",
size, max_size
)));
}
}
let mut decompressed: Vec<u8> = Vec::with_capacity(size);
decompressed.resize(size, 0_u8);
let mut d = 4;
let mut aa = [0_i32; 256];
while s < decompressed.len() {
if i == 0 {
if d >= compressed.len() {
return Err(Error::DeserializationError(
"Invalid compressed data: unexpected end of compressed data".to_string(),
));
}
f = (0xff & compressed[d]) as usize;
d += 1;
i = 1;
}
if (f & i) != 0 {
if d + 2 > compressed.len() {
return Err(Error::DeserializationError(
"Invalid compressed data: insufficient data for back-reference".to_string(),
));
}
r = aa[(0xff & compressed[d]) as usize] as usize;
d += 1;
if r >= decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: back-reference {} exceeds buffer size {}",
r,
decompressed.len()
)));
}
if s >= decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: write index {} exceeds buffer size {}",
s,
decompressed.len()
)));
}
decompressed[s] = decompressed[r];
s += 1;
r += 1;
if r >= decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: back-reference {} exceeds buffer size {}",
r,
decompressed.len()
)));
}
if s >= decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: write index {} exceeds buffer size {}",
s,
decompressed.len()
)));
}
decompressed[s] = decompressed[r];
s += 1;
r += 1;
n = (0xff & compressed[d]) as usize;
d += 1;
if r + n > decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: back-reference range {}..{} exceeds buffer size {}",
r,
r + n,
decompressed.len()
)));
}
if s + n > decompressed.len() {
return Err(Error::DeserializationError(format!(
"Invalid compressed data: write range {}..{} exceeds buffer size {}",
s,
s + n,
decompressed.len()
)));
}
for m in 0..n {
decompressed[s + m] = decompressed[r + m];
}
} else {
if d >= compressed.len() {
return Err(Error::DeserializationError(
"Invalid compressed data: unexpected end of compressed data".to_string(),
));
}
decompressed[s] = compressed[d];
s += 1;
d += 1;
}
while p < s - 1 {
aa[((0xff & decompressed[p]) ^ (0xff & decompressed[p + 1])) as usize] = p as i32;
p += 1;
}
if (f & i) != 0 {
s += n;
p = s;
}
i *= 2;
if i == 256 {
i = 0;
}
}
Ok(decompressed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{k, qmsg_type};
#[test]
fn test_compress_decompress_roundtrip() {
let large_list = k!(long: vec![1; 3000]);
let message = KdbMessage::new(1, large_list);
let mut codec = KdbCodec::new(false); let mut buffer = BytesMut::new();
codec.encode(message.clone(), &mut buffer).unwrap();
assert!(buffer.len() > 0);
let decoded = codec.decode(&mut buffer).unwrap();
assert!(decoded.is_some());
let response = decoded.unwrap();
assert_eq!(response.message_type, 1);
let decoded_list = response.payload.as_vec::<i64>().unwrap();
assert_eq!(decoded_list.len(), 3000);
assert_eq!(decoded_list[0], 1);
}
#[test]
fn test_small_message_no_compression() {
let small_list = k!(long: vec![1, 2, 3, 4, 5]);
let message = KdbMessage::new(1, small_list);
let mut codec = KdbCodec::new(false); let mut buffer = BytesMut::new();
codec.encode(message.clone(), &mut buffer).unwrap();
assert_eq!(buffer[2], 0);
let decoded = codec.decode(&mut buffer).unwrap();
assert!(decoded.is_some());
let response = decoded.unwrap();
let decoded_list = response.payload.as_vec::<i64>().unwrap();
assert_eq!(*decoded_list, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_local_connection_no_compression() {
let large_list = k!(long: vec![42; 3000]);
let message = KdbMessage::new(1, large_list);
let mut codec = KdbCodec::new(true); let mut buffer = BytesMut::new();
codec.encode(message.clone(), &mut buffer).unwrap();
assert_eq!(buffer[2], 0);
let decoded = codec.decode(&mut buffer).unwrap();
assert!(decoded.is_some());
let response = decoded.unwrap();
let decoded_list = response.payload.as_vec::<i64>().unwrap();
assert_eq!(decoded_list.len(), 3000);
assert_eq!(decoded_list[0], 42);
}
#[test]
fn test_string_query_encoding() {
let mut codec = KdbCodec::new(true);
let mut buffer = BytesMut::new();
let query = k!(string: "1+1");
let message = KdbMessage::new(qmsg_type::synchronous, query);
codec.encode(message, &mut buffer).unwrap();
assert!(buffer.len() > HEADER_SIZE);
assert_eq!(buffer[1], qmsg_type::synchronous);
assert_eq!(buffer[2], 0);
}
#[test]
fn test_message_header_roundtrip() {
let header = MessageHeader {
encoding: ENCODING,
message_type: 1,
compressed: 1,
_unused: 0,
length: 1234,
};
let bytes = header.to_bytes();
let parsed = MessageHeader::from_bytes(&bytes).unwrap();
assert_eq!(parsed.encoding, header.encoding);
assert_eq!(parsed.message_type, header.message_type);
assert_eq!(parsed.compressed, header.compressed);
assert_eq!(parsed.length, header.length);
}
#[test]
fn test_compress_decompress_direct() {
let payload = vec![42u8; 2000]; let mut raw = Vec::new();
raw.extend_from_slice(&[ENCODING, 1, 0, 0, 0, 0, 0, 0]);
raw.extend_from_slice(&payload);
let original_size = raw.len();
let (was_compressed, compressed_data) = compress_sync(raw.clone());
println!("Original size: {}", original_size);
println!("Compressed data size: {}", compressed_data.len());
println!("Was compressed: {}", was_compressed);
assert!(was_compressed, "Large repetitive data should compress");
assert_eq!(compressed_data[0], ENCODING);
assert_eq!(compressed_data[1], 1); assert_eq!(compressed_data[2], 1);
let compressed_size = match ENCODING {
0 => u32::from_be_bytes([
compressed_data[4],
compressed_data[5],
compressed_data[6],
compressed_data[7],
]),
_ => u32::from_le_bytes([
compressed_data[4],
compressed_data[5],
compressed_data[6],
compressed_data[7],
]),
};
assert_eq!(compressed_size as usize, compressed_data.len());
let uncompressed_size = match ENCODING {
0 => u32::from_be_bytes([
compressed_data[8],
compressed_data[9],
compressed_data[10],
compressed_data[11],
]),
_ => u32::from_le_bytes([
compressed_data[8],
compressed_data[9],
compressed_data[10],
compressed_data[11],
]),
};
assert_eq!(uncompressed_size as usize, original_size);
let payload_data = &compressed_data[HEADER_SIZE..];
let decompressed = decompress_sync(payload_data.to_vec(), ENCODING, None)
.expect("Decompression should succeed");
assert_eq!(
decompressed, payload,
"Decompressed payload should match original"
);
}
#[test]
fn test_compression_with_large_data() {
let large_payload = vec![42u8; 3000];
let mut raw = Vec::new();
raw.extend_from_slice(&[ENCODING, 1, 0, 0, 0, 0, 0, 0]);
raw.extend_from_slice(&large_payload);
let original_size = raw.len();
let (was_compressed, compressed_data) = compress_sync(raw);
assert!(was_compressed, "Large repetitive data should compress");
assert!(
compressed_data.len() < original_size,
"Compressed size {} should be less than original size {}",
compressed_data.len(),
original_size
);
let payload_data = &compressed_data[HEADER_SIZE..];
let decompressed = decompress_sync(payload_data.to_vec(), ENCODING, None)
.expect("Decompression should succeed");
assert_eq!(decompressed, large_payload);
}
#[test]
fn test_codec_with_compression_end_to_end() {
let large_list = k!(long: vec![123; 2500]);
let message = KdbMessage::new(qmsg_type::synchronous, large_list.clone());
let mut codec = KdbCodec::new(false);
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(header.compressed, 1, "Large message should be compressed");
let decoded = codec.decode(&mut buffer).unwrap();
assert!(decoded.is_some());
let response = decoded.unwrap();
assert_eq!(response.message_type, qmsg_type::synchronous);
let decoded_list = response.payload.as_vec::<i64>().unwrap();
assert_eq!(decoded_list.len(), 2500);
assert_eq!(decoded_list[0], 123);
assert_eq!(decoded_list[2499], 123);
}
#[test]
fn test_compression_mode_never() {
let large_list = k!(long: vec![42; 3000]);
let message = KdbMessage::new(qmsg_type::synchronous, large_list);
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Never)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(header.compressed, 0, "Never mode should not compress");
}
#[test]
fn test_compression_mode_always() {
let large_list = k!(long: vec![42; 3000]);
let message = KdbMessage::new(qmsg_type::synchronous, large_list);
let mut codec = KdbCodec::builder()
.is_local(true)
.compression_mode(CompressionMode::Always)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(
header.compressed, 1,
"Always mode should compress even on local"
);
}
#[test]
fn test_compression_mode_auto_local() {
let large_list = k!(long: vec![42; 3000]);
let message = KdbMessage::new(qmsg_type::synchronous, large_list);
let mut codec = KdbCodec::builder()
.is_local(true)
.compression_mode(CompressionMode::Auto)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(
header.compressed, 0,
"Auto mode should not compress local connections"
);
}
#[test]
fn test_compression_mode_auto_remote() {
let large_list = k!(long: vec![42; 3000]);
let message = KdbMessage::new(qmsg_type::synchronous, large_list);
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Auto)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(
header.compressed, 1,
"Auto mode should compress remote large messages"
);
}
#[test]
fn test_validation_mode_strict_invalid_compressed() {
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Never)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
buffer.extend_from_slice(&[ENCODING, 1, 2, 0]); buffer.extend_from_slice(&[20, 0, 0, 0]); buffer.extend_from_slice(&[0; 12]);
let result = codec.decode(&mut buffer);
assert!(
result.is_err(),
"Strict mode should reject invalid compressed flag"
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("Invalid compressed flag"),
"Error message should mention compressed flag, got: {}",
err
);
}
#[test]
fn test_validation_mode_strict_invalid_message_type() {
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Never)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
buffer.extend_from_slice(&[ENCODING, 3, 0, 0]); buffer.extend_from_slice(&[20, 0, 0, 0]); buffer.extend_from_slice(&[0; 12]);
let result = codec.decode(&mut buffer);
assert!(
result.is_err(),
"Strict mode should reject invalid message type"
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("Invalid message type"),
"Error message should mention message type, got: {}",
err
);
}
#[test]
fn test_validation_mode_lenient_accepts_invalid() {
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Never)
.validation_mode(ValidationMode::Lenient)
.build();
let small_int = k!(int: 42);
let payload_bytes = small_int.q_ipc_encode();
let total_length = (HEADER_SIZE + payload_bytes.len()) as u32;
let mut buffer = BytesMut::new();
buffer.extend_from_slice(&[ENCODING, 5, 3, 0]);
let length_bytes = match ENCODING {
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes(),
};
buffer.extend_from_slice(&length_bytes);
buffer.extend_from_slice(&payload_bytes);
let result = codec.decode(&mut buffer);
assert!(
result.is_ok(),
"Lenient mode should accept non-standard values"
);
assert!(
result.unwrap().is_some(),
"Should decode message successfully"
);
}
#[test]
fn test_codec_getters_setters() {
let mut codec = KdbCodec::new(false);
assert_eq!(codec.compression_mode(), CompressionMode::Auto);
assert_eq!(codec.validation_mode(), ValidationMode::Strict);
codec.set_compression_mode(CompressionMode::Always);
codec.set_validation_mode(ValidationMode::Lenient);
assert_eq!(codec.compression_mode(), CompressionMode::Always);
assert_eq!(codec.validation_mode(), ValidationMode::Lenient);
}
#[test]
fn test_compression_mode_small_message() {
let small_int = k!(int: 42);
let message = KdbMessage::new(qmsg_type::synchronous, small_int);
let mut codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Always)
.validation_mode(ValidationMode::Strict)
.build();
let mut buffer = BytesMut::new();
codec.encode(message, &mut buffer).unwrap();
let header = MessageHeader::from_bytes(&buffer[..HEADER_SIZE]).unwrap();
assert_eq!(
header.compressed, 0,
"Small messages should not be compressed"
);
}
#[test]
fn test_codec_builder_pattern() {
let codec = KdbCodec::builder()
.is_local(false)
.compression_mode(CompressionMode::Always)
.validation_mode(ValidationMode::Lenient)
.build();
assert_eq!(codec.compression_mode(), CompressionMode::Always);
assert_eq!(codec.validation_mode(), ValidationMode::Lenient);
}
#[test]
fn test_codec_builder_with_defaults() {
let codec = KdbCodec::builder().build();
assert_eq!(codec.compression_mode(), CompressionMode::Auto);
assert_eq!(codec.validation_mode(), ValidationMode::Strict);
}
#[test]
fn test_codec_builder_partial() {
let codec = KdbCodec::builder()
.compression_mode(CompressionMode::Never)
.build();
assert_eq!(codec.compression_mode(), CompressionMode::Never);
assert_eq!(codec.validation_mode(), ValidationMode::Strict); }
}