#![allow(clippy::expect_used)]
use bytes::{Bytes, BytesMut};
use tds_protocol::packet::{PacketStatus, PacketType};
use crate::packet_codec::Packet;
#[derive(Debug, Clone)]
pub struct Message {
pub packet_type: PacketType,
pub payload: Bytes,
}
impl Message {
#[must_use]
pub fn from_packet(packet: Packet) -> Self {
Self {
packet_type: packet.header.packet_type,
payload: packet.payload.freeze(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.payload.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.payload.is_empty()
}
}
#[derive(Debug)]
pub struct MessageAssembler {
buffer: BytesMut,
packet_type: Option<PacketType>,
packet_count: usize,
}
impl MessageAssembler {
#[must_use]
pub fn new() -> Self {
Self {
buffer: BytesMut::new(),
packet_type: None,
packet_count: 0,
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(capacity),
packet_type: None,
packet_count: 0,
}
}
pub fn push(&mut self, packet: Packet) -> Option<Message> {
if self.packet_type.is_none() {
self.packet_type = Some(packet.header.packet_type);
}
self.buffer.extend_from_slice(&packet.payload);
self.packet_count += 1;
tracing::trace!(
packet_type = ?packet.header.packet_type,
packet_count = self.packet_count,
buffer_len = self.buffer.len(),
is_eom = packet.header.status.contains(PacketStatus::END_OF_MESSAGE),
"assembling message"
);
if packet.header.status.contains(PacketStatus::END_OF_MESSAGE) {
let message = Message {
packet_type: self.packet_type.take().expect("packet_type set above"),
payload: self.buffer.split().freeze(),
};
self.packet_count = 0;
Some(message)
} else {
None
}
}
#[must_use]
pub fn has_partial(&self) -> bool {
self.packet_type.is_some()
}
#[must_use]
pub fn packet_count(&self) -> usize {
self.packet_count
}
#[must_use]
pub fn buffer_len(&self) -> usize {
self.buffer.len()
}
pub fn clear(&mut self) {
self.buffer.clear();
self.packet_type = None;
self.packet_count = 0;
}
}
impl Default for MessageAssembler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use tds_protocol::packet::PacketHeader;
fn make_packet(is_eom: bool, payload: &[u8]) -> Packet {
let status = if is_eom {
PacketStatus::END_OF_MESSAGE
} else {
PacketStatus::NORMAL
};
let header = PacketHeader::new(PacketType::TabularResult, status, 0);
Packet::new(header, BytesMut::from(payload))
}
#[test]
fn test_single_packet_message() {
let mut assembler = MessageAssembler::new();
let packet = make_packet(true, b"hello");
let message = assembler.push(packet).expect("should complete message");
assert_eq!(message.packet_type, PacketType::TabularResult);
assert_eq!(&message.payload[..], b"hello");
assert!(!assembler.has_partial());
}
#[test]
fn test_multi_packet_message() {
let mut assembler = MessageAssembler::new();
let packet1 = make_packet(false, b"hello ");
assert!(assembler.push(packet1).is_none());
assert!(assembler.has_partial());
assert_eq!(assembler.packet_count(), 1);
let packet2 = make_packet(false, b"world");
assert!(assembler.push(packet2).is_none());
assert_eq!(assembler.packet_count(), 2);
let packet3 = make_packet(true, b"!");
let message = assembler.push(packet3).expect("should complete message");
assert_eq!(message.packet_type, PacketType::TabularResult);
assert_eq!(&message.payload[..], b"hello world!");
assert!(!assembler.has_partial());
assert_eq!(assembler.packet_count(), 0);
}
#[test]
fn test_clear() {
let mut assembler = MessageAssembler::new();
let packet = make_packet(false, b"partial");
assembler.push(packet);
assert!(assembler.has_partial());
assembler.clear();
assert!(!assembler.has_partial());
assert_eq!(assembler.buffer_len(), 0);
}
}