use bytes::Bytes;
use crate::buffer::{ReadBuffer, WriteBuffer};
use crate::capabilities::{Capabilities, DRIVER_NAME};
use crate::constants::{MessageType, PacketType, PACKET_HEADER_SIZE};
use crate::error::{Error, Result};
use crate::packet::PacketHeader;
#[derive(Debug)]
pub struct ProtocolMessage {
pub server_version: u8,
pub server_flags: u8,
pub server_banner: Option<String>,
pub server_compile_caps: Option<Vec<u8>>,
pub server_runtime_caps: Option<Vec<u8>>,
}
impl ProtocolMessage {
pub fn new() -> Self {
Self {
server_version: 0,
server_flags: 0,
server_banner: None,
server_compile_caps: None,
server_runtime_caps: None,
}
}
pub fn build_request(&self, large_sdu: bool) -> Result<Bytes> {
let mut buf = WriteBuffer::with_capacity(128);
buf.write_zeros(PACKET_HEADER_SIZE)?;
buf.write_u16_be(0)?;
buf.write_u8(MessageType::Protocol as u8)?;
buf.write_u8(6)?;
buf.write_u8(0)?;
buf.write_bytes(DRIVER_NAME.as_bytes())?;
buf.write_u8(0)?;
let total_len = buf.len() as u32;
let header = PacketHeader::new(PacketType::Data, total_len);
let mut header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
header.write(&mut header_buf, large_sdu)?;
let mut result = buf.into_inner();
result[..PACKET_HEADER_SIZE].copy_from_slice(header_buf.as_slice());
Ok(result.freeze())
}
pub fn parse_response(&mut self, payload: &[u8], caps: &mut Capabilities) -> Result<()> {
let mut buf = ReadBuffer::from_slice(payload);
buf.skip(2)?;
let msg_type = buf.read_u8()?;
if msg_type != MessageType::Protocol as u8 {
return Err(Error::UnexpectedPacketType {
expected: PacketType::Data,
actual: PacketType::Data, });
}
self.server_version = buf.read_u8()?;
buf.skip(1)?;
self.server_banner = Some(Self::read_null_terminated_string(&mut buf)?);
caps.charset_id = buf.read_u16_le()?;
self.server_flags = buf.read_u8()?;
let num_elem = buf.read_u16_le()? as usize;
if num_elem > 0 {
buf.skip(num_elem * 5)?;
}
let fdo_length = buf.read_u16_be()? as usize;
if fdo_length > 0 {
let fdo_data = buf.read_bytes_vec(fdo_length)?;
if fdo_data.len() > 6 {
let offset = 6 + fdo_data[5] as usize + fdo_data[6] as usize;
if offset + 5 <= fdo_data.len() {
caps.ncharset_id =
((fdo_data[offset + 3] as u16) << 8) | (fdo_data[offset + 4] as u16);
}
}
}
self.server_compile_caps = buf.read_bytes_with_length()?;
if let Some(ref server_ccaps) = self.server_compile_caps {
caps.adjust_for_server_compile_caps(server_ccaps);
}
self.server_runtime_caps = buf.read_bytes_with_length()?;
if let Some(ref server_rcaps) = self.server_runtime_caps {
caps.adjust_for_server_runtime_caps(server_rcaps);
}
Ok(())
}
fn read_null_terminated_string(buf: &mut ReadBuffer) -> Result<String> {
let mut bytes = Vec::new();
loop {
let b = buf.read_u8()?;
if b == 0 {
break;
}
bytes.push(b);
}
Ok(String::from_utf8_lossy(&bytes).to_string())
}
}
impl Default for ProtocolMessage {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_protocol_request() {
let msg = ProtocolMessage::new();
let packet = msg.build_request(false).unwrap();
assert!(packet.len() > PACKET_HEADER_SIZE);
assert_eq!(packet[4], PacketType::Data as u8);
assert_eq!(packet[PACKET_HEADER_SIZE + 2], MessageType::Protocol as u8);
assert_eq!(packet[PACKET_HEADER_SIZE + 3], 6);
let packet_str = String::from_utf8_lossy(&packet);
assert!(packet_str.contains("oracle-rs"));
}
#[test]
fn test_parse_protocol_response_minimal() {
let mut payload = Vec::new();
payload.extend_from_slice(&[0x00, 0x00]);
payload.push(MessageType::Protocol as u8);
payload.push(6);
payload.push(0);
payload.extend_from_slice(b"Oracle Database 19c\0");
payload.extend_from_slice(&873u16.to_le_bytes());
payload.push(0);
payload.extend_from_slice(&0u16.to_le_bytes());
payload.extend_from_slice(&0u16.to_be_bytes());
payload.push(255);
payload.push(255);
let mut msg = ProtocolMessage::new();
let mut caps = Capabilities::new();
let result = msg.parse_response(&payload, &mut caps);
assert!(result.is_ok());
assert_eq!(msg.server_version, 6);
assert_eq!(
msg.server_banner.as_deref(),
Some("Oracle Database 19c")
);
assert_eq!(caps.charset_id, 873);
}
#[test]
fn test_parse_protocol_response_with_caps() {
use crate::constants::{ccap_index, ccap_value, rcap_index, rcap_value};
let mut payload = Vec::new();
payload.extend_from_slice(&[0x00, 0x00]);
payload.push(MessageType::Protocol as u8);
payload.push(6);
payload.push(0);
payload.extend_from_slice(b"Test\0");
payload.extend_from_slice(&873u16.to_le_bytes());
payload.push(0);
payload.extend_from_slice(&0u16.to_le_bytes());
payload.extend_from_slice(&0u16.to_be_bytes());
let mut compile_caps = vec![0u8; ccap_index::MAX];
compile_caps[ccap_index::FIELD_VERSION] = ccap_value::FIELD_VERSION_19_1;
payload.push(compile_caps.len() as u8);
payload.extend_from_slice(&compile_caps);
let mut runtime_caps = vec![0u8; rcap_index::MAX];
runtime_caps[rcap_index::TTC] = rcap_value::TTC_32K;
payload.push(runtime_caps.len() as u8);
payload.extend_from_slice(&runtime_caps);
let mut msg = ProtocolMessage::new();
let mut caps = Capabilities::new();
let result = msg.parse_response(&payload, &mut caps);
assert!(result.is_ok());
assert_eq!(caps.ttc_field_version, ccap_value::FIELD_VERSION_19_1);
assert_eq!(caps.max_string_size, 32767);
}
}