1use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19pub struct Connection<T>
47where
48 T: AsyncRead + AsyncWrite,
49{
50 reader: PacketReader<ReadHalf<T>>,
52 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54 assembler: MessageAssembler,
56 cancel_notify: Arc<Notify>,
58 cancelling: Arc<std::sync::atomic::AtomicBool>,
60 max_message_size: usize,
62}
63
64impl<T> Connection<T>
65where
66 T: AsyncRead + AsyncWrite,
67{
68 pub fn new(transport: T) -> Self {
72 let (read_half, write_half) = tokio::io::split(transport);
73
74 Self {
75 reader: PacketReader::new(read_half),
76 writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
77 assembler: MessageAssembler::new(),
78 cancel_notify: Arc::new(Notify::new()),
79 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
80 max_message_size: 0,
81 }
82 }
83
84 pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
86 let (read_half, write_half) = tokio::io::split(transport);
87
88 Self {
89 reader: PacketReader::with_codec(read_half, read_codec),
90 writer: Arc::new(Mutex::new(PacketWriter::with_codec(
91 write_half,
92 write_codec,
93 ))),
94 assembler: MessageAssembler::new(),
95 cancel_notify: Arc::new(Notify::new()),
96 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
97 max_message_size: 0,
98 }
99 }
100
101 pub fn set_max_message_size(&mut self, limit: usize) {
109 self.max_message_size = limit;
110 }
111
112 #[must_use]
116 pub fn cancel_handle(&self) -> CancelHandle<T> {
117 CancelHandle {
118 writer: Arc::clone(&self.writer),
119 notify: Arc::clone(&self.cancel_notify),
120 cancelling: Arc::clone(&self.cancelling),
121 }
122 }
123
124 #[must_use]
126 pub fn is_cancelling(&self) -> bool {
127 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
128 }
129
130 pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
138 loop {
139 if self.is_cancelling() {
141 return self.drain_after_cancel().await;
143 }
144
145 match self.reader.next().await {
146 Some(Ok(packet)) => {
147 let result = self.assembler.push(packet);
148 if self.max_message_size != 0 {
154 let size = result
155 .as_ref()
156 .map_or_else(|| self.assembler.buffer_len(), |m| m.payload.len());
157 if size > self.max_message_size {
158 self.assembler.clear();
159 return Err(CodecError::MessageTooLarge {
160 size,
161 limit: self.max_message_size,
162 });
163 }
164 }
165 if let Some(message) = result {
166 if self.is_cancelling() {
174 if Self::payload_ends_with_attention_done(&message.payload) {
175 tracing::debug!(
176 "received DONE with ATTENTION, cancellation complete"
177 );
178 self.finish_cancel();
179 return Err(CodecError::Cancelled);
180 }
181 tracing::debug!("discarding message from cancelled request");
182 continue;
183 }
184 return Ok(Some(message));
185 }
186 }
188 Some(Err(e)) => return Err(e),
189 None => {
190 if self.assembler.has_partial() {
192 return Err(CodecError::ConnectionClosed);
193 }
194 return Ok(None);
195 }
196 }
197 }
198 }
199
200 pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
204 match self.reader.next().await {
205 Some(result) => result.map(Some),
206 None => Ok(None),
207 }
208 }
209
210 pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
212 let mut writer = self.writer.lock().await;
213 writer.send(packet).await
214 }
215
216 pub async fn send_message(
223 &mut self,
224 packet_type: PacketType,
225 payload: Bytes,
226 max_packet_size: usize,
227 ) -> Result<(), CodecError> {
228 self.send_message_with_reset(packet_type, payload, max_packet_size, false)
229 .await
230 }
231
232 pub async fn send_message_with_reset(
239 &mut self,
240 packet_type: PacketType,
241 payload: Bytes,
242 max_packet_size: usize,
243 reset_connection: bool,
244 ) -> Result<(), CodecError> {
245 let max_payload = max_packet_size - PACKET_HEADER_SIZE;
246 let chunks: Vec<&[u8]> = if payload.is_empty() {
251 vec![&[]]
252 } else {
253 payload.chunks(max_payload).collect()
254 };
255 let total_chunks = chunks.len();
256
257 let mut writer = self.writer.lock().await;
258
259 for (i, chunk) in chunks.into_iter().enumerate() {
260 let is_first = i == 0;
261 let is_last = i == total_chunks - 1;
262
263 let mut status = if is_last {
265 PacketStatus::END_OF_MESSAGE
266 } else {
267 PacketStatus::NORMAL
268 };
269
270 if is_first && reset_connection {
272 status |= PacketStatus::RESET_CONNECTION;
273 }
274
275 let header = PacketHeader::new(packet_type, status, 0);
276 let packet = Packet::new(header, BytesMut::from(chunk));
277
278 writer.send(packet).await?;
279 }
280
281 Ok(())
282 }
283
284 pub async fn flush(&mut self) -> Result<(), CodecError> {
286 let mut writer = self.writer.lock().await;
287 writer.flush().await
288 }
289
290 async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
295 tracing::debug!("draining packets after cancellation");
296
297 self.assembler.clear();
299
300 loop {
301 match self.reader.next().await {
302 Some(Ok(packet)) => {
303 if let Some(message) = self.assembler.push(packet) {
307 if message.packet_type == PacketType::TabularResult
308 && Self::payload_ends_with_attention_done(&message.payload)
309 {
310 tracing::debug!("received DONE with ATTENTION, cancellation complete");
311 self.finish_cancel();
312 return Err(CodecError::Cancelled);
313 }
314 tracing::debug!("discarding message from cancelled request");
315 }
316 }
318 Some(Err(e)) => {
319 self.cancelling
320 .store(false, std::sync::atomic::Ordering::Release);
321 return Err(e);
322 }
323 None => {
324 self.cancelling
327 .store(false, std::sync::atomic::Ordering::Release);
328 return Err(CodecError::ConnectionClosed);
329 }
330 }
331 }
332 }
333
334 fn finish_cancel(&self) {
336 self.cancelling
337 .store(false, std::sync::atomic::Ordering::Release);
338 self.cancel_notify.notify_waiters();
339 }
340
341 fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
354 let Some(start) = payload.len().checked_sub(13) else {
355 return false;
356 };
357 payload[start] == 0xFD
359 && u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
360 }
361
362 pub fn read_codec(&self) -> &TdsCodec {
364 self.reader.codec()
365 }
366
367 pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
369 self.reader.codec_mut()
370 }
371}
372
373impl<T> std::fmt::Debug for Connection<T>
374where
375 T: AsyncRead + AsyncWrite + std::fmt::Debug,
376{
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 f.debug_struct("Connection")
379 .field("cancelling", &self.is_cancelling())
380 .field("has_partial_message", &self.assembler.has_partial())
381 .finish_non_exhaustive()
382 }
383}
384
385pub struct CancelHandle<T>
390where
391 T: AsyncRead + AsyncWrite,
392{
393 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
394 notify: Arc<Notify>,
395 cancelling: Arc<std::sync::atomic::AtomicBool>,
396}
397
398impl<T> CancelHandle<T>
399where
400 T: AsyncRead + AsyncWrite + Unpin,
401{
402 pub async fn cancel(&self) -> Result<(), CodecError> {
407 self.cancelling
409 .store(true, std::sync::atomic::Ordering::Release);
410
411 tracing::debug!("sending Attention packet for query cancellation");
412
413 let mut writer = self.writer.lock().await;
415
416 let header = PacketHeader::new(
418 PacketType::Attention,
419 PacketStatus::END_OF_MESSAGE,
420 PACKET_HEADER_SIZE as u16,
421 );
422 let packet = Packet::new(header, BytesMut::new());
423
424 writer.send(packet).await?;
425 writer.flush().await?;
426
427 Ok(())
428 }
429
430 pub async fn wait_cancelled(&self) {
435 if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
436 self.notify.notified().await;
437 }
438 }
439
440 #[must_use]
442 pub fn is_cancelling(&self) -> bool {
443 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
444 }
445}
446
447impl<T> Clone for CancelHandle<T>
448where
449 T: AsyncRead + AsyncWrite,
450{
451 fn clone(&self) -> Self {
452 Self {
453 writer: Arc::clone(&self.writer),
454 notify: Arc::clone(&self.notify),
455 cancelling: Arc::clone(&self.cancelling),
456 }
457 }
458}
459
460impl<T> std::fmt::Debug for CancelHandle<T>
461where
462 T: AsyncRead + AsyncWrite + Unpin,
463{
464 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465 f.debug_struct("CancelHandle")
466 .field("cancelling", &self.is_cancelling())
467 .finish_non_exhaustive()
468 }
469}
470
471#[cfg(test)]
472#[allow(clippy::unwrap_used, clippy::expect_used)]
473mod tests {
474 use super::*;
475
476 #[tokio::test]
481 async fn test_send_empty_payload_emits_one_eom_packet() {
482 use tokio::io::AsyncReadExt;
483
484 let (client_io, mut server_io) = tokio::io::duplex(4096);
485 let mut conn = Connection::new(client_io);
486
487 conn.send_message(PacketType::SqlBatch, Bytes::new(), 4096)
488 .await
489 .expect("empty message should send");
490
491 let mut header = [0u8; PACKET_HEADER_SIZE];
493 server_io
494 .read_exact(&mut header)
495 .await
496 .expect("one header-only packet must be sent");
497 assert_eq!(header[0], PacketType::SqlBatch as u8);
498 assert!(
499 PacketStatus::from_bits_truncate(header[1]).contains(PacketStatus::END_OF_MESSAGE),
500 "the single packet must be flagged END_OF_MESSAGE"
501 );
502 let length = u16::from_be_bytes([header[2], header[3]]);
503 assert_eq!(
504 length as usize, PACKET_HEADER_SIZE,
505 "length must be header-only (no payload)"
506 );
507
508 drop(conn);
510 let mut rest = Vec::new();
511 server_io.read_to_end(&mut rest).await.expect("read rest");
512 assert!(rest.is_empty(), "no second packet may follow");
513 }
514
515 #[tokio::test]
518 async fn test_reset_flag_on_first_packet_only_across_multi_packet_send() {
519 use tokio::io::AsyncReadExt;
520
521 let (client_io, mut server_io) = tokio::io::duplex(4096);
522 let mut conn = Connection::new(client_io);
523
524 let payload = Bytes::from(vec![0xABu8; 12]);
526 conn.send_message_with_reset(PacketType::SqlBatch, payload, 16, true)
527 .await
528 .expect("multi-packet send should succeed");
529 drop(conn);
530
531 let mut all = Vec::new();
532 server_io.read_to_end(&mut all).await.expect("read packets");
533
534 let s0 = PacketStatus::from_bits_truncate(all[1]);
536 assert!(
537 s0.contains(PacketStatus::RESET_CONNECTION),
538 "first packet must carry RESET_CONNECTION"
539 );
540 assert!(
541 !s0.contains(PacketStatus::END_OF_MESSAGE),
542 "first packet of two must not be END_OF_MESSAGE"
543 );
544
545 let s1 = PacketStatus::from_bits_truncate(all[16 + 1]);
547 assert!(
548 !s1.contains(PacketStatus::RESET_CONNECTION),
549 "RESET_CONNECTION must not repeat on later packets"
550 );
551 assert!(
552 s1.contains(PacketStatus::END_OF_MESSAGE),
553 "last packet must be END_OF_MESSAGE"
554 );
555 assert_eq!(all.len(), 16 + 8 + 4, "exactly two packets must be sent");
556 }
557
558 #[test]
559 fn test_attention_packet_header() {
560 let header = PacketHeader::new(
562 PacketType::Attention,
563 PacketStatus::END_OF_MESSAGE,
564 PACKET_HEADER_SIZE as u16,
565 );
566
567 assert_eq!(header.packet_type, PacketType::Attention);
568 assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
569 assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
570 }
571
572 #[test]
573 fn test_check_attention_done() {
574 let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
580
581 let payload_with_attn = BytesMut::from(
583 &[
584 0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
585 ][..],
586 );
587 let packet_with_attn = Packet::new(header, payload_with_attn);
588
589 let payload_no_attn = BytesMut::from(
591 &[
592 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
593 ][..],
594 );
595 let packet_no_attn = Packet::new(header, payload_no_attn);
596
597 assert!(
598 Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
599 &packet_with_attn.payload
600 )
601 );
602 assert!(
603 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
604 &packet_no_attn.payload
605 )
606 );
607
608 let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
611 interior.extend_from_slice(&[
612 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
613 ]);
614 assert!(
615 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
616 );
617 }
618
619 fn raw_message(payload: &[u8]) -> Vec<u8> {
621 let mut v = vec![0x04, 0x01]; v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
623 v.extend_from_slice(&[0, 0, 1, 0]); v.extend_from_slice(payload);
625 v
626 }
627
628 fn done_token(status: u16) -> [u8; 13] {
630 let s = status.to_le_bytes();
631 [
632 0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
633 ]
634 }
635
636 #[tokio::test]
647 async fn test_cancel_mid_read_discards_cancelled_stream() {
648 use std::task::{Context, Poll};
649 use tokio::io::AsyncWriteExt;
650
651 let (client_io, mut server_io) = tokio::io::duplex(4096);
652 let mut conn = Connection::new(client_io);
653 let cancel = conn.cancel_handle();
654
655 let mut read_fut = Box::pin(conn.read_message());
659 let waker = std::task::Waker::noop();
660 let mut cx = Context::from_waker(waker);
661 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
662
663 cancel.cancel().await.expect("send attention");
666 server_io
667 .write_all(&raw_message(&done_token(0x0002))) .await
669 .unwrap();
670 server_io
671 .write_all(&raw_message(&done_token(0x0020))) .await
673 .unwrap();
674 server_io
675 .write_all(&raw_message(&done_token(0x0010))) .await
677 .unwrap();
678
679 let result = read_fut.await;
680 assert!(
681 matches!(result, Err(CodecError::Cancelled)),
682 "parked read must consume the cancelled stream and report \
683 Cancelled, got {result:?}"
684 );
685 assert!(!conn.is_cancelling(), "cancel flag must be cleared");
686
687 let message = conn
689 .read_message()
690 .await
691 .expect("next read")
692 .expect("next message");
693 assert_eq!(message.payload[0], 0xFD);
694 assert_eq!(
695 u16::from_le_bytes([message.payload[1], message.payload[2]]),
696 0x0010,
697 "next response must not be eaten by a stale drain"
698 );
699 }
700
701 #[tokio::test]
704 async fn test_cancel_before_read_drains_to_attention_ack() {
705 use tokio::io::AsyncWriteExt;
706
707 let (client_io, mut server_io) = tokio::io::duplex(4096);
708 let mut conn = Connection::new(client_io);
709 let cancel = conn.cancel_handle();
710
711 cancel.cancel().await.expect("send attention");
712 server_io
713 .write_all(&raw_message(&done_token(0x0022))) .await
715 .unwrap();
716 server_io
717 .write_all(&raw_message(&done_token(0x0010))) .await
719 .unwrap();
720
721 let result = conn.read_message().await;
722 assert!(matches!(result, Err(CodecError::Cancelled)));
723 assert!(!conn.is_cancelling());
724
725 let message = conn
726 .read_message()
727 .await
728 .expect("next read")
729 .expect("next message");
730 assert_eq!(
731 u16::from_le_bytes([message.payload[1], message.payload[2]]),
732 0x0010
733 );
734 }
735
736 #[tokio::test]
746 async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
747 use std::task::{Context, Poll};
748 use tokio::io::AsyncWriteExt;
749
750 let (client_io, mut server_io) = tokio::io::duplex(4096);
751 let mut conn = Connection::new(client_io);
752 let cancel = conn.cancel_handle();
753
754 let mut read_fut = Box::pin(conn.read_message());
756 let waker = std::task::Waker::noop();
757 let mut cx = Context::from_waker(waker);
758 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
759 cancel.cancel().await.expect("send attention");
760
761 let mut row_data = vec![0xD1, 0x08]; row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
766 row_data.extend_from_slice(&done_token(0x0001)); server_io.write_all(&raw_message(&row_data)).await.unwrap();
768
769 server_io
771 .write_all(&raw_message(&done_token(0x0020)))
772 .await
773 .unwrap();
774
775 server_io
777 .write_all(&raw_message(&done_token(0x0010)))
778 .await
779 .unwrap();
780
781 let result = read_fut.await;
782 assert!(
783 matches!(result, Err(CodecError::Cancelled)),
784 "cancelled read must end in Cancelled, got {result:?}"
785 );
786 assert!(!conn.is_cancelling());
787
788 let message = conn
791 .read_message()
792 .await
793 .expect("next read")
794 .expect("next message");
795 let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
796 assert_eq!(
797 status, 0x0010,
798 "next request's response must come through intact; 0x0020 means \
799 the interior row bytes were mistaken for the ack and the real \
800 ack leaked into the next request"
801 );
802 }
803
804 fn raw_packet_non_eom(payload: &[u8]) -> Vec<u8> {
806 let mut v = vec![0x04, 0x00]; v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
808 v.extend_from_slice(&[0, 0, 1, 0]);
809 v.extend_from_slice(payload);
810 v
811 }
812
813 #[tokio::test]
817 async fn test_max_message_size_cap_fires_mid_assembly() {
818 use tokio::io::AsyncWriteExt;
819 let (client, mut server) = tokio::io::duplex(1 << 16);
820 let mut conn = Connection::new(client);
821 conn.set_max_message_size(1000);
822
823 let chunk = vec![0xAA; 600];
824 let mut stream = raw_packet_non_eom(&chunk);
825 stream.extend_from_slice(&raw_packet_non_eom(&chunk)); server.write_all(&stream).await.unwrap();
827
828 let result = conn.read_message().await;
829 let Err(CodecError::MessageTooLarge { size, limit }) = result else {
830 unreachable!("expected MessageTooLarge, got {result:?}");
831 };
832 assert_eq!(limit, 1000);
833 assert!(size > 1000);
834 }
835
836 #[tokio::test]
839 async fn test_max_message_size_cap_fires_on_completing_packet() {
840 use tokio::io::AsyncWriteExt;
841 let (client, mut server) = tokio::io::duplex(1 << 16);
842 let mut conn = Connection::new(client);
843 conn.set_max_message_size(1000);
844
845 let chunk = vec![0xAA; 600];
846 let mut stream = raw_packet_non_eom(&chunk);
847 stream.extend_from_slice(&raw_message(&chunk)); server.write_all(&stream).await.unwrap();
849
850 assert!(matches!(
851 conn.read_message().await,
852 Err(CodecError::MessageTooLarge { .. })
853 ));
854 }
855
856 #[tokio::test]
858 async fn test_max_message_size_zero_is_unlimited() {
859 use tokio::io::AsyncWriteExt;
860 let (client, mut server) = tokio::io::duplex(1 << 16);
861 let mut conn = Connection::new(client);
862
863 let payload = vec![0xAA; 5000];
864 server.write_all(&raw_message(&payload)).await.unwrap();
865 let msg = conn.read_message().await.unwrap().unwrap();
866 assert_eq!(msg.payload.len(), 5000);
867 }
868}