Skip to main content

mssql_codec/
connection.rs

1//! Split I/O connection for cancellation safety.
2//!
3//! Per ADR-005, the TCP stream is split into separate read and write halves
4//! to allow sending Attention packets while blocked on reading results.
5
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19/// A TDS connection with split I/O for cancellation safety.
20///
21/// This struct splits the underlying transport into read and write halves,
22/// allowing Attention packets to be sent even while blocked reading results.
23///
24/// # Cancellation
25///
26/// SQL Server uses out-of-band "Attention" packets to cancel running queries.
27/// Without split I/O, the driver would be unable to send cancellation while
28/// blocked awaiting a read (e.g., processing a large result set).
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use mssql_codec::Connection;
34/// use tokio::net::TcpStream;
35///
36/// let stream = TcpStream::connect("localhost:1433").await?;
37/// let conn = Connection::new(stream);
38///
39/// // Can cancel from another task while reading
40/// let cancel_handle = conn.cancel_handle();
41/// tokio::spawn(async move {
42///     tokio::time::sleep(Duration::from_secs(5)).await;
43///     cancel_handle.cancel().await?;
44/// });
45/// ```
46pub struct Connection<T>
47where
48    T: AsyncRead + AsyncWrite,
49{
50    /// Read half wrapped in a packet reader.
51    reader: PacketReader<ReadHalf<T>>,
52    /// Write half protected by mutex for concurrent cancel access.
53    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54    /// Message assembler for multi-packet messages.
55    assembler: MessageAssembler,
56    /// Notification for cancellation completion.
57    cancel_notify: Arc<Notify>,
58    /// Flag indicating cancellation is in progress.
59    cancelling: Arc<std::sync::atomic::AtomicBool>,
60    /// Maximum assembled message size in bytes; 0 means unlimited.
61    max_message_size: usize,
62}
63
64impl<T> Connection<T>
65where
66    T: AsyncRead + AsyncWrite,
67{
68    /// Create a new connection from a transport.
69    ///
70    /// The transport is immediately split into read and write halves.
71    pub fn new(transport: T) -> Self {
72        let (read_half, write_half) = tokio::io::split(transport);
73
74        Self {
75            reader: PacketReader::new(read_half),
76            writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
77            assembler: MessageAssembler::new(),
78            cancel_notify: Arc::new(Notify::new()),
79            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
80            max_message_size: 0,
81        }
82    }
83
84    /// Create a new connection with custom codecs.
85    pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
86        let (read_half, write_half) = tokio::io::split(transport);
87
88        Self {
89            reader: PacketReader::with_codec(read_half, read_codec),
90            writer: Arc::new(Mutex::new(PacketWriter::with_codec(
91                write_half,
92                write_codec,
93            ))),
94            assembler: MessageAssembler::new(),
95            cancel_notify: Arc::new(Notify::new()),
96            cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
97            max_message_size: 0,
98        }
99    }
100
101    /// Cap the size of an assembled response message; 0 means unlimited.
102    ///
103    /// Responses are buffered in full before token parsing, so without a
104    /// cap a single large SELECT is unbounded client memory. When the cap
105    /// is exceeded, [`read_message`](Self::read_message) returns
106    /// [`CodecError::MessageTooLarge`] mid-message — the connection is no
107    /// longer usable for that request and should be discarded.
108    pub fn set_max_message_size(&mut self, limit: usize) {
109        self.max_message_size = limit;
110    }
111
112    /// Get a handle for cancelling queries on this connection.
113    ///
114    /// The handle can be cloned and sent to other tasks.
115    #[must_use]
116    pub fn cancel_handle(&self) -> CancelHandle<T> {
117        CancelHandle {
118            writer: Arc::clone(&self.writer),
119            notify: Arc::clone(&self.cancel_notify),
120            cancelling: Arc::clone(&self.cancelling),
121        }
122    }
123
124    /// Check if a cancellation is currently in progress.
125    #[must_use]
126    pub fn is_cancelling(&self) -> bool {
127        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
128    }
129
130    /// Read the next complete message from the connection.
131    ///
132    /// This handles multi-packet message reassembly automatically.
133    ///
134    /// Returns [`CodecError::Cancelled`] when the in-flight request was
135    /// cancelled via Attention and the server's DONE_ATTN acknowledgement has
136    /// been consumed — the connection is then clean for the next request.
137    pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
138        loop {
139            // Check for cancellation
140            if self.is_cancelling() {
141                // Drain until we see DONE with ATTENTION flag
142                return self.drain_after_cancel().await;
143            }
144
145            match self.reader.next().await {
146                Some(Ok(packet)) => {
147                    let result = self.assembler.push(packet);
148                    // Enforce the response-size cap on the accumulating
149                    // buffer AND the completed message (the final packet can
150                    // jump past the limit). Exceeding it abandons the
151                    // message mid-stream: the caller must discard the
152                    // connection.
153                    if self.max_message_size != 0 {
154                        let size = result
155                            .as_ref()
156                            .map_or_else(|| self.assembler.buffer_len(), |m| m.payload.len());
157                        if size > self.max_message_size {
158                            self.assembler.clear();
159                            return Err(CodecError::MessageTooLarge {
160                                size,
161                                limit: self.max_message_size,
162                            });
163                        }
164                    }
165                    if let Some(message) = result {
166                        // The cancel flag may have been set while this read was
167                        // parked in `next()`. In that case the message belongs
168                        // to the request being cancelled (the server discards
169                        // it and acknowledges with DONE_ATTN), so it must not
170                        // be surfaced as a response — otherwise `cancelling`
171                        // stays latched and a later drain eats the *next*
172                        // request's response.
173                        if self.is_cancelling() {
174                            if Self::payload_ends_with_attention_done(&message.payload) {
175                                tracing::debug!(
176                                    "received DONE with ATTENTION, cancellation complete"
177                                );
178                                self.finish_cancel();
179                                return Err(CodecError::Cancelled);
180                            }
181                            tracing::debug!("discarding message from cancelled request");
182                            continue;
183                        }
184                        return Ok(Some(message));
185                    }
186                    // Continue reading packets until message complete
187                }
188                Some(Err(e)) => return Err(e),
189                None => {
190                    // Connection closed
191                    if self.assembler.has_partial() {
192                        return Err(CodecError::ConnectionClosed);
193                    }
194                    return Ok(None);
195                }
196            }
197        }
198    }
199
200    /// Read a single packet from the connection.
201    ///
202    /// This is lower-level than `read_message` and doesn't perform reassembly.
203    pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
204        match self.reader.next().await {
205            Some(result) => result.map(Some),
206            None => Ok(None),
207        }
208    }
209
210    /// Send a packet on the connection.
211    pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
212        let mut writer = self.writer.lock().await;
213        writer.send(packet).await
214    }
215
216    /// Send a complete message, splitting into multiple packets if needed.
217    ///
218    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
219    /// first packet. This causes SQL Server to reset connection state (temp
220    /// tables, SET options, isolation level, etc.) before executing the command.
221    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
222    pub async fn send_message(
223        &mut self,
224        packet_type: PacketType,
225        payload: Bytes,
226        max_packet_size: usize,
227    ) -> Result<(), CodecError> {
228        self.send_message_with_reset(packet_type, payload, max_packet_size, false)
229            .await
230    }
231
232    /// Send a complete message with optional connection reset.
233    ///
234    /// If `reset_connection` is true, the RESETCONNECTION flag is set on the
235    /// first packet. This causes SQL Server to reset connection state (temp
236    /// tables, SET options, isolation level, etc.) before executing the command.
237    /// Per TDS spec, this flag MUST only be set on the first packet of a message.
238    pub async fn send_message_with_reset(
239        &mut self,
240        packet_type: PacketType,
241        payload: Bytes,
242        max_packet_size: usize,
243        reset_connection: bool,
244    ) -> Result<(), CodecError> {
245        let max_payload = max_packet_size - PACKET_HEADER_SIZE;
246        // An empty payload must still produce one header-only EOM packet:
247        // `[]chunks()` yields zero chunks, which would send nothing at all and
248        // leave the caller waiting for a response that never comes (issue
249        // #165). A zero-length-payload message is valid TDS framing.
250        let chunks: Vec<&[u8]> = if payload.is_empty() {
251            vec![&[]]
252        } else {
253            payload.chunks(max_payload).collect()
254        };
255        let total_chunks = chunks.len();
256
257        let mut writer = self.writer.lock().await;
258
259        for (i, chunk) in chunks.into_iter().enumerate() {
260            let is_first = i == 0;
261            let is_last = i == total_chunks - 1;
262
263            // Build status flags
264            let mut status = if is_last {
265                PacketStatus::END_OF_MESSAGE
266            } else {
267                PacketStatus::NORMAL
268            };
269
270            // Per TDS spec, RESETCONNECTION must be on the first packet only
271            if is_first && reset_connection {
272                status |= PacketStatus::RESET_CONNECTION;
273            }
274
275            let header = PacketHeader::new(packet_type, status, 0);
276            let packet = Packet::new(header, BytesMut::from(chunk));
277
278            writer.send(packet).await?;
279        }
280
281        Ok(())
282    }
283
284    /// Flush the write buffer.
285    pub async fn flush(&mut self) -> Result<(), CodecError> {
286        let mut writer = self.writer.lock().await;
287        writer.flush().await
288    }
289
290    /// Drain messages after cancellation until DONE with ATTENTION is received.
291    ///
292    /// Returns [`CodecError::Cancelled`] once the acknowledgement is consumed;
293    /// the connection is then clean for the next request.
294    async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
295        tracing::debug!("draining packets after cancellation");
296
297        // Clear any partial message
298        self.assembler.clear();
299
300        loop {
301            match self.reader.next().await {
302                Some(Ok(packet)) => {
303                    // Assemble complete messages so the acknowledgement check
304                    // runs on the message trailer — a per-packet check would
305                    // miss a DONE token straddling a packet boundary.
306                    if let Some(message) = self.assembler.push(packet) {
307                        if message.packet_type == PacketType::TabularResult
308                            && Self::payload_ends_with_attention_done(&message.payload)
309                        {
310                            tracing::debug!("received DONE with ATTENTION, cancellation complete");
311                            self.finish_cancel();
312                            return Err(CodecError::Cancelled);
313                        }
314                        tracing::debug!("discarding message from cancelled request");
315                    }
316                    // Continue draining
317                }
318                Some(Err(e)) => {
319                    self.cancelling
320                        .store(false, std::sync::atomic::Ordering::Release);
321                    return Err(e);
322                }
323                None => {
324                    // EOF while waiting for the acknowledgement: the
325                    // connection really is gone.
326                    self.cancelling
327                        .store(false, std::sync::atomic::Ordering::Release);
328                    return Err(CodecError::ConnectionClosed);
329                }
330            }
331        }
332    }
333
334    /// Mark the in-flight cancellation as acknowledged and wake waiters.
335    fn finish_cancel(&self) {
336        self.cancelling
337            .store(false, std::sync::atomic::Ordering::Release);
338        self.cancel_notify.notify_waiters();
339    }
340
341    /// Check whether a message payload terminates in a DONE token carrying
342    /// the ATTN status flag (the attention acknowledgement).
343    ///
344    /// Every tabular response message ends with a fixed 13-byte DONE-family
345    /// token (token(1) + status(2) + cur_cmd(2) + row_count(8)), and per
346    /// MS-TDS 2.2.7.6 the acknowledgement is a DONE (0xFD) with DONE_ATTN as
347    /// the final token of the cancelled stream. Anchoring the check to the
348    /// trailer means row bytes that happen to contain `0xFD, 0x20` (entirely
349    /// possible in binary/integer cell data arriving during the cancel
350    /// window) cannot be mistaken for the acknowledgement — an interior byte
351    /// scan was proven to clear the cancel flag early and leak the real
352    /// acknowledgement into the next request.
353    fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
354        let Some(start) = payload.len().checked_sub(13) else {
355            return false;
356        };
357        // DONE token type = 0xFD; DONE_ATTN = 0x0020 in the LE status word.
358        payload[start] == 0xFD
359            && u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
360    }
361
362    /// Get a reference to the read codec.
363    pub fn read_codec(&self) -> &TdsCodec {
364        self.reader.codec()
365    }
366
367    /// Get a mutable reference to the read codec.
368    pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
369        self.reader.codec_mut()
370    }
371}
372
373impl<T> std::fmt::Debug for Connection<T>
374where
375    T: AsyncRead + AsyncWrite + std::fmt::Debug,
376{
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        f.debug_struct("Connection")
379            .field("cancelling", &self.is_cancelling())
380            .field("has_partial_message", &self.assembler.has_partial())
381            .finish_non_exhaustive()
382    }
383}
384
385/// Handle for cancelling queries on a connection.
386///
387/// This can be cloned and sent to other tasks to enable cancellation
388/// from a different async context.
389pub struct CancelHandle<T>
390where
391    T: AsyncRead + AsyncWrite,
392{
393    writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
394    notify: Arc<Notify>,
395    cancelling: Arc<std::sync::atomic::AtomicBool>,
396}
397
398impl<T> CancelHandle<T>
399where
400    T: AsyncRead + AsyncWrite + Unpin,
401{
402    /// Send an Attention packet to cancel the current query.
403    ///
404    /// This can be called from a different task while the main task
405    /// is blocked reading results.
406    pub async fn cancel(&self) -> Result<(), CodecError> {
407        // Mark cancellation in progress
408        self.cancelling
409            .store(true, std::sync::atomic::Ordering::Release);
410
411        tracing::debug!("sending Attention packet for query cancellation");
412
413        // Send the Attention packet
414        let mut writer = self.writer.lock().await;
415
416        // Create and send attention packet
417        let header = PacketHeader::new(
418            PacketType::Attention,
419            PacketStatus::END_OF_MESSAGE,
420            PACKET_HEADER_SIZE as u16,
421        );
422        let packet = Packet::new(header, BytesMut::new());
423
424        writer.send(packet).await?;
425        writer.flush().await?;
426
427        Ok(())
428    }
429
430    /// Wait for the cancellation to complete.
431    ///
432    /// This waits until the server acknowledges the cancellation
433    /// with a DONE token containing the ATTENTION flag.
434    pub async fn wait_cancelled(&self) {
435        if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
436            self.notify.notified().await;
437        }
438    }
439
440    /// Check if a cancellation is currently in progress.
441    #[must_use]
442    pub fn is_cancelling(&self) -> bool {
443        self.cancelling.load(std::sync::atomic::Ordering::Acquire)
444    }
445}
446
447impl<T> Clone for CancelHandle<T>
448where
449    T: AsyncRead + AsyncWrite,
450{
451    fn clone(&self) -> Self {
452        Self {
453            writer: Arc::clone(&self.writer),
454            notify: Arc::clone(&self.notify),
455            cancelling: Arc::clone(&self.cancelling),
456        }
457    }
458}
459
460impl<T> std::fmt::Debug for CancelHandle<T>
461where
462    T: AsyncRead + AsyncWrite + Unpin,
463{
464    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465        f.debug_struct("CancelHandle")
466            .field("cancelling", &self.is_cancelling())
467            .finish_non_exhaustive()
468    }
469}
470
471#[cfg(test)]
472#[allow(clippy::unwrap_used, clippy::expect_used)]
473mod tests {
474    use super::*;
475
476    /// Issue #165: sending a message with an empty payload must still emit
477    /// exactly one header-only EOM packet. Previously `chunks()` on an empty
478    /// payload yielded zero chunks, so nothing was sent and the caller would
479    /// hang waiting for a response.
480    #[tokio::test]
481    async fn test_send_empty_payload_emits_one_eom_packet() {
482        use tokio::io::AsyncReadExt;
483
484        let (client_io, mut server_io) = tokio::io::duplex(4096);
485        let mut conn = Connection::new(client_io);
486
487        conn.send_message(PacketType::SqlBatch, Bytes::new(), 4096)
488            .await
489            .expect("empty message should send");
490
491        // Exactly one header-only packet (8 bytes) must arrive.
492        let mut header = [0u8; PACKET_HEADER_SIZE];
493        server_io
494            .read_exact(&mut header)
495            .await
496            .expect("one header-only packet must be sent");
497        assert_eq!(header[0], PacketType::SqlBatch as u8);
498        assert!(
499            PacketStatus::from_bits_truncate(header[1]).contains(PacketStatus::END_OF_MESSAGE),
500            "the single packet must be flagged END_OF_MESSAGE"
501        );
502        let length = u16::from_be_bytes([header[2], header[3]]);
503        assert_eq!(
504            length as usize, PACKET_HEADER_SIZE,
505            "length must be header-only (no payload)"
506        );
507
508        // And nothing more.
509        drop(conn);
510        let mut rest = Vec::new();
511        server_io.read_to_end(&mut rest).await.expect("read rest");
512        assert!(rest.is_empty(), "no second packet may follow");
513    }
514
515    /// Issue #165: across a multi-packet send, RESET_CONNECTION must be set on
516    /// the first packet only (per MS-TDS), and END_OF_MESSAGE on the last only.
517    #[tokio::test]
518    async fn test_reset_flag_on_first_packet_only_across_multi_packet_send() {
519        use tokio::io::AsyncReadExt;
520
521        let (client_io, mut server_io) = tokio::io::duplex(4096);
522        let mut conn = Connection::new(client_io);
523
524        // max_packet_size 16 → max_payload 8; a 12-byte payload spans 2 packets.
525        let payload = Bytes::from(vec![0xABu8; 12]);
526        conn.send_message_with_reset(PacketType::SqlBatch, payload, 16, true)
527            .await
528            .expect("multi-packet send should succeed");
529        drop(conn);
530
531        let mut all = Vec::new();
532        server_io.read_to_end(&mut all).await.expect("read packets");
533
534        // Packet 1: header(8) + payload(8) = 16 bytes.
535        let s0 = PacketStatus::from_bits_truncate(all[1]);
536        assert!(
537            s0.contains(PacketStatus::RESET_CONNECTION),
538            "first packet must carry RESET_CONNECTION"
539        );
540        assert!(
541            !s0.contains(PacketStatus::END_OF_MESSAGE),
542            "first packet of two must not be END_OF_MESSAGE"
543        );
544
545        // Packet 2 starts at offset 16: header(8) + payload(4).
546        let s1 = PacketStatus::from_bits_truncate(all[16 + 1]);
547        assert!(
548            !s1.contains(PacketStatus::RESET_CONNECTION),
549            "RESET_CONNECTION must not repeat on later packets"
550        );
551        assert!(
552            s1.contains(PacketStatus::END_OF_MESSAGE),
553            "last packet must be END_OF_MESSAGE"
554        );
555        assert_eq!(all.len(), 16 + 8 + 4, "exactly two packets must be sent");
556    }
557
558    #[test]
559    fn test_attention_packet_header() {
560        // Verify attention packet header construction
561        let header = PacketHeader::new(
562            PacketType::Attention,
563            PacketStatus::END_OF_MESSAGE,
564            PACKET_HEADER_SIZE as u16,
565        );
566
567        assert_eq!(header.packet_type, PacketType::Attention);
568        assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
569        assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
570    }
571
572    #[test]
573    fn test_check_attention_done() {
574        // Test DONE token with ATTN flag detection
575        // DONE token: 0xFD + status(2 bytes) + cur_cmd(2 bytes) + row_count(8 bytes)
576        // DONE_ATTN flag is 0x0020
577
578        // Create a mock packet with DONE token and ATTN flag
579        let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
580
581        // DONE token with ATTN flag set (status = 0x0020)
582        let payload_with_attn = BytesMut::from(
583            &[
584                0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
585            ][..],
586        );
587        let packet_with_attn = Packet::new(header, payload_with_attn);
588
589        // DONE token without ATTN flag (status = 0x0000)
590        let payload_no_attn = BytesMut::from(
591            &[
592                0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
593            ][..],
594        );
595        let packet_no_attn = Packet::new(header, payload_no_attn);
596
597        assert!(
598            Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
599                &packet_with_attn.payload
600            )
601        );
602        assert!(
603            !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
604                &packet_no_attn.payload
605            )
606        );
607
608        // Interior 0xFD,0x20 bytes (e.g. row data) must not register: only
609        // the trailing token position counts.
610        let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
611        interior.extend_from_slice(&[
612            0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
613        ]);
614        assert!(
615            !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
616        );
617    }
618
619    /// Build a raw single-packet TabularResult TDS message around `payload`.
620    fn raw_message(payload: &[u8]) -> Vec<u8> {
621        let mut v = vec![0x04, 0x01]; // TabularResult, END_OF_MESSAGE
622        v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
623        v.extend_from_slice(&[0, 0, 1, 0]); // spid, packet id, window
624        v.extend_from_slice(payload);
625        v
626    }
627
628    /// DONE token bytes with the given status.
629    fn done_token(status: u16) -> [u8; 13] {
630        let s = status.to_le_bytes();
631        [
632            0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
633        ]
634    }
635
636    /// Regression test for the cancel-mid-read race.
637    ///
638    /// When `cancel()` fires while `read_message()` is already parked on the
639    /// socket, the cancelled request's response stream (here: DONE(ERROR)
640    /// followed by the DONE(ATTN) acknowledgement) arrives through the
641    /// *normal* read path. It must be discarded — not surfaced as a query
642    /// response — and the read must end in `CodecError::Cancelled` with the
643    /// `cancelling` flag cleared, so the next request's response is delivered
644    /// intact. Before the fix, the first DONE was returned as the response,
645    /// the flag stayed latched, and a later drain ate the next response.
646    #[tokio::test]
647    async fn test_cancel_mid_read_discards_cancelled_stream() {
648        use std::task::{Context, Poll};
649        use tokio::io::AsyncWriteExt;
650
651        let (client_io, mut server_io) = tokio::io::duplex(4096);
652        let mut conn = Connection::new(client_io);
653        let cancel = conn.cancel_handle();
654
655        // Park a read with nothing to deliver yet (mimics waiting on a slow
656        // query). A noop waker is fine: the future is re-polled via `.await`
657        // below after data is written.
658        let mut read_fut = Box::pin(conn.read_message());
659        let waker = std::task::Waker::noop();
660        let mut cx = Context::from_waker(waker);
661        assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
662
663        // Cancel while the read is parked, then deliver the cancelled
664        // request's stream plus the next request's response.
665        cancel.cancel().await.expect("send attention");
666        server_io
667            .write_all(&raw_message(&done_token(0x0002))) // DONE_ERROR
668            .await
669            .unwrap();
670        server_io
671            .write_all(&raw_message(&done_token(0x0020))) // DONE_ATTN ack
672            .await
673            .unwrap();
674        server_io
675            .write_all(&raw_message(&done_token(0x0010))) // next response
676            .await
677            .unwrap();
678
679        let result = read_fut.await;
680        assert!(
681            matches!(result, Err(CodecError::Cancelled)),
682            "parked read must consume the cancelled stream and report \
683             Cancelled, got {result:?}"
684        );
685        assert!(!conn.is_cancelling(), "cancel flag must be cleared");
686
687        // The next request's response must come through untouched.
688        let message = conn
689            .read_message()
690            .await
691            .expect("next read")
692            .expect("next message");
693        assert_eq!(message.payload[0], 0xFD);
694        assert_eq!(
695            u16::from_le_bytes([message.payload[1], message.payload[2]]),
696            0x0010,
697            "next response must not be eaten by a stale drain"
698        );
699    }
700
701    /// Cancellation requested before the read starts takes the drain path and
702    /// must behave identically to the mid-read race.
703    #[tokio::test]
704    async fn test_cancel_before_read_drains_to_attention_ack() {
705        use tokio::io::AsyncWriteExt;
706
707        let (client_io, mut server_io) = tokio::io::duplex(4096);
708        let mut conn = Connection::new(client_io);
709        let cancel = conn.cancel_handle();
710
711        cancel.cancel().await.expect("send attention");
712        server_io
713            .write_all(&raw_message(&done_token(0x0022))) // ERROR | ATTN ack
714            .await
715            .unwrap();
716        server_io
717            .write_all(&raw_message(&done_token(0x0010))) // next response
718            .await
719            .unwrap();
720
721        let result = conn.read_message().await;
722        assert!(matches!(result, Err(CodecError::Cancelled)));
723        assert!(!conn.is_cancelling());
724
725        let message = conn
726            .read_message()
727            .await
728            .expect("next read")
729            .expect("next message");
730        assert_eq!(
731            u16::from_le_bytes([message.payload[1], message.payload[2]]),
732            0x0010
733        );
734    }
735
736    /// PR #143 review, Blocker 1: row bytes that happen to contain
737    /// `0xFD, 0x20` must NOT be mistaken for the DONE_ATTN acknowledgement.
738    ///
739    /// During the cancel window the cancelled request's *data* (rows already
740    /// in flight) can arrive before the real acknowledgement. A byte-scan
741    /// for any interior 0xFD with bit 5 set false-positives on such data,
742    /// clears the cancel flag early, and the genuine ack then poisons the
743    /// next request — the exact failure the cancellation fix claims to
744    /// eliminate.
745    #[tokio::test]
746    async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
747        use std::task::{Context, Poll};
748        use tokio::io::AsyncWriteExt;
749
750        let (client_io, mut server_io) = tokio::io::duplex(4096);
751        let mut conn = Connection::new(client_io);
752        let cancel = conn.cancel_handle();
753
754        // Park a read, then cancel while it waits (the realistic ordering).
755        let mut read_fut = Box::pin(conn.read_message());
756        let waker = std::task::Waker::noop();
757        let mut cx = Context::from_waker(waker);
758        assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
759        cancel.cancel().await.expect("send attention");
760
761        // Message 1: the cancelled request's data — row-ish bytes whose
762        // *interior* contains 0xFD followed by a byte with bit 5 set (e.g. a
763        // BIGINT cell value), terminated by a DONE with MORE and no ATTN.
764        let mut row_data = vec![0xD1, 0x08]; // ROW token, length-ish prefix
765        row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
766        row_data.extend_from_slice(&done_token(0x0001)); // DONE_MORE, no ATTN
767        server_io.write_all(&raw_message(&row_data)).await.unwrap();
768
769        // Message 2: the genuine acknowledgement.
770        server_io
771            .write_all(&raw_message(&done_token(0x0020)))
772            .await
773            .unwrap();
774
775        // Message 3: the next request's response.
776        server_io
777            .write_all(&raw_message(&done_token(0x0010)))
778            .await
779            .unwrap();
780
781        let result = read_fut.await;
782        assert!(
783            matches!(result, Err(CodecError::Cancelled)),
784            "cancelled read must end in Cancelled, got {result:?}"
785        );
786        assert!(!conn.is_cancelling());
787
788        // The next read must deliver message 3 — not the stale ack from
789        // message 2.
790        let message = conn
791            .read_message()
792            .await
793            .expect("next read")
794            .expect("next message");
795        let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
796        assert_eq!(
797            status, 0x0010,
798            "next request's response must come through intact; 0x0020 means \
799             the interior row bytes were mistaken for the ack and the real \
800             ack leaked into the next request"
801        );
802    }
803
804    /// Build a raw non-EOM (continuation) TabularResult packet.
805    fn raw_packet_non_eom(payload: &[u8]) -> Vec<u8> {
806        let mut v = vec![0x04, 0x00]; // TabularResult, more packets follow
807        v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
808        v.extend_from_slice(&[0, 0, 1, 0]);
809        v.extend_from_slice(payload);
810        v
811    }
812
813    /// Issue #167: the assembled response can be capped so a huge SELECT is
814    /// a loud error instead of unbounded client memory. The cap must fire
815    /// mid-assembly, before the message completes.
816    #[tokio::test]
817    async fn test_max_message_size_cap_fires_mid_assembly() {
818        use tokio::io::AsyncWriteExt;
819        let (client, mut server) = tokio::io::duplex(1 << 16);
820        let mut conn = Connection::new(client);
821        conn.set_max_message_size(1000);
822
823        let chunk = vec![0xAA; 600];
824        let mut stream = raw_packet_non_eom(&chunk);
825        stream.extend_from_slice(&raw_packet_non_eom(&chunk)); // buffer hits 1200
826        server.write_all(&stream).await.unwrap();
827
828        let result = conn.read_message().await;
829        let Err(CodecError::MessageTooLarge { size, limit }) = result else {
830            unreachable!("expected MessageTooLarge, got {result:?}");
831        };
832        assert_eq!(limit, 1000);
833        assert!(size > 1000);
834    }
835
836    /// The cap must also catch a message whose final (EOM) packet jumps it
837    /// past the limit.
838    #[tokio::test]
839    async fn test_max_message_size_cap_fires_on_completing_packet() {
840        use tokio::io::AsyncWriteExt;
841        let (client, mut server) = tokio::io::duplex(1 << 16);
842        let mut conn = Connection::new(client);
843        conn.set_max_message_size(1000);
844
845        let chunk = vec![0xAA; 600];
846        let mut stream = raw_packet_non_eom(&chunk);
847        stream.extend_from_slice(&raw_message(&chunk)); // completes at 1200
848        server.write_all(&stream).await.unwrap();
849
850        assert!(matches!(
851            conn.read_message().await,
852            Err(CodecError::MessageTooLarge { .. })
853        ));
854    }
855
856    /// Zero (the default) means unlimited — existing behavior unchanged.
857    #[tokio::test]
858    async fn test_max_message_size_zero_is_unlimited() {
859        use tokio::io::AsyncWriteExt;
860        let (client, mut server) = tokio::io::duplex(1 << 16);
861        let mut conn = Connection::new(client);
862
863        let payload = vec![0xAA; 5000];
864        server.write_all(&raw_message(&payload)).await.unwrap();
865        let msg = conn.read_message().await.unwrap().unwrap();
866        assert_eq!(msg.payload.len(), 5000);
867    }
868}