Skip to main content

mssql_codec/
connection.rs

1//! Split I/O connection for cancellation safety.
2//!
3//! Per ADR-005, the TCP stream is split into separate read and write halves
4//! to allow sending Attention packets while blocked on reading results.
5
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19/// A TDS connection with split I/O for cancellation safety.
20///
21/// This struct splits the underlying transport into read and write halves,
22/// allowing Attention packets to be sent even while blocked reading results.
23///
24/// # Cancellation
25///
26/// SQL Server uses out-of-band "Attention" packets to cancel running queries.
27/// Without split I/O, the driver would be unable to send cancellation while
28/// blocked awaiting a read (e.g., processing a large result set).
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use mssql_codec::Connection;
34/// use tokio::net::TcpStream;
35///
36/// let stream = TcpStream::connect("localhost:1433").await?;
37/// let conn = Connection::new(stream);
38///
39/// // Can cancel from another task while reading
40/// let cancel_handle = conn.cancel_handle();
41/// tokio::spawn(async move {
42///     tokio::time::sleep(Duration::from_secs(5)).await;
43///     cancel_handle.cancel().await?;
44/// });
45/// ```
46pub struct Connection<T>
47where
48    T: AsyncRead + AsyncWrite,
49{
50    /// Read half wrapped in a packet reader.
51    reader: PacketReader<ReadHalf<T>>,
52    /// Write half protected by mutex for concurrent cancel access.
53    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54    /// Message assembler for multi-packet messages.
55    assembler: MessageAssembler,
56    /// Notification for cancellation completion.
57    cancel_notify: Arc<Notify>,
58    /// Flag indicating cancellation is in progress.
59    cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64    T: AsyncRead + AsyncWrite,
65{
66    /// Create a new connection from a transport.
67    ///
68    /// The transport is immediately split into read and write halves.
69    pub fn new(transport: T) -> Self {
70        let (read_half, write_half) = tokio::io::split(transport);
71
72        Self {
73            reader: PacketReader::new(read_half),
74            writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75            assembler: MessageAssembler::new(),
76            cancel_notify: Arc::new(Notify::new()),
77            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78        }
79    }
80
81    /// Create a new connection with custom codecs.
82    pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83        let (read_half, write_half) = tokio::io::split(transport);
84
85        Self {
86            reader: PacketReader::with_codec(read_half, read_codec),
87            writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88                write_half,
89                write_codec,
90            ))),
91            assembler: MessageAssembler::new(),
92            cancel_notify: Arc::new(Notify::new()),
93            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94        }
95    }
96
97    /// Get a handle for cancelling queries on this connection.
98    ///
99    /// The handle can be cloned and sent to other tasks.
100    #[must_use]
101    pub fn cancel_handle(&self) -> CancelHandle<T> {
102        CancelHandle {
103            writer: Arc::clone(&self.writer),
104            notify: Arc::clone(&self.cancel_notify),
105            cancelling: Arc::clone(&self.cancelling),
106        }
107    }
108
109    /// Check if a cancellation is currently in progress.
110    #[must_use]
111    pub fn is_cancelling(&self) -> bool {
112        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113    }
114
115    /// Read the next complete message from the connection.
116    ///
117    /// This handles multi-packet message reassembly automatically.
118    ///
119    /// Returns [`CodecError::Cancelled`] when the in-flight request was
120    /// cancelled via Attention and the server's DONE_ATTN acknowledgement has
121    /// been consumed — the connection is then clean for the next request.
122    pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
123        loop {
124            // Check for cancellation
125            if self.is_cancelling() {
126                // Drain until we see DONE with ATTENTION flag
127                return self.drain_after_cancel().await;
128            }
129
130            match self.reader.next().await {
131                Some(Ok(packet)) => {
132                    if let Some(message) = self.assembler.push(packet) {
133                        // The cancel flag may have been set while this read was
134                        // parked in `next()`. In that case the message belongs
135                        // to the request being cancelled (the server discards
136                        // it and acknowledges with DONE_ATTN), so it must not
137                        // be surfaced as a response — otherwise `cancelling`
138                        // stays latched and a later drain eats the *next*
139                        // request's response.
140                        if self.is_cancelling() {
141                            if Self::payload_ends_with_attention_done(&message.payload) {
142                                tracing::debug!(
143                                    "received DONE with ATTENTION, cancellation complete"
144                                );
145                                self.finish_cancel();
146                                return Err(CodecError::Cancelled);
147                            }
148                            tracing::debug!("discarding message from cancelled request");
149                            continue;
150                        }
151                        return Ok(Some(message));
152                    }
153                    // Continue reading packets until message complete
154                }
155                Some(Err(e)) => return Err(e),
156                None => {
157                    // Connection closed
158                    if self.assembler.has_partial() {
159                        return Err(CodecError::ConnectionClosed);
160                    }
161                    return Ok(None);
162                }
163            }
164        }
165    }
166
167    /// Read a single packet from the connection.
168    ///
169    /// This is lower-level than `read_message` and doesn't perform reassembly.
170    pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
171        match self.reader.next().await {
172            Some(result) => result.map(Some),
173            None => Ok(None),
174        }
175    }
176
177    /// Send a packet on the connection.
178    pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
179        let mut writer = self.writer.lock().await;
180        writer.send(packet).await
181    }
182
183    /// Send a complete message, splitting into multiple packets if needed.
184    ///
185    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
186    /// first packet. This causes SQL Server to reset connection state (temp
187    /// tables, SET options, isolation level, etc.) before executing the command.
188    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
189    pub async fn send_message(
190        &mut self,
191        packet_type: PacketType,
192        payload: Bytes,
193        max_packet_size: usize,
194    ) -> Result<(), CodecError> {
195        self.send_message_with_reset(packet_type, payload, max_packet_size, false)
196            .await
197    }
198
199    /// Send a complete message with optional connection reset.
200    ///
201    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
202    /// first packet. This causes SQL Server to reset connection state (temp
203    /// tables, SET options, isolation level, etc.) before executing the command.
204    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
205    pub async fn send_message_with_reset(
206        &mut self,
207        packet_type: PacketType,
208        payload: Bytes,
209        max_packet_size: usize,
210        reset_connection: bool,
211    ) -> Result<(), CodecError> {
212        let max_payload = max_packet_size - PACKET_HEADER_SIZE;
213        // An empty payload must still produce one header-only EOM packet:
214        // `[]chunks()` yields zero chunks, which would send nothing at all and
215        // leave the caller waiting for a response that never comes (issue
216        // #165). A zero-length-payload message is valid TDS framing.
217        let chunks: Vec<&[u8]> = if payload.is_empty() {
218            vec![&[]]
219        } else {
220            payload.chunks(max_payload).collect()
221        };
222        let total_chunks = chunks.len();
223
224        let mut writer = self.writer.lock().await;
225
226        for (i, chunk) in chunks.into_iter().enumerate() {
227            let is_first = i == 0;
228            let is_last = i == total_chunks - 1;
229
230            // Build status flags
231            let mut status = if is_last {
232                PacketStatus::END_OF_MESSAGE
233            } else {
234                PacketStatus::NORMAL
235            };
236
237            // Per TDS spec, RESETCONNECTION must be on the first packet only
238            if is_first && reset_connection {
239                status |= PacketStatus::RESET_CONNECTION;
240            }
241
242            let header = PacketHeader::new(packet_type, status, 0);
243            let packet = Packet::new(header, BytesMut::from(chunk));
244
245            writer.send(packet).await?;
246        }
247
248        Ok(())
249    }
250
251    /// Flush the write buffer.
252    pub async fn flush(&mut self) -> Result<(), CodecError> {
253        let mut writer = self.writer.lock().await;
254        writer.flush().await
255    }
256
257    /// Drain messages after cancellation until DONE with ATTENTION is received.
258    ///
259    /// Returns [`CodecError::Cancelled`] once the acknowledgement is consumed;
260    /// the connection is then clean for the next request.
261    async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
262        tracing::debug!("draining packets after cancellation");
263
264        // Clear any partial message
265        self.assembler.clear();
266
267        loop {
268            match self.reader.next().await {
269                Some(Ok(packet)) => {
270                    // Assemble complete messages so the acknowledgement check
271                    // runs on the message trailer — a per-packet check would
272                    // miss a DONE token straddling a packet boundary.
273                    if let Some(message) = self.assembler.push(packet) {
274                        if message.packet_type == PacketType::TabularResult
275                            && Self::payload_ends_with_attention_done(&message.payload)
276                        {
277                            tracing::debug!("received DONE with ATTENTION, cancellation complete");
278                            self.finish_cancel();
279                            return Err(CodecError::Cancelled);
280                        }
281                        tracing::debug!("discarding message from cancelled request");
282                    }
283                    // Continue draining
284                }
285                Some(Err(e)) => {
286                    self.cancelling
287                        .store(false, std::sync::atomic::Ordering::Release);
288                    return Err(e);
289                }
290                None => {
291                    // EOF while waiting for the acknowledgement: the
292                    // connection really is gone.
293                    self.cancelling
294                        .store(false, std::sync::atomic::Ordering::Release);
295                    return Err(CodecError::ConnectionClosed);
296                }
297            }
298        }
299    }
300
301    /// Mark the in-flight cancellation as acknowledged and wake waiters.
302    fn finish_cancel(&self) {
303        self.cancelling
304            .store(false, std::sync::atomic::Ordering::Release);
305        self.cancel_notify.notify_waiters();
306    }
307
308    /// Check whether a message payload terminates in a DONE token carrying
309    /// the ATTN status flag (the attention acknowledgement).
310    ///
311    /// Every tabular response message ends with a fixed 13-byte DONE-family
312    /// token (token(1) + status(2) + cur_cmd(2) + row_count(8)), and per
313    /// MS-TDS 2.2.7.6 the acknowledgement is a DONE (0xFD) with DONE_ATTN as
314    /// the final token of the cancelled stream. Anchoring the check to the
315    /// trailer means row bytes that happen to contain `0xFD, 0x20` (entirely
316    /// possible in binary/integer cell data arriving during the cancel
317    /// window) cannot be mistaken for the acknowledgement — an interior byte
318    /// scan was proven to clear the cancel flag early and leak the real
319    /// acknowledgement into the next request.
320    fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
321        let Some(start) = payload.len().checked_sub(13) else {
322            return false;
323        };
324        // DONE token type = 0xFD; DONE_ATTN = 0x0020 in the LE status word.
325        payload[start] == 0xFD
326            && u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
327    }
328
329    /// Get a reference to the read codec.
330    pub fn read_codec(&self) -> &TdsCodec {
331        self.reader.codec()
332    }
333
334    /// Get a mutable reference to the read codec.
335    pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
336        self.reader.codec_mut()
337    }
338}
339
340impl<T> std::fmt::Debug for Connection<T>
341where
342    T: AsyncRead + AsyncWrite + std::fmt::Debug,
343{
344    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        f.debug_struct("Connection")
346            .field("cancelling", &self.is_cancelling())
347            .field("has_partial_message", &self.assembler.has_partial())
348            .finish_non_exhaustive()
349    }
350}
351
352/// Handle for cancelling queries on a connection.
353///
354/// This can be cloned and sent to other tasks to enable cancellation
355/// from a different async context.
356pub struct CancelHandle<T>
357where
358    T: AsyncRead + AsyncWrite,
359{
360    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
361    notify: Arc<Notify>,
362    cancelling: Arc<std::sync::atomic::AtomicBool>,
363}
364
365impl<T> CancelHandle<T>
366where
367    T: AsyncRead + AsyncWrite + Unpin,
368{
369    /// Send an Attention packet to cancel the current query.
370    ///
371    /// This can be called from a different task while the main task
372    /// is blocked reading results.
373    pub async fn cancel(&self) -> Result<(), CodecError> {
374        // Mark cancellation in progress
375        self.cancelling
376            .store(true, std::sync::atomic::Ordering::Release);
377
378        tracing::debug!("sending Attention packet for query cancellation");
379
380        // Send the Attention packet
381        let mut writer = self.writer.lock().await;
382
383        // Create and send attention packet
384        let header = PacketHeader::new(
385            PacketType::Attention,
386            PacketStatus::END_OF_MESSAGE,
387            PACKET_HEADER_SIZE as u16,
388        );
389        let packet = Packet::new(header, BytesMut::new());
390
391        writer.send(packet).await?;
392        writer.flush().await?;
393
394        Ok(())
395    }
396
397    /// Wait for the cancellation to complete.
398    ///
399    /// This waits until the server acknowledges the cancellation
400    /// with a DONE token containing the ATTENTION flag.
401    pub async fn wait_cancelled(&self) {
402        if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
403            self.notify.notified().await;
404        }
405    }
406
407    /// Check if a cancellation is currently in progress.
408    #[must_use]
409    pub fn is_cancelling(&self) -> bool {
410        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
411    }
412}
413
414impl<T> Clone for CancelHandle<T>
415where
416    T: AsyncRead + AsyncWrite,
417{
418    fn clone(&self) -> Self {
419        Self {
420            writer: Arc::clone(&self.writer),
421            notify: Arc::clone(&self.notify),
422            cancelling: Arc::clone(&self.cancelling),
423        }
424    }
425}
426
427impl<T> std::fmt::Debug for CancelHandle<T>
428where
429    T: AsyncRead + AsyncWrite + Unpin,
430{
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        f.debug_struct("CancelHandle")
433            .field("cancelling", &self.is_cancelling())
434            .finish_non_exhaustive()
435    }
436}
437
438#[cfg(test)]
439#[allow(clippy::unwrap_used, clippy::expect_used)]
440mod tests {
441    use super::*;
442
443    /// Issue #165: sending a message with an empty payload must still emit
444    /// exactly one header-only EOM packet. Previously `chunks()` on an empty
445    /// payload yielded zero chunks, so nothing was sent and the caller would
446    /// hang waiting for a response.
447    #[tokio::test]
448    async fn test_send_empty_payload_emits_one_eom_packet() {
449        use tokio::io::AsyncReadExt;
450
451        let (client_io, mut server_io) = tokio::io::duplex(4096);
452        let mut conn = Connection::new(client_io);
453
454        conn.send_message(PacketType::SqlBatch, Bytes::new(), 4096)
455            .await
456            .expect("empty message should send");
457
458        // Exactly one header-only packet (8 bytes) must arrive.
459        let mut header = [0u8; PACKET_HEADER_SIZE];
460        server_io
461            .read_exact(&mut header)
462            .await
463            .expect("one header-only packet must be sent");
464        assert_eq!(header[0], PacketType::SqlBatch as u8);
465        assert!(
466            PacketStatus::from_bits_truncate(header[1]).contains(PacketStatus::END_OF_MESSAGE),
467            "the single packet must be flagged END_OF_MESSAGE"
468        );
469        let length = u16::from_be_bytes([header[2], header[3]]);
470        assert_eq!(
471            length as usize, PACKET_HEADER_SIZE,
472            "length must be header-only (no payload)"
473        );
474
475        // And nothing more.
476        drop(conn);
477        let mut rest = Vec::new();
478        server_io.read_to_end(&mut rest).await.expect("read rest");
479        assert!(rest.is_empty(), "no second packet may follow");
480    }
481
482    /// Issue #165: across a multi-packet send, RESET_CONNECTION must be set on
483    /// the first packet only (per MS-TDS), and END_OF_MESSAGE on the last only.
484    #[tokio::test]
485    async fn test_reset_flag_on_first_packet_only_across_multi_packet_send() {
486        use tokio::io::AsyncReadExt;
487
488        let (client_io, mut server_io) = tokio::io::duplex(4096);
489        let mut conn = Connection::new(client_io);
490
491        // max_packet_size 16 → max_payload 8; a 12-byte payload spans 2 packets.
492        let payload = Bytes::from(vec![0xABu8; 12]);
493        conn.send_message_with_reset(PacketType::SqlBatch, payload, 16, true)
494            .await
495            .expect("multi-packet send should succeed");
496        drop(conn);
497
498        let mut all = Vec::new();
499        server_io.read_to_end(&mut all).await.expect("read packets");
500
501        // Packet 1: header(8) + payload(8) = 16 bytes.
502        let s0 = PacketStatus::from_bits_truncate(all[1]);
503        assert!(
504            s0.contains(PacketStatus::RESET_CONNECTION),
505            "first packet must carry RESET_CONNECTION"
506        );
507        assert!(
508            !s0.contains(PacketStatus::END_OF_MESSAGE),
509            "first packet of two must not be END_OF_MESSAGE"
510        );
511
512        // Packet 2 starts at offset 16: header(8) + payload(4).
513        let s1 = PacketStatus::from_bits_truncate(all[16 + 1]);
514        assert!(
515            !s1.contains(PacketStatus::RESET_CONNECTION),
516            "RESET_CONNECTION must not repeat on later packets"
517        );
518        assert!(
519            s1.contains(PacketStatus::END_OF_MESSAGE),
520            "last packet must be END_OF_MESSAGE"
521        );
522        assert_eq!(all.len(), 16 + 8 + 4, "exactly two packets must be sent");
523    }
524
525    #[test]
526    fn test_attention_packet_header() {
527        // Verify attention packet header construction
528        let header = PacketHeader::new(
529            PacketType::Attention,
530            PacketStatus::END_OF_MESSAGE,
531            PACKET_HEADER_SIZE as u16,
532        );
533
534        assert_eq!(header.packet_type, PacketType::Attention);
535        assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
536        assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
537    }
538
539    #[test]
540    fn test_check_attention_done() {
541        // Test DONE token with ATTN flag detection
542        // DONE token: 0xFD + status(2 bytes) + cur_cmd(2 bytes) + row_count(8 bytes)
543        // DONE_ATTN flag is 0x0020
544
545        // Create a mock packet with DONE token and ATTN flag
546        let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
547
548        // DONE token with ATTN flag set (status = 0x0020)
549        let payload_with_attn = BytesMut::from(
550            &[
551                0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
552            ][..],
553        );
554        let packet_with_attn = Packet::new(header, payload_with_attn);
555
556        // DONE token without ATTN flag (status = 0x0000)
557        let payload_no_attn = BytesMut::from(
558            &[
559                0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
560            ][..],
561        );
562        let packet_no_attn = Packet::new(header, payload_no_attn);
563
564        assert!(
565            Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
566                &packet_with_attn.payload
567            )
568        );
569        assert!(
570            !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
571                &packet_no_attn.payload
572            )
573        );
574
575        // Interior 0xFD,0x20 bytes (e.g. row data) must not register: only
576        // the trailing token position counts.
577        let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
578        interior.extend_from_slice(&[
579            0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
580        ]);
581        assert!(
582            !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
583        );
584    }
585
586    /// Build a raw single-packet TabularResult TDS message around `payload`.
587    fn raw_message(payload: &[u8]) -> Vec<u8> {
588        let mut v = vec![0x04, 0x01]; // TabularResult, END_OF_MESSAGE
589        v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
590        v.extend_from_slice(&[0, 0, 1, 0]); // spid, packet id, window
591        v.extend_from_slice(payload);
592        v
593    }
594
595    /// DONE token bytes with the given status.
596    fn done_token(status: u16) -> [u8; 13] {
597        let s = status.to_le_bytes();
598        [
599            0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
600        ]
601    }
602
603    /// Regression test for the cancel-mid-read race.
604    ///
605    /// When `cancel()` fires while `read_message()` is already parked on the
606    /// socket, the cancelled request's response stream (here: DONE(ERROR)
607    /// followed by the DONE(ATTN) acknowledgement) arrives through the
608    /// *normal* read path. It must be discarded — not surfaced as a query
609    /// response — and the read must end in `CodecError::Cancelled` with the
610    /// `cancelling` flag cleared, so the next request's response is delivered
611    /// intact. Before the fix, the first DONE was returned as the response,
612    /// the flag stayed latched, and a later drain ate the next response.
613    #[tokio::test]
614    async fn test_cancel_mid_read_discards_cancelled_stream() {
615        use std::task::{Context, Poll};
616        use tokio::io::AsyncWriteExt;
617
618        let (client_io, mut server_io) = tokio::io::duplex(4096);
619        let mut conn = Connection::new(client_io);
620        let cancel = conn.cancel_handle();
621
622        // Park a read with nothing to deliver yet (mimics waiting on a slow
623        // query). A noop waker is fine: the future is re-polled via `.await`
624        // below after data is written.
625        let mut read_fut = Box::pin(conn.read_message());
626        let waker = std::task::Waker::noop();
627        let mut cx = Context::from_waker(waker);
628        assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
629
630        // Cancel while the read is parked, then deliver the cancelled
631        // request's stream plus the next request's response.
632        cancel.cancel().await.expect("send attention");
633        server_io
634            .write_all(&raw_message(&done_token(0x0002))) // DONE_ERROR
635            .await
636            .unwrap();
637        server_io
638            .write_all(&raw_message(&done_token(0x0020))) // DONE_ATTN ack
639            .await
640            .unwrap();
641        server_io
642            .write_all(&raw_message(&done_token(0x0010))) // next response
643            .await
644            .unwrap();
645
646        let result = read_fut.await;
647        assert!(
648            matches!(result, Err(CodecError::Cancelled)),
649            "parked read must consume the cancelled stream and report \
650             Cancelled, got {result:?}"
651        );
652        assert!(!conn.is_cancelling(), "cancel flag must be cleared");
653
654        // The next request's response must come through untouched.
655        let message = conn
656            .read_message()
657            .await
658            .expect("next read")
659            .expect("next message");
660        assert_eq!(message.payload[0], 0xFD);
661        assert_eq!(
662            u16::from_le_bytes([message.payload[1], message.payload[2]]),
663            0x0010,
664            "next response must not be eaten by a stale drain"
665        );
666    }
667
668    /// Cancellation requested before the read starts takes the drain path and
669    /// must behave identically to the mid-read race.
670    #[tokio::test]
671    async fn test_cancel_before_read_drains_to_attention_ack() {
672        use tokio::io::AsyncWriteExt;
673
674        let (client_io, mut server_io) = tokio::io::duplex(4096);
675        let mut conn = Connection::new(client_io);
676        let cancel = conn.cancel_handle();
677
678        cancel.cancel().await.expect("send attention");
679        server_io
680            .write_all(&raw_message(&done_token(0x0022))) // ERROR | ATTN ack
681            .await
682            .unwrap();
683        server_io
684            .write_all(&raw_message(&done_token(0x0010))) // next response
685            .await
686            .unwrap();
687
688        let result = conn.read_message().await;
689        assert!(matches!(result, Err(CodecError::Cancelled)));
690        assert!(!conn.is_cancelling());
691
692        let message = conn
693            .read_message()
694            .await
695            .expect("next read")
696            .expect("next message");
697        assert_eq!(
698            u16::from_le_bytes([message.payload[1], message.payload[2]]),
699            0x0010
700        );
701    }
702
703    /// PR #143 review, Blocker 1: row bytes that happen to contain
704    /// `0xFD, 0x20` must NOT be mistaken for the DONE_ATTN acknowledgement.
705    ///
706    /// During the cancel window the cancelled request's *data* (rows already
707    /// in flight) can arrive before the real acknowledgement. A byte-scan
708    /// for any interior 0xFD with bit 5 set false-positives on such data,
709    /// clears the cancel flag early, and the genuine ack then poisons the
710    /// next request — the exact failure the cancellation fix claims to
711    /// eliminate.
712    #[tokio::test]
713    async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
714        use std::task::{Context, Poll};
715        use tokio::io::AsyncWriteExt;
716
717        let (client_io, mut server_io) = tokio::io::duplex(4096);
718        let mut conn = Connection::new(client_io);
719        let cancel = conn.cancel_handle();
720
721        // Park a read, then cancel while it waits (the realistic ordering).
722        let mut read_fut = Box::pin(conn.read_message());
723        let waker = std::task::Waker::noop();
724        let mut cx = Context::from_waker(waker);
725        assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
726        cancel.cancel().await.expect("send attention");
727
728        // Message 1: the cancelled request's data — row-ish bytes whose
729        // *interior* contains 0xFD followed by a byte with bit 5 set (e.g. a
730        // BIGINT cell value), terminated by a DONE with MORE and no ATTN.
731        let mut row_data = vec![0xD1, 0x08]; // ROW token, length-ish prefix
732        row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
733        row_data.extend_from_slice(&done_token(0x0001)); // DONE_MORE, no ATTN
734        server_io.write_all(&raw_message(&row_data)).await.unwrap();
735
736        // Message 2: the genuine acknowledgement.
737        server_io
738            .write_all(&raw_message(&done_token(0x0020)))
739            .await
740            .unwrap();
741
742        // Message 3: the next request's response.
743        server_io
744            .write_all(&raw_message(&done_token(0x0010)))
745            .await
746            .unwrap();
747
748        let result = read_fut.await;
749        assert!(
750            matches!(result, Err(CodecError::Cancelled)),
751            "cancelled read must end in Cancelled, got {result:?}"
752        );
753        assert!(!conn.is_cancelling());
754
755        // The next read must deliver message 3 — not the stale ack from
756        // message 2.
757        let message = conn
758            .read_message()
759            .await
760            .expect("next read")
761            .expect("next message");
762        let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
763        assert_eq!(
764            status, 0x0010,
765            "next request's response must come through intact; 0x0020 means \
766             the interior row bytes were mistaken for the ack and the real \
767             ack leaked into the next request"
768        );
769    }
770}