use crate::buffer::{ReadBuffer, WriteBuffer};
use crate::constants::{collection_type, obj_flags, OracleType};
use crate::dbobject::{CollectionType, DbObject, DbObjectType};
use crate::error::{Error, Result};
use crate::row::Value;
use crate::types::{decode_oracle_number, encode_oracle_number};
const LONG_LENGTH_INDICATOR: u8 = 254;
const NULL_LENGTH_INDICATOR: u8 = 255;
pub fn decode_collection(obj_type: &DbObjectType, data: &[u8]) -> Result<DbObject> {
if data.is_empty() {
return Ok(DbObject::collection(obj_type.full_name()));
}
let mut buf = ReadBuffer::from(data);
let (_flags, _version) = read_header(&mut buf)?;
let mut obj = DbObject::collection(obj_type.full_name());
let _collection_type_byte = buf.read_u8()?;
let num_elements = read_length(&mut buf)?;
let element_type = obj_type.element_type.unwrap_or(OracleType::Varchar);
let coll_type = obj_type.collection_type.unwrap_or(CollectionType::Varray);
for _ in 0..num_elements {
if coll_type == CollectionType::PlsqlIndexTable {
let _index = buf.read_u32_be()?;
}
let value = decode_value(&mut buf, element_type)?;
obj.elements.push(value);
}
Ok(obj)
}
pub fn encode_collection(obj: &DbObject, obj_type: &DbObjectType) -> Result<Vec<u8>> {
let mut buf = WriteBuffer::new();
let flags = obj_flags::IS_VERSION_81 | obj_flags::IS_COLLECTION;
buf.write_u8(flags)?;
buf.write_u8(obj_flags::IMAGE_VERSION)?;
let length_pos = buf.len();
buf.write_u8(LONG_LENGTH_INDICATOR)?;
buf.write_u32_be(0)?;
buf.write_u8(1)?; buf.write_u8(1)?;
let coll_flags = match obj_type.collection_type {
Some(CollectionType::Varray) => collection_type::VARRAY,
Some(CollectionType::NestedTable) => collection_type::NESTED_TABLE,
Some(CollectionType::PlsqlIndexTable) => collection_type::PLSQL_INDEX_TABLE,
None => collection_type::VARRAY,
};
buf.write_u8(coll_flags)?;
write_length(&mut buf, obj.elements.len())?;
let element_type = obj_type.element_type.unwrap_or(OracleType::Varchar);
let coll_type = obj_type.collection_type.unwrap_or(CollectionType::Varray);
for (idx, value) in obj.elements.iter().enumerate() {
if coll_type == CollectionType::PlsqlIndexTable {
buf.write_u32_be(idx as u32)?;
}
encode_value(&mut buf, value, element_type)?;
}
let total_len = buf.len();
let data = buf.as_ref();
let mut result = data.to_vec();
result[length_pos + 1..length_pos + 5].copy_from_slice(&(total_len as u32).to_be_bytes());
Ok(result)
}
fn read_header(buf: &mut ReadBuffer) -> Result<(u8, u8)> {
let flags = buf.read_u8()?;
let version = buf.read_u8()?;
skip_length(buf)?;
if flags & obj_flags::IS_DEGENERATE != 0 {
return Err(Error::DataConversionError(
"DbObject stored in LOB is not supported".to_string(),
));
}
if buf.remaining() > 0 {
let next_byte = buf.peek_u8()?;
if next_byte == 1 {
if buf.remaining() >= 2 {
buf.skip(1)?; let content_len = next_byte as usize;
buf.skip(content_len)?; }
}
}
Ok((flags, version))
}
fn read_length(buf: &mut ReadBuffer) -> Result<u32> {
let short_len = buf.read_u8()?;
if short_len == LONG_LENGTH_INDICATOR {
buf.read_u32_be()
} else {
Ok(short_len as u32)
}
}
fn skip_length(buf: &mut ReadBuffer) -> Result<()> {
let short_len = buf.read_u8()?;
if short_len == LONG_LENGTH_INDICATOR {
buf.skip(4)?;
}
Ok(())
}
fn write_length(buf: &mut WriteBuffer, len: usize) -> Result<()> {
if len <= obj_flags::MAX_SHORT_LENGTH as usize {
buf.write_u8(len as u8)?;
} else {
buf.write_u8(LONG_LENGTH_INDICATOR)?;
buf.write_u32_be(len as u32)?;
}
Ok(())
}
fn decode_value(buf: &mut ReadBuffer, oracle_type: OracleType) -> Result<Value> {
let first_byte = buf.read_u8()?;
if first_byte == obj_flags::ATOMIC_NULL || first_byte == NULL_LENGTH_INDICATOR {
return Ok(Value::Null);
}
let len = if first_byte == LONG_LENGTH_INDICATOR {
buf.read_u32_be()? as usize
} else {
first_byte as usize
};
if len == 0 {
return Ok(Value::Null);
}
match oracle_type {
OracleType::Varchar | OracleType::Char => {
let bytes = buf.read_bytes_vec(len)?;
let s = String::from_utf8(bytes).map_err(|e| {
Error::DataConversionError(format!("Invalid UTF-8: {}", e))
})?;
Ok(Value::String(s))
}
OracleType::Number => {
let bytes = buf.read_bytes_vec(len)?;
let num = decode_oracle_number(&bytes)?;
if num.is_integer {
if let Ok(i) = num.to_i64() {
return Ok(Value::Integer(i));
}
}
Ok(Value::Number(num))
}
OracleType::BinaryInteger => {
if len >= 4 {
let value = buf.read_u32_be()?;
if len > 4 {
buf.skip(len - 4)?;
}
Ok(Value::Integer(value as i64))
} else {
buf.skip(len)?;
Ok(Value::Null)
}
}
OracleType::Raw => {
let bytes = buf.read_bytes_vec(len)?;
Ok(Value::Bytes(bytes))
}
OracleType::BinaryDouble => {
if len == 8 {
let bytes = buf.read_bytes_vec(8)?;
let f = crate::types::decode_binary_double(&bytes);
Ok(Value::Float(f))
} else {
buf.skip(len)?;
Ok(Value::Null)
}
}
OracleType::BinaryFloat => {
if len == 4 {
let bytes = buf.read_bytes_vec(4)?;
let f = crate::types::decode_binary_float(&bytes);
Ok(Value::Float(f as f64))
} else {
buf.skip(len)?;
Ok(Value::Null)
}
}
OracleType::Boolean => {
if len >= 4 {
let value = buf.read_u32_be()?;
Ok(Value::Boolean(value != 0))
} else {
let bytes = buf.read_bytes_vec(len)?;
let b = bytes.last().map(|&v| v != 0).unwrap_or(false);
Ok(Value::Boolean(b))
}
}
OracleType::Date => {
let bytes = buf.read_bytes_vec(len)?;
let date = crate::types::decode_oracle_date(&bytes)?;
Ok(Value::Date(date))
}
OracleType::Timestamp | OracleType::TimestampTz | OracleType::TimestampLtz => {
let bytes = buf.read_bytes_vec(len)?;
let ts = crate::types::decode_oracle_timestamp(&bytes)?;
Ok(Value::Timestamp(ts))
}
_ => {
let bytes = buf.read_bytes_vec(len)?;
Ok(Value::Bytes(bytes))
}
}
}
fn encode_value(buf: &mut WriteBuffer, value: &Value, oracle_type: OracleType) -> Result<()> {
match value {
Value::Null => {
buf.write_u8(NULL_LENGTH_INDICATOR)?;
}
Value::String(s) => {
let bytes = s.as_bytes();
write_length(buf, bytes.len())?;
buf.write_bytes(bytes)?;
}
Value::Integer(n) => {
match oracle_type {
OracleType::BinaryInteger => {
buf.write_u8(4)?;
buf.write_u32_be(*n as u32)?;
}
_ => {
let encoded = encode_oracle_number(&n.to_string())?;
write_length(buf, encoded.len())?;
buf.write_bytes(&encoded)?;
}
}
}
Value::Float(f) => {
match oracle_type {
OracleType::BinaryDouble => {
let encoded = crate::types::encode_binary_double(*f);
buf.write_u8(8)?;
buf.write_bytes(&encoded)?;
}
OracleType::BinaryFloat => {
let encoded = crate::types::encode_binary_float(*f as f32);
buf.write_u8(4)?;
buf.write_bytes(&encoded)?;
}
_ => {
let encoded = encode_oracle_number(&f.to_string())?;
write_length(buf, encoded.len())?;
buf.write_bytes(&encoded)?;
}
}
}
Value::Number(n) => {
let encoded = encode_oracle_number(n.as_str())?;
write_length(buf, encoded.len())?;
buf.write_bytes(&encoded)?;
}
Value::Bytes(b) => {
write_length(buf, b.len())?;
buf.write_bytes(b)?;
}
Value::Boolean(b) => {
buf.write_u8(4)?;
buf.write_u32_be(if *b { 1 } else { 0 })?;
}
Value::Date(d) => {
let bytes = d.to_oracle_bytes();
write_length(buf, bytes.len())?;
buf.write_bytes(&bytes)?;
}
Value::Timestamp(ts) => {
let bytes = ts.to_oracle_bytes();
write_length(buf, bytes.len())?;
buf.write_bytes(&bytes)?;
}
Value::RowId(_) | Value::Lob(_) | Value::Json(_) | Value::Vector(_)
| Value::Cursor(_) | Value::Collection(_) => {
return Err(Error::DataConversionError(
format!("Type {:?} not supported in collections", value),
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_length_encoding_short() {
let mut buf = WriteBuffer::new();
write_length(&mut buf, 100).unwrap();
assert_eq!(buf.as_ref(), &[100u8]);
}
#[test]
fn test_length_encoding_long() {
let mut buf = WriteBuffer::new();
write_length(&mut buf, 1000).unwrap();
assert_eq!(buf.as_ref(), &[254, 0, 0, 3, 232]);
}
#[test]
fn test_decode_empty_collection() {
let obj_type = DbObjectType::collection("TEST", "NUM_ARRAY", CollectionType::Varray, OracleType::Number);
let obj = decode_collection(&obj_type, &[]).unwrap();
assert!(obj.is_collection);
assert_eq!(obj.elements.len(), 0);
}
#[test]
fn test_encode_decode_roundtrip() {
let obj_type = DbObjectType::collection("TEST", "NUM_ARRAY", CollectionType::Varray, OracleType::Number);
let mut obj = DbObject::collection("TEST.NUM_ARRAY");
obj.append(Value::Integer(1));
obj.append(Value::Integer(2));
obj.append(Value::Integer(3));
let encoded = encode_collection(&obj, &obj_type).unwrap();
let decoded = decode_collection(&obj_type, &encoded).unwrap();
assert_eq!(decoded.elements.len(), 3);
assert_eq!(decoded.elements[0].as_i64(), Some(1));
assert_eq!(decoded.elements[1].as_i64(), Some(2));
assert_eq!(decoded.elements[2].as_i64(), Some(3));
}
#[test]
fn test_encode_decode_strings() {
let obj_type = DbObjectType::collection("TEST", "STR_ARRAY", CollectionType::Varray, OracleType::Varchar);
let mut obj = DbObject::collection("TEST.STR_ARRAY");
obj.append(Value::String("hello".to_string()));
obj.append(Value::String("world".to_string()));
let encoded = encode_collection(&obj, &obj_type).unwrap();
let decoded = decode_collection(&obj_type, &encoded).unwrap();
assert_eq!(decoded.elements.len(), 2);
assert_eq!(decoded.elements[0].as_str(), Some("hello"));
assert_eq!(decoded.elements[1].as_str(), Some("world"));
}
#[test]
fn test_wire_collection_image_flags_must_be_0x88() {
let obj_type = DbObjectType::collection("TEST", "NUM_ARRAY", CollectionType::Varray, OracleType::Number);
let mut obj = DbObject::collection("TEST.NUM_ARRAY");
obj.append(Value::Integer(42));
let encoded = encode_collection(&obj, &obj_type).unwrap();
assert_eq!(encoded[0], 0x88,
"Collection image_flags must be 0x88 (IS_VERSION_81 | IS_COLLECTION), not 0x04");
assert_eq!(encoded[1], 0x01);
}
#[test]
fn test_wire_pickle_header_layout() {
let obj_type = DbObjectType::collection("TEST", "ARR", CollectionType::Varray, OracleType::Number);
let mut obj = DbObject::collection("TEST.ARR");
obj.append(Value::Integer(10));
obj.append(Value::Integer(20));
obj.append(Value::Integer(30));
let encoded = encode_collection(&obj, &obj_type).unwrap();
assert_eq!(encoded[0], 0x88, "image_flags");
assert_eq!(encoded[1], 0x01, "image_version");
assert_eq!(encoded[2], 0xFE, "length_indicator (TNS_LONG_LENGTH_INDICATOR)");
let total_len = u32::from_be_bytes([encoded[3], encoded[4], encoded[5], encoded[6]]);
assert_eq!(total_len as usize, encoded.len(),
"Length field must equal total pickle size, not data-after-header");
assert_eq!(encoded[7], 0x01, "prefix_seg_len");
assert_eq!(encoded[8], 0x01, "prefix_seg_content");
assert_eq!(encoded[9], 0x03, "collection_type (VARRAY)");
assert_eq!(encoded[10], 0x03, "element_count");
}
#[test]
fn test_wire_collection_type_codes() {
let varray_type = DbObjectType::collection("T", "V", CollectionType::Varray, OracleType::Number);
let mut varray = DbObject::collection("T.V");
varray.append(Value::Integer(1));
let encoded = encode_collection(&varray, &varray_type).unwrap();
assert_eq!(encoded[9], 3, "VARRAY wire code must be 3");
let nested_type = DbObjectType::collection("T", "N", CollectionType::NestedTable, OracleType::Number);
let mut nested = DbObject::collection("T.N");
nested.append(Value::Integer(1));
let encoded = encode_collection(&nested, &nested_type).unwrap();
assert_eq!(encoded[9], 2, "Nested Table wire code must be 2");
let idx_type = DbObjectType::collection("T", "I", CollectionType::PlsqlIndexTable, OracleType::Number);
let mut idx = DbObject::collection("T.I");
idx.append(Value::Integer(1));
let encoded = encode_collection(&idx, &idx_type).unwrap();
assert_eq!(encoded[9], 1, "PL/SQL Index Table wire code must be 1");
}
}