1use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19pub struct Connection<T>
47where
48 T: AsyncRead + AsyncWrite,
49{
50 reader: PacketReader<ReadHalf<T>>,
52 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54 assembler: MessageAssembler,
56 cancel_notify: Arc<Notify>,
58 cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64 T: AsyncRead + AsyncWrite,
65{
66 pub fn new(transport: T) -> Self {
70 let (read_half, write_half) = tokio::io::split(transport);
71
72 Self {
73 reader: PacketReader::new(read_half),
74 writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75 assembler: MessageAssembler::new(),
76 cancel_notify: Arc::new(Notify::new()),
77 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78 }
79 }
80
81 pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83 let (read_half, write_half) = tokio::io::split(transport);
84
85 Self {
86 reader: PacketReader::with_codec(read_half, read_codec),
87 writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88 write_half,
89 write_codec,
90 ))),
91 assembler: MessageAssembler::new(),
92 cancel_notify: Arc::new(Notify::new()),
93 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94 }
95 }
96
97 #[must_use]
101 pub fn cancel_handle(&self) -> CancelHandle<T> {
102 CancelHandle {
103 writer: Arc::clone(&self.writer),
104 notify: Arc::clone(&self.cancel_notify),
105 cancelling: Arc::clone(&self.cancelling),
106 }
107 }
108
109 #[must_use]
111 pub fn is_cancelling(&self) -> bool {
112 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113 }
114
115 pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
119 loop {
120 if self.is_cancelling() {
122 return self.drain_after_cancel().await;
124 }
125
126 match self.reader.next().await {
127 Some(Ok(packet)) => {
128 if let Some(message) = self.assembler.push(packet) {
129 return Ok(Some(message));
130 }
131 }
133 Some(Err(e)) => return Err(e),
134 None => {
135 if self.assembler.has_partial() {
137 return Err(CodecError::ConnectionClosed);
138 }
139 return Ok(None);
140 }
141 }
142 }
143 }
144
145 pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
149 match self.reader.next().await {
150 Some(result) => result.map(Some),
151 None => Ok(None),
152 }
153 }
154
155 pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
157 let mut writer = self.writer.lock().await;
158 writer.send(packet).await
159 }
160
161 pub async fn send_message(
168 &mut self,
169 packet_type: PacketType,
170 payload: Bytes,
171 max_packet_size: usize,
172 ) -> Result<(), CodecError> {
173 self.send_message_with_reset(packet_type, payload, max_packet_size, false)
174 .await
175 }
176
177 pub async fn send_message_with_reset(
184 &mut self,
185 packet_type: PacketType,
186 payload: Bytes,
187 max_packet_size: usize,
188 reset_connection: bool,
189 ) -> Result<(), CodecError> {
190 let max_payload = max_packet_size - PACKET_HEADER_SIZE;
191 let chunks: Vec<_> = payload.chunks(max_payload).collect();
192 let total_chunks = chunks.len();
193
194 let mut writer = self.writer.lock().await;
195
196 for (i, chunk) in chunks.into_iter().enumerate() {
197 let is_first = i == 0;
198 let is_last = i == total_chunks - 1;
199
200 let mut status = if is_last {
202 PacketStatus::END_OF_MESSAGE
203 } else {
204 PacketStatus::NORMAL
205 };
206
207 if is_first && reset_connection {
209 status |= PacketStatus::RESET_CONNECTION;
210 }
211
212 let header = PacketHeader::new(packet_type, status, 0);
213 let packet = Packet::new(header, BytesMut::from(chunk));
214
215 writer.send(packet).await?;
216 }
217
218 Ok(())
219 }
220
221 pub async fn flush(&mut self) -> Result<(), CodecError> {
223 let mut writer = self.writer.lock().await;
224 writer.flush().await
225 }
226
227 async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
229 tracing::debug!("draining packets after cancellation");
230
231 self.assembler.clear();
233
234 loop {
235 match self.reader.next().await {
236 Some(Ok(packet)) => {
237 if packet.header.packet_type == PacketType::TabularResult
240 && !packet.payload.is_empty()
241 {
242 if self.check_attention_done(&packet) {
246 tracing::debug!("received DONE with ATTENTION, cancellation complete");
247 self.cancelling
248 .store(false, std::sync::atomic::Ordering::Release);
249 self.cancel_notify.notify_waiters();
250 return Ok(None);
251 }
252 }
253 }
255 Some(Err(e)) => {
256 self.cancelling
257 .store(false, std::sync::atomic::Ordering::Release);
258 return Err(e);
259 }
260 None => {
261 self.cancelling
262 .store(false, std::sync::atomic::Ordering::Release);
263 return Ok(None);
264 }
265 }
266 }
267 }
268
269 fn check_attention_done(&self, packet: &Packet) -> bool {
271 let payload = &packet.payload;
274
275 for i in 0..payload.len() {
276 if payload[i] == 0xFD && i + 3 <= payload.len() {
277 let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
279 if status & 0x0020 != 0 {
281 return true;
282 }
283 }
284 }
285
286 false
287 }
288
289 pub fn read_codec(&self) -> &TdsCodec {
291 self.reader.codec()
292 }
293
294 pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
296 self.reader.codec_mut()
297 }
298}
299
300impl<T> std::fmt::Debug for Connection<T>
301where
302 T: AsyncRead + AsyncWrite + std::fmt::Debug,
303{
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 f.debug_struct("Connection")
306 .field("cancelling", &self.is_cancelling())
307 .field("has_partial_message", &self.assembler.has_partial())
308 .finish_non_exhaustive()
309 }
310}
311
312pub struct CancelHandle<T>
317where
318 T: AsyncRead + AsyncWrite,
319{
320 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
321 notify: Arc<Notify>,
322 cancelling: Arc<std::sync::atomic::AtomicBool>,
323}
324
325impl<T> CancelHandle<T>
326where
327 T: AsyncRead + AsyncWrite + Unpin,
328{
329 pub async fn cancel(&self) -> Result<(), CodecError> {
334 self.cancelling
336 .store(true, std::sync::atomic::Ordering::Release);
337
338 tracing::debug!("sending Attention packet for query cancellation");
339
340 let mut writer = self.writer.lock().await;
342
343 let header = PacketHeader::new(
345 PacketType::Attention,
346 PacketStatus::END_OF_MESSAGE,
347 PACKET_HEADER_SIZE as u16,
348 );
349 let packet = Packet::new(header, BytesMut::new());
350
351 writer.send(packet).await?;
352 writer.flush().await?;
353
354 Ok(())
355 }
356
357 pub async fn wait_cancelled(&self) {
362 if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
363 self.notify.notified().await;
364 }
365 }
366
367 #[must_use]
369 pub fn is_cancelling(&self) -> bool {
370 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
371 }
372}
373
374impl<T> Clone for CancelHandle<T>
375where
376 T: AsyncRead + AsyncWrite,
377{
378 fn clone(&self) -> Self {
379 Self {
380 writer: Arc::clone(&self.writer),
381 notify: Arc::clone(&self.notify),
382 cancelling: Arc::clone(&self.cancelling),
383 }
384 }
385}
386
387impl<T> std::fmt::Debug for CancelHandle<T>
388where
389 T: AsyncRead + AsyncWrite + Unpin,
390{
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.debug_struct("CancelHandle")
393 .field("cancelling", &self.is_cancelling())
394 .finish_non_exhaustive()
395 }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_attention_packet_header() {
405 let header = PacketHeader::new(
407 PacketType::Attention,
408 PacketStatus::END_OF_MESSAGE,
409 PACKET_HEADER_SIZE as u16,
410 );
411
412 assert_eq!(header.packet_type, PacketType::Attention);
413 assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
414 assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
415 }
416
417 #[test]
418 fn test_check_attention_done() {
419 let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
425
426 let payload_with_attn = BytesMut::from(
428 &[
429 0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
430 ][..],
431 );
432 let packet_with_attn = Packet::new(header, payload_with_attn);
433
434 let payload_no_attn = BytesMut::from(
436 &[
437 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
438 ][..],
439 );
440 let packet_no_attn = Packet::new(header, payload_no_attn);
441
442 let check_done = |packet: &Packet| -> bool {
445 let payload = &packet.payload;
446 for i in 0..payload.len() {
447 if payload[i] == 0xFD && i + 3 <= payload.len() {
448 let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
449 if status & 0x0020 != 0 {
450 return true;
451 }
452 }
453 }
454 false
455 };
456
457 assert!(check_done(&packet_with_attn));
458 assert!(!check_done(&packet_no_attn));
459 }
460}