use crate::buffer::ReadBuffer;
use crate::constants::{accept_flags, nsi_flags, version};
use crate::error::{Error, Result};
use crate::packet::Packet;
#[derive(Debug)]
pub struct AcceptMessage {
pub protocol_version: u16,
pub service_options: u16,
pub sdu: u32,
pub tdu: u32,
pub flags: u8,
pub flags2: u32,
pub accept_data: Option<String>,
pub supports_fast_auth: bool,
pub supports_oob: bool,
pub supports_end_of_response: bool,
}
impl AcceptMessage {
pub fn parse(packet: &Packet) -> Result<Self> {
if !packet.is_accept() {
return Err(Error::UnexpectedPacketType {
expected: crate::constants::PacketType::Accept,
actual: packet.packet_type(),
});
}
let mut buf = ReadBuffer::from_slice(&packet.payload);
let protocol_version = buf.read_u16_be()?;
if protocol_version < version::MIN_ACCEPTED {
return Err(Error::ProtocolVersionNotSupported(
protocol_version,
version::MIN_ACCEPTED,
));
}
let service_options = buf.read_u16_be()?;
let sdu_16 = buf.read_u16_be()? as u32;
let tdu_16 = buf.read_u16_be()? as u32;
buf.skip(2)?;
let data_length = buf.read_u16_be()? as usize;
let data_offset = buf.read_u16_be()? as usize;
let flags = buf.read_u8()?;
let _flags1 = buf.read_u8()?;
if (flags & nsi_flags::NA_REQUIRED) != 0 {
return Err(Error::NativeNetworkEncryptionRequired);
}
buf.skip(8)?;
let sdu = if protocol_version >= version::MIN_LARGE_SDU {
buf.read_u32_be()?
} else {
sdu_16
};
let tdu = tdu_16;
let flags2 = if protocol_version >= version::MIN_OOB_CHECK {
buf.skip(5)?;
buf.read_u32_be()?
} else {
0
};
let supports_fast_auth = (flags2 & accept_flags::FAST_AUTH) != 0;
let supports_oob = (service_options & 0x0400) != 0; let supports_end_of_response = protocol_version >= version::MIN_END_OF_RESPONSE
&& (flags2 & accept_flags::HAS_END_OF_RESPONSE) != 0;
let accept_data = if data_length > 0 && data_offset > 0 {
let current_pos = buf.position();
let data_start = data_offset.saturating_sub(current_pos);
if data_start > 0 && buf.has_remaining(data_start) {
buf.skip(data_start)?;
}
if buf.has_remaining(data_length) {
let data_bytes = buf.read_bytes_vec(data_length)?;
Some(String::from_utf8_lossy(&data_bytes).to_string())
} else {
None
}
} else {
None
};
Ok(Self {
protocol_version,
service_options,
sdu,
tdu,
flags,
flags2,
accept_data,
supports_fast_auth,
supports_oob,
supports_end_of_response,
})
}
pub fn uses_large_sdu(&self) -> bool {
self.protocol_version >= version::MIN_LARGE_SDU
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::PacketHeader;
use crate::constants::PacketType;
use bytes::Bytes;
fn make_accept_packet(payload: &[u8]) -> Packet {
let header = PacketHeader::new(
PacketType::Accept,
(crate::constants::PACKET_HEADER_SIZE + payload.len()) as u32,
);
Packet::new(header, Bytes::copy_from_slice(payload))
}
#[test]
fn test_parse_accept_basic() {
let payload = [
0x01, 0x3F, 0x00, 0x01, 0x20, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, ];
let packet = make_accept_packet(&payload);
let accept = AcceptMessage::parse(&packet).unwrap();
assert_eq!(accept.protocol_version, 319);
assert_eq!(accept.sdu, 8192);
assert!(accept.supports_fast_auth);
assert!(accept.uses_large_sdu());
}
#[test]
fn test_parse_accept_old_version() {
let payload = [
0x01, 0x3B, 0x00, 0x01, 0x20, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, ];
let packet = make_accept_packet(&payload);
let accept = AcceptMessage::parse(&packet).unwrap();
assert_eq!(accept.protocol_version, 315);
assert_eq!(accept.sdu, 8192);
assert!(!accept.supports_fast_auth); assert!(accept.uses_large_sdu());
}
#[test]
fn test_parse_accept_version_too_old() {
let payload = [
0x01, 0x2C, 0x00, 0x01, 0x20, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
let packet = make_accept_packet(&payload);
let result = AcceptMessage::parse(&packet);
assert!(result.is_err());
}
#[test]
fn test_parse_accept_na_required() {
let payload = [
0x01, 0x3F, 0x00, 0x01, 0x20, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, ];
let packet = make_accept_packet(&payload);
let result = AcceptMessage::parse(&packet);
assert!(matches!(result, Err(Error::NativeNetworkEncryptionRequired)));
}
}