1#![allow(clippy::expect_used)]
8
9use bytes::{Bytes, BytesMut};
10use tds_protocol::packet::{PacketStatus, PacketType};
11
12use crate::packet_codec::Packet;
13
14#[derive(Debug, Clone)]
16pub struct Message {
17 pub packet_type: PacketType,
19 pub payload: Bytes,
21}
22
23impl Message {
24 #[must_use]
26 pub fn from_packet(packet: Packet) -> Self {
27 Self {
28 packet_type: packet.header.packet_type,
29 payload: packet.payload.freeze(),
30 }
31 }
32
33 #[must_use]
35 pub fn len(&self) -> usize {
36 self.payload.len()
37 }
38
39 #[must_use]
41 pub fn is_empty(&self) -> bool {
42 self.payload.is_empty()
43 }
44}
45
46#[derive(Debug)]
51pub struct MessageAssembler {
52 buffer: BytesMut,
54 packet_type: Option<PacketType>,
56 packet_count: usize,
58}
59
60impl MessageAssembler {
61 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 buffer: BytesMut::new(),
66 packet_type: None,
67 packet_count: 0,
68 }
69 }
70
71 #[must_use]
73 pub fn with_capacity(capacity: usize) -> Self {
74 Self {
75 buffer: BytesMut::with_capacity(capacity),
76 packet_type: None,
77 packet_count: 0,
78 }
79 }
80
81 pub fn push(&mut self, packet: Packet) -> Option<Message> {
86 if self.packet_type.is_none() {
88 self.packet_type = Some(packet.header.packet_type);
89 }
90
91 self.buffer.extend_from_slice(&packet.payload);
93 self.packet_count += 1;
94
95 tracing::trace!(
96 packet_type = ?packet.header.packet_type,
97 packet_count = self.packet_count,
98 buffer_len = self.buffer.len(),
99 is_eom = packet.header.status.contains(PacketStatus::END_OF_MESSAGE),
100 "assembling message"
101 );
102
103 if packet.header.status.contains(PacketStatus::END_OF_MESSAGE) {
105 let message = Message {
106 packet_type: self.packet_type.take().expect("packet_type set above"),
107 payload: self.buffer.split().freeze(),
108 };
109 self.packet_count = 0;
110 Some(message)
111 } else {
112 None
113 }
114 }
115
116 #[must_use]
118 pub fn has_partial(&self) -> bool {
119 self.packet_type.is_some()
120 }
121
122 #[must_use]
124 pub fn packet_count(&self) -> usize {
125 self.packet_count
126 }
127
128 #[must_use]
130 pub fn buffer_len(&self) -> usize {
131 self.buffer.len()
132 }
133
134 pub fn clear(&mut self) {
136 self.buffer.clear();
137 self.packet_type = None;
138 self.packet_count = 0;
139 }
140}
141
142impl Default for MessageAssembler {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[cfg(test)]
149#[allow(clippy::unwrap_used, clippy::expect_used)]
150mod tests {
151 use super::*;
152 use tds_protocol::packet::PacketHeader;
153
154 fn make_packet(is_eom: bool, payload: &[u8]) -> Packet {
155 let status = if is_eom {
156 PacketStatus::END_OF_MESSAGE
157 } else {
158 PacketStatus::NORMAL
159 };
160 let header = PacketHeader::new(PacketType::TabularResult, status, 0);
161 Packet::new(header, BytesMut::from(payload))
162 }
163
164 #[test]
165 fn test_single_packet_message() {
166 let mut assembler = MessageAssembler::new();
167 let packet = make_packet(true, b"hello");
168
169 let message = assembler.push(packet).expect("should complete message");
170 assert_eq!(message.packet_type, PacketType::TabularResult);
171 assert_eq!(&message.payload[..], b"hello");
172 assert!(!assembler.has_partial());
173 }
174
175 #[test]
176 fn test_multi_packet_message() {
177 let mut assembler = MessageAssembler::new();
178
179 let packet1 = make_packet(false, b"hello ");
181 assert!(assembler.push(packet1).is_none());
182 assert!(assembler.has_partial());
183 assert_eq!(assembler.packet_count(), 1);
184
185 let packet2 = make_packet(false, b"world");
187 assert!(assembler.push(packet2).is_none());
188 assert_eq!(assembler.packet_count(), 2);
189
190 let packet3 = make_packet(true, b"!");
192 let message = assembler.push(packet3).expect("should complete message");
193
194 assert_eq!(message.packet_type, PacketType::TabularResult);
195 assert_eq!(&message.payload[..], b"hello world!");
196 assert!(!assembler.has_partial());
197 assert_eq!(assembler.packet_count(), 0);
198 }
199
200 #[test]
201 fn test_clear() {
202 let mut assembler = MessageAssembler::new();
203
204 let packet = make_packet(false, b"partial");
205 assembler.push(packet);
206 assert!(assembler.has_partial());
207
208 assembler.clear();
209 assert!(!assembler.has_partial());
210 assert_eq!(assembler.buffer_len(), 0);
211 }
212}