use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use futures_core::Stream;
use futures_util::Sink;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
use crate::error::CodecError;
use crate::packet_codec::{Packet, TdsCodec};
pin_project! {
pub struct PacketStream<T> {
#[pin]
inner: Framed<T, TdsCodec>,
}
}
impl<T> PacketStream<T>
where
T: AsyncRead + AsyncWrite,
{
pub fn new(transport: T) -> Self {
Self {
inner: Framed::new(transport, TdsCodec::new()),
}
}
pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
Self {
inner: Framed::new(transport, codec),
}
}
pub fn get_ref(&self) -> &T {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
pub fn codec(&self) -> &TdsCodec {
self.inner.codec()
}
pub fn codec_mut(&mut self) -> &mut TdsCodec {
self.inner.codec_mut()
}
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
pub fn read_buffer(&self) -> &BytesMut {
self.inner.read_buffer()
}
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
self.inner.read_buffer_mut()
}
}
impl<T> Stream for PacketStream<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Packet, CodecError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl<T> Sink<Packet> for PacketStream<T>
where
T: AsyncWrite + Unpin,
{
type Error = CodecError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T> std::fmt::Debug for PacketStream<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketStream")
.field("transport", self.inner.get_ref())
.finish()
}
}
pin_project! {
pub struct PacketReader<T> {
#[pin]
inner: FramedRead<T, TdsCodec>,
}
}
impl<T> PacketReader<T>
where
T: AsyncRead,
{
pub fn new(transport: T) -> Self {
Self {
inner: FramedRead::new(transport, TdsCodec::new()),
}
}
pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
Self {
inner: FramedRead::new(transport, codec),
}
}
pub fn get_ref(&self) -> &T {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
pub fn codec(&self) -> &TdsCodec {
self.inner.decoder()
}
pub fn codec_mut(&mut self) -> &mut TdsCodec {
self.inner.decoder_mut()
}
pub fn read_buffer(&self) -> &BytesMut {
self.inner.read_buffer()
}
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
self.inner.read_buffer_mut()
}
}
impl<T> Stream for PacketReader<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Packet, CodecError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl<T> std::fmt::Debug for PacketReader<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketReader")
.field("transport", self.inner.get_ref())
.finish()
}
}
pin_project! {
pub struct PacketWriter<T> {
#[pin]
inner: FramedWrite<T, TdsCodec>,
}
}
impl<T> PacketWriter<T>
where
T: AsyncWrite,
{
pub fn new(transport: T) -> Self {
Self {
inner: FramedWrite::new(transport, TdsCodec::new()),
}
}
pub fn with_codec(transport: T, codec: TdsCodec) -> Self {
Self {
inner: FramedWrite::new(transport, codec),
}
}
pub fn get_ref(&self) -> &T {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
pub fn codec(&self) -> &TdsCodec {
self.inner.encoder()
}
pub fn codec_mut(&mut self) -> &mut TdsCodec {
self.inner.encoder_mut()
}
}
impl<T> Sink<Packet> for PacketWriter<T>
where
T: AsyncWrite + Unpin,
{
type Error = CodecError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T> std::fmt::Debug for PacketWriter<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketWriter")
.field("transport", self.inner.get_ref())
.finish()
}
}