mssql_codec/
message.rs

1//! TDS message reassembly.
2//!
3//! TDS messages can span multiple packets. This module handles reassembling
4//! packets into complete messages based on the `END_OF_MESSAGE` status flag.
5
6// Allow expect() on Option that is guaranteed to be Some based on prior logic
7#![allow(clippy::expect_used)]
8
9use bytes::{Bytes, BytesMut};
10use tds_protocol::packet::{PacketStatus, PacketType};
11
12use crate::packet_codec::Packet;
13
14/// A complete TDS message reassembled from one or more packets.
15#[derive(Debug, Clone)]
16pub struct Message {
17    /// The packet type of this message.
18    pub packet_type: PacketType,
19    /// The complete message payload (all packets combined).
20    pub payload: Bytes,
21}
22
23impl Message {
24    /// Create a new message from a single packet.
25    #[must_use]
26    pub fn from_packet(packet: Packet) -> Self {
27        Self {
28            packet_type: packet.header.packet_type,
29            payload: packet.payload.freeze(),
30        }
31    }
32
33    /// Get the message payload length.
34    #[must_use]
35    pub fn len(&self) -> usize {
36        self.payload.len()
37    }
38
39    /// Check if the message is empty.
40    #[must_use]
41    pub fn is_empty(&self) -> bool {
42        self.payload.is_empty()
43    }
44}
45
46/// Reassembles multiple TDS packets into complete messages.
47///
48/// TDS messages are framed with the `END_OF_MESSAGE` status flag on the final
49/// packet. This assembler buffers packets until a complete message is received.
50#[derive(Debug)]
51pub struct MessageAssembler {
52    /// Buffer for accumulating packet payloads.
53    buffer: BytesMut,
54    /// Packet type of the message being assembled.
55    packet_type: Option<PacketType>,
56    /// Number of packets accumulated.
57    packet_count: usize,
58}
59
60impl MessageAssembler {
61    /// Create a new message assembler.
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            buffer: BytesMut::new(),
66            packet_type: None,
67            packet_count: 0,
68        }
69    }
70
71    /// Create a new message assembler with pre-allocated capacity.
72    #[must_use]
73    pub fn with_capacity(capacity: usize) -> Self {
74        Self {
75            buffer: BytesMut::with_capacity(capacity),
76            packet_type: None,
77            packet_count: 0,
78        }
79    }
80
81    /// Push a packet into the assembler.
82    ///
83    /// Returns `Some(Message)` if this packet completes a message,
84    /// `None` if more packets are needed.
85    pub fn push(&mut self, packet: Packet) -> Option<Message> {
86        // Record the packet type from the first packet
87        if self.packet_type.is_none() {
88            self.packet_type = Some(packet.header.packet_type);
89        }
90
91        // Append payload to buffer
92        self.buffer.extend_from_slice(&packet.payload);
93        self.packet_count += 1;
94
95        tracing::trace!(
96            packet_type = ?packet.header.packet_type,
97            packet_count = self.packet_count,
98            buffer_len = self.buffer.len(),
99            is_eom = packet.header.status.contains(PacketStatus::END_OF_MESSAGE),
100            "assembling message"
101        );
102
103        // Check if this is the last packet
104        if packet.header.status.contains(PacketStatus::END_OF_MESSAGE) {
105            let message = Message {
106                packet_type: self.packet_type.take().expect("packet_type set above"),
107                payload: self.buffer.split().freeze(),
108            };
109            self.packet_count = 0;
110            Some(message)
111        } else {
112            None
113        }
114    }
115
116    /// Check if the assembler has partial data buffered.
117    #[must_use]
118    pub fn has_partial(&self) -> bool {
119        self.packet_type.is_some()
120    }
121
122    /// Get the number of packets accumulated so far.
123    #[must_use]
124    pub fn packet_count(&self) -> usize {
125        self.packet_count
126    }
127
128    /// Get the current buffer length.
129    #[must_use]
130    pub fn buffer_len(&self) -> usize {
131        self.buffer.len()
132    }
133
134    /// Clear any partial message data.
135    pub fn clear(&mut self) {
136        self.buffer.clear();
137        self.packet_type = None;
138        self.packet_count = 0;
139    }
140}
141
142impl Default for MessageAssembler {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[cfg(test)]
149#[allow(clippy::unwrap_used, clippy::expect_used)]
150mod tests {
151    use super::*;
152    use tds_protocol::packet::PacketHeader;
153
154    fn make_packet(is_eom: bool, payload: &[u8]) -> Packet {
155        let status = if is_eom {
156            PacketStatus::END_OF_MESSAGE
157        } else {
158            PacketStatus::NORMAL
159        };
160        let header = PacketHeader::new(PacketType::TabularResult, status, 0);
161        Packet::new(header, BytesMut::from(payload))
162    }
163
164    #[test]
165    fn test_single_packet_message() {
166        let mut assembler = MessageAssembler::new();
167        let packet = make_packet(true, b"hello");
168
169        let message = assembler.push(packet).expect("should complete message");
170        assert_eq!(message.packet_type, PacketType::TabularResult);
171        assert_eq!(&message.payload[..], b"hello");
172        assert!(!assembler.has_partial());
173    }
174
175    #[test]
176    fn test_multi_packet_message() {
177        let mut assembler = MessageAssembler::new();
178
179        // First packet - not EOM
180        let packet1 = make_packet(false, b"hello ");
181        assert!(assembler.push(packet1).is_none());
182        assert!(assembler.has_partial());
183        assert_eq!(assembler.packet_count(), 1);
184
185        // Second packet - not EOM
186        let packet2 = make_packet(false, b"world");
187        assert!(assembler.push(packet2).is_none());
188        assert_eq!(assembler.packet_count(), 2);
189
190        // Third packet - EOM
191        let packet3 = make_packet(true, b"!");
192        let message = assembler.push(packet3).expect("should complete message");
193
194        assert_eq!(message.packet_type, PacketType::TabularResult);
195        assert_eq!(&message.payload[..], b"hello world!");
196        assert!(!assembler.has_partial());
197        assert_eq!(assembler.packet_count(), 0);
198    }
199
200    #[test]
201    fn test_clear() {
202        let mut assembler = MessageAssembler::new();
203
204        let packet = make_packet(false, b"partial");
205        assembler.push(packet);
206        assert!(assembler.has_partial());
207
208        assembler.clear();
209        assert!(!assembler.has_partial());
210        assert_eq!(assembler.buffer_len(), 0);
211    }
212}