1use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use bytes::BytesMut;
14use futures_core::Stream;
15use futures_util::Sink;
16use pin_project_lite::pin_project;
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_util::codec::{Framed, FramedRead, FramedWrite};
19
20use crate::error::CodecError;
21use crate::packet_codec::{Packet, TdsCodec};
22
23pin_project! {
24 pub struct PacketStream<T> {
29 #[pin]
30 inner: Framed<T, TdsCodec>,
31 }
32}
33
34impl<T> PacketStream<T>
35where
36 T: AsyncRead + AsyncWrite,
37{
38 pub fn new(transport: T) -> Self {
40 Self {
41 inner: Framed::new(transport, TdsCodec::new()),
42 }
43 }
44
45 pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
47 Self {
48 inner: Framed::new(transport, codec),
49 }
50 }
51
52 pub fn get_ref(&self) -> &T {
54 self.inner.get_ref()
55 }
56
57 pub fn get_mut(&mut self) -> &mut T {
59 self.inner.get_mut()
60 }
61
62 pub fn codec(&self) -> &TdsCodec {
64 self.inner.codec()
65 }
66
67 pub fn codec_mut(&mut self) -> &mut TdsCodec {
69 self.inner.codec_mut()
70 }
71
72 pub fn into_inner(self) -> T {
74 self.inner.into_inner()
75 }
76
77 pub fn read_buffer(&self) -> &BytesMut {
79 self.inner.read_buffer()
80 }
81
82 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
84 self.inner.read_buffer_mut()
85 }
86}
87
88impl<T> Stream for PacketStream<T>
89where
90 T: AsyncRead + Unpin,
91{
92 type Item = Result<Packet, CodecError>;
93
94 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 self.project().inner.poll_next(cx)
96 }
97}
98
99impl<T> Sink<Packet> for PacketStream<T>
100where
101 T: AsyncWrite + Unpin,
102{
103 type Error = CodecError;
104
105 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 self.project().inner.poll_ready(cx)
107 }
108
109 fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
110 self.project().inner.start_send(item)
111 }
112
113 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114 self.project().inner.poll_flush(cx)
115 }
116
117 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118 self.project().inner.poll_close(cx)
119 }
120}
121
122impl<T> std::fmt::Debug for PacketStream<T>
123where
124 T: std::fmt::Debug,
125{
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("PacketStream")
128 .field("transport", self.inner.get_ref())
129 .finish()
130 }
131}
132
133pin_project! {
138 pub struct PacketReader<T> {
143 #[pin]
144 inner: FramedRead<T, TdsCodec>,
145 }
146}
147
148impl<T> PacketReader<T>
149where
150 T: AsyncRead,
151{
152 pub fn new(transport: T) -> Self {
154 Self {
155 inner: FramedRead::new(transport, TdsCodec::new()),
156 }
157 }
158
159 pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
161 Self {
162 inner: FramedRead::new(transport, codec),
163 }
164 }
165
166 pub fn get_ref(&self) -> &T {
168 self.inner.get_ref()
169 }
170
171 pub fn get_mut(&mut self) -> &mut T {
173 self.inner.get_mut()
174 }
175
176 pub fn codec(&self) -> &TdsCodec {
178 self.inner.decoder()
179 }
180
181 pub fn codec_mut(&mut self) -> &mut TdsCodec {
183 self.inner.decoder_mut()
184 }
185
186 pub fn read_buffer(&self) -> &BytesMut {
188 self.inner.read_buffer()
189 }
190
191 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
193 self.inner.read_buffer_mut()
194 }
195}
196
197impl<T> Stream for PacketReader<T>
198where
199 T: AsyncRead + Unpin,
200{
201 type Item = Result<Packet, CodecError>;
202
203 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 self.project().inner.poll_next(cx)
205 }
206}
207
208impl<T> std::fmt::Debug for PacketReader<T>
209where
210 T: std::fmt::Debug,
211{
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.debug_struct("PacketReader")
214 .field("transport", self.inner.get_ref())
215 .finish()
216 }
217}
218
219pin_project! {
220 pub struct PacketWriter<T> {
225 #[pin]
226 inner: FramedWrite<T, TdsCodec>,
227 }
228}
229
230impl<T> PacketWriter<T>
231where
232 T: AsyncWrite,
233{
234 pub fn new(transport: T) -> Self {
236 Self {
237 inner: FramedWrite::new(transport, TdsCodec::new()),
238 }
239 }
240
241 pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
243 Self {
244 inner: FramedWrite::new(transport, codec),
245 }
246 }
247
248 pub fn get_ref(&self) -> &T {
250 self.inner.get_ref()
251 }
252
253 pub fn get_mut(&mut self) -> &mut T {
255 self.inner.get_mut()
256 }
257
258 pub fn codec(&self) -> &TdsCodec {
260 self.inner.encoder()
261 }
262
263 pub fn codec_mut(&mut self) -> &mut TdsCodec {
265 self.inner.encoder_mut()
266 }
267}
268
269impl<T> Sink<Packet> for PacketWriter<T>
270where
271 T: AsyncWrite + Unpin,
272{
273 type Error = CodecError;
274
275 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 self.project().inner.poll_ready(cx)
277 }
278
279 fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
280 self.project().inner.start_send(item)
281 }
282
283 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
284 self.project().inner.poll_flush(cx)
285 }
286
287 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
288 self.project().inner.poll_close(cx)
289 }
290}
291
292impl<T> std::fmt::Debug for PacketWriter<T>
293where
294 T: std::fmt::Debug,
295{
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("PacketWriter")
298 .field("transport", self.inner.get_ref())
299 .finish()
300 }
301}