mssql_codec/
framed.rs

1//! Framed packet stream for async I/O.
2//!
3//! This module provides both combined and split stream types:
4//! - `PacketStream<T>` - Combined read/write stream for bidirectional I/O
5//! - `PacketReader<T>` - Read-only stream for receiving packets
6//! - `PacketWriter<T>` - Write-only sink for sending packets
7//!
8//! The split types are used by `Connection` for cancellation safety (ADR-005).
9
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use bytes::BytesMut;
14use futures_core::Stream;
15use futures_util::Sink;
16use pin_project_lite::pin_project;
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_util::codec::{Framed, FramedRead, FramedWrite};
19
20use crate::error::CodecError;
21use crate::packet_codec::{Packet, TdsCodec};
22
23pin_project! {
24    /// A framed packet stream over an async I/O transport.
25    ///
26    /// This wraps a tokio-util `Framed` codec and provides a higher-level
27    /// interface for sending and receiving TDS packets.
28    pub struct PacketStream<T> {
29        #[pin]
30        inner: Framed<T, TdsCodec>,
31    }
32}
33
34impl<T> PacketStream<T>
35where
36    T: AsyncRead + AsyncWrite,
37{
38    /// Create a new packet stream over the given transport.
39    pub fn new(transport: T) -> Self {
40        Self {
41            inner: Framed::new(transport, TdsCodec::new()),
42        }
43    }
44
45    /// Create a new packet stream with a custom codec.
46    pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
47        Self {
48            inner: Framed::new(transport, codec),
49        }
50    }
51
52    /// Get a reference to the underlying transport.
53    pub fn get_ref(&self) -> &T {
54        self.inner.get_ref()
55    }
56
57    /// Get a mutable reference to the underlying transport.
58    pub fn get_mut(&mut self) -> &mut T {
59        self.inner.get_mut()
60    }
61
62    /// Get a reference to the codec.
63    pub fn codec(&self) -> &TdsCodec {
64        self.inner.codec()
65    }
66
67    /// Get a mutable reference to the codec.
68    pub fn codec_mut(&mut self) -> &mut TdsCodec {
69        self.inner.codec_mut()
70    }
71
72    /// Consume the stream and return the underlying transport.
73    pub fn into_inner(self) -> T {
74        self.inner.into_inner()
75    }
76
77    /// Get a reference to the read buffer.
78    pub fn read_buffer(&self) -> &BytesMut {
79        self.inner.read_buffer()
80    }
81
82    /// Get a mutable reference to the read buffer.
83    pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
84        self.inner.read_buffer_mut()
85    }
86}
87
88impl<T> Stream for PacketStream<T>
89where
90    T: AsyncRead + Unpin,
91{
92    type Item = Result<Packet, CodecError>;
93
94    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        self.project().inner.poll_next(cx)
96    }
97}
98
99impl<T> Sink<Packet> for PacketStream<T>
100where
101    T: AsyncWrite + Unpin,
102{
103    type Error = CodecError;
104
105    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        self.project().inner.poll_ready(cx)
107    }
108
109    fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
110        self.project().inner.start_send(item)
111    }
112
113    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        self.project().inner.poll_flush(cx)
115    }
116
117    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        self.project().inner.poll_close(cx)
119    }
120}
121
122impl<T> std::fmt::Debug for PacketStream<T>
123where
124    T: std::fmt::Debug,
125{
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("PacketStream")
128            .field("transport", self.inner.get_ref())
129            .finish()
130    }
131}
132
133// =============================================================================
134// Split stream types for cancellation safety (ADR-005)
135// =============================================================================
136
137pin_project! {
138    /// A read-only packet stream for receiving TDS packets.
139    ///
140    /// This is used for the read half of a split connection, enabling
141    /// cancellation safety per ADR-005.
142    pub struct PacketReader<T> {
143        #[pin]
144        inner: FramedRead<T, TdsCodec>,
145    }
146}
147
148impl<T> PacketReader<T>
149where
150    T: AsyncRead,
151{
152    /// Create a new packet reader over the given transport.
153    pub fn new(transport: T) -> Self {
154        Self {
155            inner: FramedRead::new(transport, TdsCodec::new()),
156        }
157    }
158
159    /// Create a new packet reader with a custom codec.
160    pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
161        Self {
162            inner: FramedRead::new(transport, codec),
163        }
164    }
165
166    /// Get a reference to the underlying transport.
167    pub fn get_ref(&self) -> &T {
168        self.inner.get_ref()
169    }
170
171    /// Get a mutable reference to the underlying transport.
172    pub fn get_mut(&mut self) -> &mut T {
173        self.inner.get_mut()
174    }
175
176    /// Get a reference to the codec.
177    pub fn codec(&self) -> &TdsCodec {
178        self.inner.decoder()
179    }
180
181    /// Get a mutable reference to the codec.
182    pub fn codec_mut(&mut self) -> &mut TdsCodec {
183        self.inner.decoder_mut()
184    }
185
186    /// Get a reference to the read buffer.
187    pub fn read_buffer(&self) -> &BytesMut {
188        self.inner.read_buffer()
189    }
190
191    /// Get a mutable reference to the read buffer.
192    pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
193        self.inner.read_buffer_mut()
194    }
195}
196
197impl<T> Stream for PacketReader<T>
198where
199    T: AsyncRead + Unpin,
200{
201    type Item = Result<Packet, CodecError>;
202
203    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        self.project().inner.poll_next(cx)
205    }
206}
207
208impl<T> std::fmt::Debug for PacketReader<T>
209where
210    T: std::fmt::Debug,
211{
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("PacketReader")
214            .field("transport", self.inner.get_ref())
215            .finish()
216    }
217}
218
219pin_project! {
220    /// A write-only packet sink for sending TDS packets.
221    ///
222    /// This is used for the write half of a split connection, enabling
223    /// cancellation safety per ADR-005.
224    pub struct PacketWriter<T> {
225        #[pin]
226        inner: FramedWrite<T, TdsCodec>,
227    }
228}
229
230impl<T> PacketWriter<T>
231where
232    T: AsyncWrite,
233{
234    /// Create a new packet writer over the given transport.
235    pub fn new(transport: T) -> Self {
236        Self {
237            inner: FramedWrite::new(transport, TdsCodec::new()),
238        }
239    }
240
241    /// Create a new packet writer with a custom codec.
242    pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
243        Self {
244            inner: FramedWrite::new(transport, codec),
245        }
246    }
247
248    /// Get a reference to the underlying transport.
249    pub fn get_ref(&self) -> &T {
250        self.inner.get_ref()
251    }
252
253    /// Get a mutable reference to the underlying transport.
254    pub fn get_mut(&mut self) -> &mut T {
255        self.inner.get_mut()
256    }
257
258    /// Get a reference to the codec.
259    pub fn codec(&self) -> &TdsCodec {
260        self.inner.encoder()
261    }
262
263    /// Get a mutable reference to the codec.
264    pub fn codec_mut(&mut self) -> &mut TdsCodec {
265        self.inner.encoder_mut()
266    }
267}
268
269impl<T> Sink<Packet> for PacketWriter<T>
270where
271    T: AsyncWrite + Unpin,
272{
273    type Error = CodecError;
274
275    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276        self.project().inner.poll_ready(cx)
277    }
278
279    fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
280        self.project().inner.start_send(item)
281    }
282
283    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
284        self.project().inner.poll_flush(cx)
285    }
286
287    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
288        self.project().inner.poll_close(cx)
289    }
290}
291
292impl<T> std::fmt::Debug for PacketWriter<T>
293where
294    T: std::fmt::Debug,
295{
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("PacketWriter")
298            .field("transport", self.inner.get_ref())
299            .finish()
300    }
301}