1use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19pub struct Connection<T>
47where
48 T: AsyncRead + AsyncWrite,
49{
50 reader: PacketReader<ReadHalf<T>>,
52 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54 assembler: MessageAssembler,
56 cancel_notify: Arc<Notify>,
58 cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64 T: AsyncRead + AsyncWrite,
65{
66 pub fn new(transport: T) -> Self {
70 let (read_half, write_half) = tokio::io::split(transport);
71
72 Self {
73 reader: PacketReader::new(read_half),
74 writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75 assembler: MessageAssembler::new(),
76 cancel_notify: Arc::new(Notify::new()),
77 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78 }
79 }
80
81 pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83 let (read_half, write_half) = tokio::io::split(transport);
84
85 Self {
86 reader: PacketReader::with_codec(read_half, read_codec),
87 writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88 write_half,
89 write_codec,
90 ))),
91 assembler: MessageAssembler::new(),
92 cancel_notify: Arc::new(Notify::new()),
93 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94 }
95 }
96
97 #[must_use]
101 pub fn cancel_handle(&self) -> CancelHandle<T> {
102 CancelHandle {
103 writer: Arc::clone(&self.writer),
104 notify: Arc::clone(&self.cancel_notify),
105 cancelling: Arc::clone(&self.cancelling),
106 }
107 }
108
109 #[must_use]
111 pub fn is_cancelling(&self) -> bool {
112 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113 }
114
115 pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
119 loop {
120 if self.is_cancelling() {
122 return self.drain_after_cancel().await;
124 }
125
126 match self.reader.next().await {
127 Some(Ok(packet)) => {
128 if let Some(message) = self.assembler.push(packet) {
129 return Ok(Some(message));
130 }
131 }
133 Some(Err(e)) => return Err(e),
134 None => {
135 if self.assembler.has_partial() {
137 return Err(CodecError::ConnectionClosed);
138 }
139 return Ok(None);
140 }
141 }
142 }
143 }
144
145 pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
149 match self.reader.next().await {
150 Some(result) => result.map(Some),
151 None => Ok(None),
152 }
153 }
154
155 pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
157 let mut writer = self.writer.lock().await;
158 writer.send(packet).await
159 }
160
161 pub async fn send_message(
163 &mut self,
164 packet_type: PacketType,
165 payload: Bytes,
166 max_packet_size: usize,
167 ) -> Result<(), CodecError> {
168 let max_payload = max_packet_size - PACKET_HEADER_SIZE;
169 let chunks: Vec<_> = payload.chunks(max_payload).collect();
170 let total_chunks = chunks.len();
171
172 let mut writer = self.writer.lock().await;
173
174 for (i, chunk) in chunks.into_iter().enumerate() {
175 let is_last = i == total_chunks - 1;
176 let status = if is_last {
177 PacketStatus::END_OF_MESSAGE
178 } else {
179 PacketStatus::NORMAL
180 };
181
182 let header = PacketHeader::new(packet_type, status, 0);
183 let packet = Packet::new(header, BytesMut::from(chunk));
184
185 writer.send(packet).await?;
186 }
187
188 Ok(())
189 }
190
191 pub async fn flush(&mut self) -> Result<(), CodecError> {
193 let mut writer = self.writer.lock().await;
194 writer.flush().await
195 }
196
197 async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
199 tracing::debug!("draining packets after cancellation");
200
201 self.assembler.clear();
203
204 loop {
205 match self.reader.next().await {
206 Some(Ok(packet)) => {
207 if packet.header.packet_type == PacketType::TabularResult
210 && !packet.payload.is_empty()
211 {
212 if self.check_attention_done(&packet) {
216 tracing::debug!("received DONE with ATTENTION, cancellation complete");
217 self.cancelling
218 .store(false, std::sync::atomic::Ordering::Release);
219 self.cancel_notify.notify_waiters();
220 return Ok(None);
221 }
222 }
223 }
225 Some(Err(e)) => {
226 self.cancelling
227 .store(false, std::sync::atomic::Ordering::Release);
228 return Err(e);
229 }
230 None => {
231 self.cancelling
232 .store(false, std::sync::atomic::Ordering::Release);
233 return Ok(None);
234 }
235 }
236 }
237 }
238
239 fn check_attention_done(&self, packet: &Packet) -> bool {
241 let payload = &packet.payload;
244
245 for i in 0..payload.len() {
246 if payload[i] == 0xFD && i + 3 <= payload.len() {
247 let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
249 if status & 0x0020 != 0 {
251 return true;
252 }
253 }
254 }
255
256 false
257 }
258
259 pub fn read_codec(&self) -> &TdsCodec {
261 self.reader.codec()
262 }
263
264 pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
266 self.reader.codec_mut()
267 }
268}
269
270impl<T> std::fmt::Debug for Connection<T>
271where
272 T: AsyncRead + AsyncWrite + std::fmt::Debug,
273{
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("Connection")
276 .field("cancelling", &self.is_cancelling())
277 .field("has_partial_message", &self.assembler.has_partial())
278 .finish_non_exhaustive()
279 }
280}
281
282pub struct CancelHandle<T>
287where
288 T: AsyncRead + AsyncWrite,
289{
290 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
291 notify: Arc<Notify>,
292 cancelling: Arc<std::sync::atomic::AtomicBool>,
293}
294
295impl<T> CancelHandle<T>
296where
297 T: AsyncRead + AsyncWrite + Unpin,
298{
299 pub async fn cancel(&self) -> Result<(), CodecError> {
304 self.cancelling
306 .store(true, std::sync::atomic::Ordering::Release);
307
308 tracing::debug!("sending Attention packet for query cancellation");
309
310 let mut writer = self.writer.lock().await;
312
313 let header = PacketHeader::new(
315 PacketType::Attention,
316 PacketStatus::END_OF_MESSAGE,
317 PACKET_HEADER_SIZE as u16,
318 );
319 let packet = Packet::new(header, BytesMut::new());
320
321 writer.send(packet).await?;
322 writer.flush().await?;
323
324 Ok(())
325 }
326
327 pub async fn wait_cancelled(&self) {
332 if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
333 self.notify.notified().await;
334 }
335 }
336
337 #[must_use]
339 pub fn is_cancelling(&self) -> bool {
340 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
341 }
342}
343
344impl<T> Clone for CancelHandle<T>
345where
346 T: AsyncRead + AsyncWrite,
347{
348 fn clone(&self) -> Self {
349 Self {
350 writer: Arc::clone(&self.writer),
351 notify: Arc::clone(&self.notify),
352 cancelling: Arc::clone(&self.cancelling),
353 }
354 }
355}
356
357impl<T> std::fmt::Debug for CancelHandle<T>
358where
359 T: AsyncRead + AsyncWrite + Unpin,
360{
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 f.debug_struct("CancelHandle")
363 .field("cancelling", &self.is_cancelling())
364 .finish_non_exhaustive()
365 }
366}
367
368#[cfg(test)]
369#[allow(clippy::unwrap_used)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_attention_packet_header() {
375 let header = PacketHeader::new(
377 PacketType::Attention,
378 PacketStatus::END_OF_MESSAGE,
379 PACKET_HEADER_SIZE as u16,
380 );
381
382 assert_eq!(header.packet_type, PacketType::Attention);
383 assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
384 assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
385 }
386
387 #[test]
388 fn test_check_attention_done() {
389 let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
395
396 let payload_with_attn = BytesMut::from(
398 &[
399 0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
400 ][..],
401 );
402 let packet_with_attn = Packet::new(header, payload_with_attn);
403
404 let payload_no_attn = BytesMut::from(
406 &[
407 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
408 ][..],
409 );
410 let packet_no_attn = Packet::new(header, payload_no_attn);
411
412 let check_done = |packet: &Packet| -> bool {
415 let payload = &packet.payload;
416 for i in 0..payload.len() {
417 if payload[i] == 0xFD && i + 3 <= payload.len() {
418 let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
419 if status & 0x0020 != 0 {
420 return true;
421 }
422 }
423 }
424 false
425 };
426
427 assert!(check_done(&packet_with_attn));
428 assert!(!check_done(&packet_no_attn));
429 }
430}