use bytes::Bytes;
use crate::buffer::WriteBuffer;
use crate::config::Config;
use crate::constants::{
connection, nsi_flags, service_options, version, PacketType, PACKET_HEADER_SIZE,
};
use crate::error::Result;
use crate::packet::PacketHeader;
#[derive(Debug)]
pub struct ConnectMessage {
pub version_desired: u16,
pub version_minimum: u16,
pub service_options: u16,
pub sdu: u32,
pub tdu: u32,
pub protocol_characteristics: u16,
pub nsi_flags: u8,
pub connect_flags_1: u32,
pub connect_flags_2: u32,
pub connect_data: String,
pub supports_oob: bool,
}
impl ConnectMessage {
pub fn from_config(config: &Config) -> Self {
let connect_data = config.build_connect_string();
let mut service_opts = service_options::DONT_CARE;
let mut connect_flags_2 = 0u32;
service_opts |= service_options::CAN_RECV_ATTENTION;
connect_flags_2 |= connection::CHECK_OOB;
Self {
version_desired: version::DESIRED,
version_minimum: version::MINIMUM,
service_options: service_opts,
sdu: config.sdu,
tdu: connection::DEFAULT_TDU as u32,
protocol_characteristics: connection::PROTOCOL_CHARACTERISTICS,
nsi_flags: nsi_flags::SUPPORT_SECURITY_RENEG | nsi_flags::DISABLE_NA,
connect_flags_1: 0,
connect_flags_2,
connect_data,
supports_oob: true,
}
}
pub fn build(&self) -> Result<Bytes> {
let connect_data_bytes = self.connect_data.as_bytes();
let connect_data_len = connect_data_bytes.len();
let needs_split = connect_data_len > connection::MAX_CONNECT_DATA as usize;
let mut buf = WriteBuffer::with_capacity(512);
buf.write_zeros(PACKET_HEADER_SIZE)?;
buf.write_u16_be(self.version_desired)?;
buf.write_u16_be(self.version_minimum)?;
buf.write_u16_be(self.service_options)?;
buf.write_u16_be(self.sdu.min(65535) as u16)?;
buf.write_u16_be(self.tdu.min(65535) as u16)?;
buf.write_u16_be(self.protocol_characteristics)?;
buf.write_u16_be(0)?;
buf.write_u16_be(1)?;
buf.write_u16_be(connect_data_len as u16)?;
buf.write_u16_be(74)?;
buf.write_u32_be(0)?;
buf.write_u8(self.nsi_flags)?;
buf.write_u8(self.nsi_flags)?;
buf.write_zeros(24)?;
buf.write_u32_be(self.sdu)?;
buf.write_u32_be(self.tdu)?;
buf.write_u32_be(self.connect_flags_1)?;
buf.write_u32_be(self.connect_flags_2)?;
if !needs_split {
buf.write_bytes(connect_data_bytes)?;
}
let total_len = buf.len() as u32;
let header = if needs_split {
PacketHeader::new(PacketType::Connect, total_len)
} else {
PacketHeader::new(PacketType::Connect, total_len)
};
let mut header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
header.write(&mut header_buf, false)?;
let mut result = buf.into_inner();
result[..PACKET_HEADER_SIZE].copy_from_slice(header_buf.as_slice());
Ok(result.freeze())
}
pub fn build_with_continuation(&self) -> Result<(Bytes, Option<Bytes>)> {
let connect_data_bytes = self.connect_data.as_bytes();
let connect_data_len = connect_data_bytes.len();
let needs_split = connect_data_len > connection::MAX_CONNECT_DATA as usize;
if !needs_split {
return Ok((self.build()?, None));
}
let mut connect_buf = WriteBuffer::with_capacity(128);
connect_buf.write_zeros(PACKET_HEADER_SIZE)?;
connect_buf.write_u16_be(self.version_desired)?;
connect_buf.write_u16_be(self.version_minimum)?;
connect_buf.write_u16_be(self.service_options)?;
connect_buf.write_u16_be(self.sdu.min(65535) as u16)?;
connect_buf.write_u16_be(self.tdu.min(65535) as u16)?;
connect_buf.write_u16_be(self.protocol_characteristics)?;
connect_buf.write_u16_be(0)?; connect_buf.write_u16_be(1)?; connect_buf.write_u16_be(connect_data_len as u16)?;
connect_buf.write_u16_be(74)?; connect_buf.write_u32_be(0)?; connect_buf.write_u8(self.nsi_flags)?;
connect_buf.write_u8(self.nsi_flags)?;
connect_buf.write_zeros(24)?; connect_buf.write_u32_be(self.sdu)?;
connect_buf.write_u32_be(self.tdu)?;
connect_buf.write_u32_be(self.connect_flags_1)?;
connect_buf.write_u32_be(self.connect_flags_2)?;
let connect_len = connect_buf.len() as u32;
let header = PacketHeader::new(PacketType::Connect, connect_len);
let mut header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
header.write(&mut header_buf, false)?;
let mut connect_result = connect_buf.into_inner();
connect_result[..PACKET_HEADER_SIZE].copy_from_slice(header_buf.as_slice());
let mut data_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE + 2 + connect_data_len);
data_buf.write_zeros(PACKET_HEADER_SIZE)?;
data_buf.write_u16_be(0)?;
data_buf.write_bytes(connect_data_bytes)?;
let data_len = data_buf.len() as u32;
let data_header = PacketHeader::new(PacketType::Data, data_len);
let mut data_header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
data_header.write(&mut data_header_buf, false)?;
let mut data_result = data_buf.into_inner();
data_result[..PACKET_HEADER_SIZE].copy_from_slice(data_header_buf.as_slice());
Ok((connect_result.freeze(), Some(data_result.freeze())))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connect_message_from_config() {
let config = Config::new("localhost", 1521, "FREEPDB1", "user", "pass");
let msg = ConnectMessage::from_config(&config);
assert_eq!(msg.version_desired, version::DESIRED);
assert_eq!(msg.version_minimum, version::MINIMUM);
assert_eq!(msg.sdu, config.sdu);
assert!(msg.connect_data.contains("FREEPDB1"));
assert!(msg.connect_data.contains("localhost"));
}
#[test]
fn test_connect_message_build() {
let config = Config::new("localhost", 1521, "FREEPDB1", "user", "pass");
let msg = ConnectMessage::from_config(&config);
let packet = msg.build().unwrap();
assert!(packet.len() > PACKET_HEADER_SIZE);
assert_eq!(packet[4], PacketType::Connect as u8);
assert_eq!(packet[8], (version::DESIRED >> 8) as u8);
assert_eq!(packet[9], (version::DESIRED & 0xff) as u8);
}
#[test]
fn test_connect_message_small_data() {
let config = Config::new("localhost", 1521, "SVC", "u", "p");
let msg = ConnectMessage::from_config(&config);
let (connect, data) = msg.build_with_continuation().unwrap();
assert!(data.is_none());
assert!(connect.len() > PACKET_HEADER_SIZE + 66);
}
#[test]
fn test_connect_message_large_data() {
let long_service = "A".repeat(300);
let config = Config::new("localhost", 1521, &long_service, "u", "p");
let msg = ConnectMessage::from_config(&config);
let (_connect, data) = msg.build_with_continuation().unwrap();
assert!(data.is_some());
let data_packet = data.unwrap();
assert!(data_packet.len() > PACKET_HEADER_SIZE + 2);
assert_eq!(data_packet[4], PacketType::Data as u8);
}
}