mssql_codec/
connection.rs

1//! Split I/O connection for cancellation safety.
2//!
3//! Per ADR-005, the TCP stream is split into separate read and write halves
4//! to allow sending Attention packets while blocked on reading results.
5
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19/// A TDS connection with split I/O for cancellation safety.
20///
21/// This struct splits the underlying transport into read and write halves,
22/// allowing Attention packets to be sent even while blocked reading results.
23///
24/// # Cancellation
25///
26/// SQL Server uses out-of-band "Attention" packets to cancel running queries.
27/// Without split I/O, the driver would be unable to send cancellation while
28/// blocked awaiting a read (e.g., processing a large result set).
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use mssql_codec::Connection;
34/// use tokio::net::TcpStream;
35///
36/// let stream = TcpStream::connect("localhost:1433").await?;
37/// let conn = Connection::new(stream);
38///
39/// // Can cancel from another task while reading
40/// let cancel_handle = conn.cancel_handle();
41/// tokio::spawn(async move {
42///     tokio::time::sleep(Duration::from_secs(5)).await;
43///     cancel_handle.cancel().await?;
44/// });
45/// ```
46pub struct Connection<T>
47where
48    T: AsyncRead + AsyncWrite,
49{
50    /// Read half wrapped in a packet reader.
51    reader: PacketReader<ReadHalf<T>>,
52    /// Write half protected by mutex for concurrent cancel access.
53    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54    /// Message assembler for multi-packet messages.
55    assembler: MessageAssembler,
56    /// Notification for cancellation completion.
57    cancel_notify: Arc<Notify>,
58    /// Flag indicating cancellation is in progress.
59    cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64    T: AsyncRead + AsyncWrite,
65{
66    /// Create a new connection from a transport.
67    ///
68    /// The transport is immediately split into read and write halves.
69    pub fn new(transport: T) -> Self {
70        let (read_half, write_half) = tokio::io::split(transport);
71
72        Self {
73            reader: PacketReader::new(read_half),
74            writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75            assembler: MessageAssembler::new(),
76            cancel_notify: Arc::new(Notify::new()),
77            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78        }
79    }
80
81    /// Create a new connection with custom codecs.
82    pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83        let (read_half, write_half) = tokio::io::split(transport);
84
85        Self {
86            reader: PacketReader::with_codec(read_half, read_codec),
87            writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88                write_half,
89                write_codec,
90            ))),
91            assembler: MessageAssembler::new(),
92            cancel_notify: Arc::new(Notify::new()),
93            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94        }
95    }
96
97    /// Get a handle for cancelling queries on this connection.
98    ///
99    /// The handle can be cloned and sent to other tasks.
100    #[must_use]
101    pub fn cancel_handle(&self) -> CancelHandle<T> {
102        CancelHandle {
103            writer: Arc::clone(&self.writer),
104            notify: Arc::clone(&self.cancel_notify),
105            cancelling: Arc::clone(&self.cancelling),
106        }
107    }
108
109    /// Check if a cancellation is currently in progress.
110    #[must_use]
111    pub fn is_cancelling(&self) -> bool {
112        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113    }
114
115    /// Read the next complete message from the connection.
116    ///
117    /// This handles multi-packet message reassembly automatically.
118    pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
119        loop {
120            // Check for cancellation
121            if self.is_cancelling() {
122                // Drain until we see DONE with ATTENTION flag
123                return self.drain_after_cancel().await;
124            }
125
126            match self.reader.next().await {
127                Some(Ok(packet)) => {
128                    if let Some(message) = self.assembler.push(packet) {
129                        return Ok(Some(message));
130                    }
131                    // Continue reading packets until message complete
132                }
133                Some(Err(e)) => return Err(e),
134                None => {
135                    // Connection closed
136                    if self.assembler.has_partial() {
137                        return Err(CodecError::ConnectionClosed);
138                    }
139                    return Ok(None);
140                }
141            }
142        }
143    }
144
145    /// Read a single packet from the connection.
146    ///
147    /// This is lower-level than `read_message` and doesn't perform reassembly.
148    pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
149        match self.reader.next().await {
150            Some(result) => result.map(Some),
151            None => Ok(None),
152        }
153    }
154
155    /// Send a packet on the connection.
156    pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
157        let mut writer = self.writer.lock().await;
158        writer.send(packet).await
159    }
160
161    /// Send a complete message, splitting into multiple packets if needed.
162    pub async fn send_message(
163        &mut self,
164        packet_type: PacketType,
165        payload: Bytes,
166        max_packet_size: usize,
167    ) -> Result<(), CodecError> {
168        let max_payload = max_packet_size - PACKET_HEADER_SIZE;
169        let chunks: Vec<_> = payload.chunks(max_payload).collect();
170        let total_chunks = chunks.len();
171
172        let mut writer = self.writer.lock().await;
173
174        for (i, chunk) in chunks.into_iter().enumerate() {
175            let is_last = i == total_chunks - 1;
176            let status = if is_last {
177                PacketStatus::END_OF_MESSAGE
178            } else {
179                PacketStatus::NORMAL
180            };
181
182            let header = PacketHeader::new(packet_type, status, 0);
183            let packet = Packet::new(header, BytesMut::from(chunk));
184
185            writer.send(packet).await?;
186        }
187
188        Ok(())
189    }
190
191    /// Flush the write buffer.
192    pub async fn flush(&mut self) -> Result<(), CodecError> {
193        let mut writer = self.writer.lock().await;
194        writer.flush().await
195    }
196
197    /// Drain packets after cancellation until DONE with ATTENTION is received.
198    async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
199        tracing::debug!("draining packets after cancellation");
200
201        // Clear any partial message
202        self.assembler.clear();
203
204        loop {
205            match self.reader.next().await {
206                Some(Ok(packet)) => {
207                    // Check for DONE token with ATTENTION flag
208                    // The DONE token is at the start of the payload
209                    if packet.header.packet_type == PacketType::TabularResult
210                        && !packet.payload.is_empty()
211                    {
212                        // TokenType::Done = 0xFD
213                        // Check if this packet contains a Done token
214                        // and the status has ATTN flag (0x0020)
215                        if self.check_attention_done(&packet) {
216                            tracing::debug!("received DONE with ATTENTION, cancellation complete");
217                            self.cancelling
218                                .store(false, std::sync::atomic::Ordering::Release);
219                            self.cancel_notify.notify_waiters();
220                            return Ok(None);
221                        }
222                    }
223                    // Continue draining
224                }
225                Some(Err(e)) => {
226                    self.cancelling
227                        .store(false, std::sync::atomic::Ordering::Release);
228                    return Err(e);
229                }
230                None => {
231                    self.cancelling
232                        .store(false, std::sync::atomic::Ordering::Release);
233                    return Ok(None);
234                }
235            }
236        }
237    }
238
239    /// Check if a packet contains a DONE token with ATTENTION flag.
240    fn check_attention_done(&self, packet: &Packet) -> bool {
241        // Look for DONE token (0xFD) with ATTN status flag (bit 5)
242        // DONE token format: token_type(1) + status(2) + cur_cmd(2) + row_count(8)
243        let payload = &packet.payload;
244
245        for i in 0..payload.len() {
246            if payload[i] == 0xFD && i + 3 <= payload.len() {
247                // Found DONE token, check status
248                let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
249                // DONE_ATTN = 0x0020
250                if status & 0x0020 != 0 {
251                    return true;
252                }
253            }
254        }
255
256        false
257    }
258
259    /// Get a reference to the read codec.
260    pub fn read_codec(&self) -> &TdsCodec {
261        self.reader.codec()
262    }
263
264    /// Get a mutable reference to the read codec.
265    pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
266        self.reader.codec_mut()
267    }
268}
269
270impl<T> std::fmt::Debug for Connection<T>
271where
272    T: AsyncRead + AsyncWrite + std::fmt::Debug,
273{
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("Connection")
276            .field("cancelling", &self.is_cancelling())
277            .field("has_partial_message", &self.assembler.has_partial())
278            .finish_non_exhaustive()
279    }
280}
281
282/// Handle for cancelling queries on a connection.
283///
284/// This can be cloned and sent to other tasks to enable cancellation
285/// from a different async context.
286pub struct CancelHandle<T>
287where
288    T: AsyncRead + AsyncWrite,
289{
290    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
291    notify: Arc<Notify>,
292    cancelling: Arc<std::sync::atomic::AtomicBool>,
293}
294
295impl<T> CancelHandle<T>
296where
297    T: AsyncRead + AsyncWrite + Unpin,
298{
299    /// Send an Attention packet to cancel the current query.
300    ///
301    /// This can be called from a different task while the main task
302    /// is blocked reading results.
303    pub async fn cancel(&self) -> Result<(), CodecError> {
304        // Mark cancellation in progress
305        self.cancelling
306            .store(true, std::sync::atomic::Ordering::Release);
307
308        tracing::debug!("sending Attention packet for query cancellation");
309
310        // Send the Attention packet
311        let mut writer = self.writer.lock().await;
312
313        // Create and send attention packet
314        let header = PacketHeader::new(
315            PacketType::Attention,
316            PacketStatus::END_OF_MESSAGE,
317            PACKET_HEADER_SIZE as u16,
318        );
319        let packet = Packet::new(header, BytesMut::new());
320
321        writer.send(packet).await?;
322        writer.flush().await?;
323
324        Ok(())
325    }
326
327    /// Wait for the cancellation to complete.
328    ///
329    /// This waits until the server acknowledges the cancellation
330    /// with a DONE token containing the ATTENTION flag.
331    pub async fn wait_cancelled(&self) {
332        if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
333            self.notify.notified().await;
334        }
335    }
336
337    /// Check if a cancellation is currently in progress.
338    #[must_use]
339    pub fn is_cancelling(&self) -> bool {
340        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
341    }
342}
343
344impl<T> Clone for CancelHandle<T>
345where
346    T: AsyncRead + AsyncWrite,
347{
348    fn clone(&self) -> Self {
349        Self {
350            writer: Arc::clone(&self.writer),
351            notify: Arc::clone(&self.notify),
352            cancelling: Arc::clone(&self.cancelling),
353        }
354    }
355}
356
357impl<T> std::fmt::Debug for CancelHandle<T>
358where
359    T: AsyncRead + AsyncWrite + Unpin,
360{
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        f.debug_struct("CancelHandle")
363            .field("cancelling", &self.is_cancelling())
364            .finish_non_exhaustive()
365    }
366}
367
368#[cfg(test)]
369#[allow(clippy::unwrap_used)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_attention_packet_header() {
375        // Verify attention packet header construction
376        let header = PacketHeader::new(
377            PacketType::Attention,
378            PacketStatus::END_OF_MESSAGE,
379            PACKET_HEADER_SIZE as u16,
380        );
381
382        assert_eq!(header.packet_type, PacketType::Attention);
383        assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
384        assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
385    }
386
387    #[test]
388    fn test_check_attention_done() {
389        // Test DONE token with ATTN flag detection
390        // DONE token: 0xFD + status(2 bytes) + cur_cmd(2 bytes) + row_count(8 bytes)
391        // DONE_ATTN flag is 0x0020
392
393        // Create a mock packet with DONE token and ATTN flag
394        let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
395
396        // DONE token with ATTN flag set (status = 0x0020)
397        let payload_with_attn = BytesMut::from(
398            &[
399                0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
400            ][..],
401        );
402        let packet_with_attn = Packet::new(header, payload_with_attn);
403
404        // DONE token without ATTN flag (status = 0x0000)
405        let payload_no_attn = BytesMut::from(
406            &[
407                0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
408            ][..],
409        );
410        let packet_no_attn = Packet::new(header, payload_no_attn);
411
412        // We can't easily test check_attention_done without a Connection,
413        // so we verify the token detection logic directly
414        let check_done = |packet: &Packet| -> bool {
415            let payload = &packet.payload;
416            for i in 0..payload.len() {
417                if payload[i] == 0xFD && i + 3 <= payload.len() {
418                    let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
419                    if status & 0x0020 != 0 {
420                        return true;
421                    }
422                }
423            }
424            false
425        };
426
427        assert!(check_done(&packet_with_attn));
428        assert!(!check_done(&packet_no_attn));
429    }
430}