mssql_codec/
connection.rs

1//! Split I/O connection for cancellation safety.
2//!
3//! Per ADR-005, the TCP stream is split into separate read and write halves
4//! to allow sending Attention packets while blocked on reading results.
5
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19/// A TDS connection with split I/O for cancellation safety.
20///
21/// This struct splits the underlying transport into read and write halves,
22/// allowing Attention packets to be sent even while blocked reading results.
23///
24/// # Cancellation
25///
26/// SQL Server uses out-of-band "Attention" packets to cancel running queries.
27/// Without split I/O, the driver would be unable to send cancellation while
28/// blocked awaiting a read (e.g., processing a large result set).
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use mssql_codec::Connection;
34/// use tokio::net::TcpStream;
35///
36/// let stream = TcpStream::connect("localhost:1433").await?;
37/// let conn = Connection::new(stream);
38///
39/// // Can cancel from another task while reading
40/// let cancel_handle = conn.cancel_handle();
41/// tokio::spawn(async move {
42///     tokio::time::sleep(Duration::from_secs(5)).await;
43///     cancel_handle.cancel().await?;
44/// });
45/// ```
46pub struct Connection<T>
47where
48    T: AsyncRead + AsyncWrite,
49{
50    /// Read half wrapped in a packet reader.
51    reader: PacketReader<ReadHalf<T>>,
52    /// Write half protected by mutex for concurrent cancel access.
53    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54    /// Message assembler for multi-packet messages.
55    assembler: MessageAssembler,
56    /// Notification for cancellation completion.
57    cancel_notify: Arc<Notify>,
58    /// Flag indicating cancellation is in progress.
59    cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64    T: AsyncRead + AsyncWrite,
65{
66    /// Create a new connection from a transport.
67    ///
68    /// The transport is immediately split into read and write halves.
69    pub fn new(transport: T) -> Self {
70        let (read_half, write_half) = tokio::io::split(transport);
71
72        Self {
73            reader: PacketReader::new(read_half),
74            writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75            assembler: MessageAssembler::new(),
76            cancel_notify: Arc::new(Notify::new()),
77            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78        }
79    }
80
81    /// Create a new connection with custom codecs.
82    pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83        let (read_half, write_half) = tokio::io::split(transport);
84
85        Self {
86            reader: PacketReader::with_codec(read_half, read_codec),
87            writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88                write_half,
89                write_codec,
90            ))),
91            assembler: MessageAssembler::new(),
92            cancel_notify: Arc::new(Notify::new()),
93            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94        }
95    }
96
97    /// Get a handle for cancelling queries on this connection.
98    ///
99    /// The handle can be cloned and sent to other tasks.
100    #[must_use]
101    pub fn cancel_handle(&self) -> CancelHandle<T> {
102        CancelHandle {
103            writer: Arc::clone(&self.writer),
104            notify: Arc::clone(&self.cancel_notify),
105            cancelling: Arc::clone(&self.cancelling),
106        }
107    }
108
109    /// Check if a cancellation is currently in progress.
110    #[must_use]
111    pub fn is_cancelling(&self) -> bool {
112        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113    }
114
115    /// Read the next complete message from the connection.
116    ///
117    /// This handles multi-packet message reassembly automatically.
118    pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
119        loop {
120            // Check for cancellation
121            if self.is_cancelling() {
122                // Drain until we see DONE with ATTENTION flag
123                return self.drain_after_cancel().await;
124            }
125
126            match self.reader.next().await {
127                Some(Ok(packet)) => {
128                    if let Some(message) = self.assembler.push(packet) {
129                        return Ok(Some(message));
130                    }
131                    // Continue reading packets until message complete
132                }
133                Some(Err(e)) => return Err(e),
134                None => {
135                    // Connection closed
136                    if self.assembler.has_partial() {
137                        return Err(CodecError::ConnectionClosed);
138                    }
139                    return Ok(None);
140                }
141            }
142        }
143    }
144
145    /// Read a single packet from the connection.
146    ///
147    /// This is lower-level than `read_message` and doesn't perform reassembly.
148    pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
149        match self.reader.next().await {
150            Some(result) => result.map(Some),
151            None => Ok(None),
152        }
153    }
154
155    /// Send a packet on the connection.
156    pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
157        let mut writer = self.writer.lock().await;
158        writer.send(packet).await
159    }
160
161    /// Send a complete message, splitting into multiple packets if needed.
162    ///
163    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
164    /// first packet. This causes SQL Server to reset connection state (temp
165    /// tables, SET options, isolation level, etc.) before executing the command.
166    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
167    pub async fn send_message(
168        &mut self,
169        packet_type: PacketType,
170        payload: Bytes,
171        max_packet_size: usize,
172    ) -> Result<(), CodecError> {
173        self.send_message_with_reset(packet_type, payload, max_packet_size, false)
174            .await
175    }
176
177    /// Send a complete message with optional connection reset.
178    ///
179    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
180    /// first packet. This causes SQL Server to reset connection state (temp
181    /// tables, SET options, isolation level, etc.) before executing the command.
182    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
183    pub async fn send_message_with_reset(
184        &mut self,
185        packet_type: PacketType,
186        payload: Bytes,
187        max_packet_size: usize,
188        reset_connection: bool,
189    ) -> Result<(), CodecError> {
190        let max_payload = max_packet_size - PACKET_HEADER_SIZE;
191        let chunks: Vec<_> = payload.chunks(max_payload).collect();
192        let total_chunks = chunks.len();
193
194        let mut writer = self.writer.lock().await;
195
196        for (i, chunk) in chunks.into_iter().enumerate() {
197            let is_first = i == 0;
198            let is_last = i == total_chunks - 1;
199
200            // Build status flags
201            let mut status = if is_last {
202                PacketStatus::END_OF_MESSAGE
203            } else {
204                PacketStatus::NORMAL
205            };
206
207            // Per TDS spec, RESETCONNECTION must be on the first packet only
208            if is_first && reset_connection {
209                status |= PacketStatus::RESET_CONNECTION;
210            }
211
212            let header = PacketHeader::new(packet_type, status, 0);
213            let packet = Packet::new(header, BytesMut::from(chunk));
214
215            writer.send(packet).await?;
216        }
217
218        Ok(())
219    }
220
221    /// Flush the write buffer.
222    pub async fn flush(&mut self) -> Result<(), CodecError> {
223        let mut writer = self.writer.lock().await;
224        writer.flush().await
225    }
226
227    /// Drain packets after cancellation until DONE with ATTENTION is received.
228    async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
229        tracing::debug!("draining packets after cancellation");
230
231        // Clear any partial message
232        self.assembler.clear();
233
234        loop {
235            match self.reader.next().await {
236                Some(Ok(packet)) => {
237                    // Check for DONE token with ATTENTION flag
238                    // The DONE token is at the start of the payload
239                    if packet.header.packet_type == PacketType::TabularResult
240                        && !packet.payload.is_empty()
241                    {
242                        // TokenType::Done = 0xFD
243                        // Check if this packet contains a Done token
244                        // and the status has ATTN flag (0x0020)
245                        if self.check_attention_done(&packet) {
246                            tracing::debug!("received DONE with ATTENTION, cancellation complete");
247                            self.cancelling
248                                .store(false, std::sync::atomic::Ordering::Release);
249                            self.cancel_notify.notify_waiters();
250                            return Ok(None);
251                        }
252                    }
253                    // Continue draining
254                }
255                Some(Err(e)) => {
256                    self.cancelling
257                        .store(false, std::sync::atomic::Ordering::Release);
258                    return Err(e);
259                }
260                None => {
261                    self.cancelling
262                        .store(false, std::sync::atomic::Ordering::Release);
263                    return Ok(None);
264                }
265            }
266        }
267    }
268
269    /// Check if a packet contains a DONE token with ATTENTION flag.
270    fn check_attention_done(&self, packet: &Packet) -> bool {
271        // Look for DONE token (0xFD) with ATTN status flag (bit 5)
272        // DONE token format: token_type(1) + status(2) + cur_cmd(2) + row_count(8)
273        let payload = &packet.payload;
274
275        for i in 0..payload.len() {
276            if payload[i] == 0xFD && i + 3 <= payload.len() {
277                // Found DONE token, check status
278                let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
279                // DONE_ATTN = 0x0020
280                if status & 0x0020 != 0 {
281                    return true;
282                }
283            }
284        }
285
286        false
287    }
288
289    /// Get a reference to the read codec.
290    pub fn read_codec(&self) -> &TdsCodec {
291        self.reader.codec()
292    }
293
294    /// Get a mutable reference to the read codec.
295    pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
296        self.reader.codec_mut()
297    }
298}
299
300impl<T> std::fmt::Debug for Connection<T>
301where
302    T: AsyncRead + AsyncWrite + std::fmt::Debug,
303{
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        f.debug_struct("Connection")
306            .field("cancelling", &self.is_cancelling())
307            .field("has_partial_message", &self.assembler.has_partial())
308            .finish_non_exhaustive()
309    }
310}
311
312/// Handle for cancelling queries on a connection.
313///
314/// This can be cloned and sent to other tasks to enable cancellation
315/// from a different async context.
316pub struct CancelHandle<T>
317where
318    T: AsyncRead + AsyncWrite,
319{
320    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
321    notify: Arc<Notify>,
322    cancelling: Arc<std::sync::atomic::AtomicBool>,
323}
324
325impl<T> CancelHandle<T>
326where
327    T: AsyncRead + AsyncWrite + Unpin,
328{
329    /// Send an Attention packet to cancel the current query.
330    ///
331    /// This can be called from a different task while the main task
332    /// is blocked reading results.
333    pub async fn cancel(&self) -> Result<(), CodecError> {
334        // Mark cancellation in progress
335        self.cancelling
336            .store(true, std::sync::atomic::Ordering::Release);
337
338        tracing::debug!("sending Attention packet for query cancellation");
339
340        // Send the Attention packet
341        let mut writer = self.writer.lock().await;
342
343        // Create and send attention packet
344        let header = PacketHeader::new(
345            PacketType::Attention,
346            PacketStatus::END_OF_MESSAGE,
347            PACKET_HEADER_SIZE as u16,
348        );
349        let packet = Packet::new(header, BytesMut::new());
350
351        writer.send(packet).await?;
352        writer.flush().await?;
353
354        Ok(())
355    }
356
357    /// Wait for the cancellation to complete.
358    ///
359    /// This waits until the server acknowledges the cancellation
360    /// with a DONE token containing the ATTENTION flag.
361    pub async fn wait_cancelled(&self) {
362        if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
363            self.notify.notified().await;
364        }
365    }
366
367    /// Check if a cancellation is currently in progress.
368    #[must_use]
369    pub fn is_cancelling(&self) -> bool {
370        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
371    }
372}
373
374impl<T> Clone for CancelHandle<T>
375where
376    T: AsyncRead + AsyncWrite,
377{
378    fn clone(&self) -> Self {
379        Self {
380            writer: Arc::clone(&self.writer),
381            notify: Arc::clone(&self.notify),
382            cancelling: Arc::clone(&self.cancelling),
383        }
384    }
385}
386
387impl<T> std::fmt::Debug for CancelHandle<T>
388where
389    T: AsyncRead + AsyncWrite + Unpin,
390{
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("CancelHandle")
393            .field("cancelling", &self.is_cancelling())
394            .finish_non_exhaustive()
395    }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_attention_packet_header() {
405        // Verify attention packet header construction
406        let header = PacketHeader::new(
407            PacketType::Attention,
408            PacketStatus::END_OF_MESSAGE,
409            PACKET_HEADER_SIZE as u16,
410        );
411
412        assert_eq!(header.packet_type, PacketType::Attention);
413        assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
414        assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
415    }
416
417    #[test]
418    fn test_check_attention_done() {
419        // Test DONE token with ATTN flag detection
420        // DONE token: 0xFD + status(2 bytes) + cur_cmd(2 bytes) + row_count(8 bytes)
421        // DONE_ATTN flag is 0x0020
422
423        // Create a mock packet with DONE token and ATTN flag
424        let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
425
426        // DONE token with ATTN flag set (status = 0x0020)
427        let payload_with_attn = BytesMut::from(
428            &[
429                0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
430            ][..],
431        );
432        let packet_with_attn = Packet::new(header, payload_with_attn);
433
434        // DONE token without ATTN flag (status = 0x0000)
435        let payload_no_attn = BytesMut::from(
436            &[
437                0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
438            ][..],
439        );
440        let packet_no_attn = Packet::new(header, payload_no_attn);
441
442        // We can't easily test check_attention_done without a Connection,
443        // so we verify the token detection logic directly
444        let check_done = |packet: &Packet| -> bool {
445            let payload = &packet.payload;
446            for i in 0..payload.len() {
447                if payload[i] == 0xFD && i + 3 <= payload.len() {
448                    let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
449                    if status & 0x0020 != 0 {
450                        return true;
451                    }
452                }
453            }
454            false
455        };
456
457        assert!(check_done(&packet_with_attn));
458        assert!(!check_done(&packet_no_attn));
459    }
460}