use bytes::{BufMut, Bytes, BytesMut, buf::UninitSlice};
#[cfg(feature = "async")]
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
#[cfg(feature = "sync")]
use std::io::{Read, Write};
use crate::message::backend::{self, PgMessage};
pub struct PgConnection<S> {
stream: S,
buf: BytesMut,
}
impl<S> PgConnection<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
buf: BytesMut::with_capacity(4096),
}
}
pub fn with_capacity(stream: S, capacity: usize) -> Self {
Self {
stream,
buf: BytesMut::with_capacity(capacity),
}
}
pub fn take_buf(&mut self) -> Bytes {
self.buf.split().freeze()
}
pub fn has_pending(&self) -> bool {
!self.buf.is_empty()
}
pub fn pending_len(&self) -> usize {
self.buf.len()
}
pub fn into_parts(self) -> (S, BytesMut) {
(self.stream, self.buf)
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
}
unsafe impl<S> BufMut for PgConnection<S> {
fn remaining_mut(&self) -> usize {
self.buf.remaining_mut()
}
unsafe fn advance_mut(&mut self, cnt: usize) {
unsafe { self.buf.advance_mut(cnt) }
}
fn chunk_mut(&mut self) -> &mut UninitSlice {
self.buf.chunk_mut()
}
}
#[cfg(feature = "async")]
impl<S: AsyncWrite + Unpin> PgConnection<S> {
pub async fn flush(&mut self) -> std::io::Result<()> {
if !self.buf.is_empty() {
self.stream.write_all(&self.buf).await?;
self.buf.clear();
}
self.stream.flush().await
}
pub async fn write_raw(&mut self, bytes: &[u8]) -> std::io::Result<()> {
self.stream.write_all(bytes).await
}
}
#[cfg(feature = "async")]
impl<S: AsyncRead + Unpin> PgConnection<S> {
pub async fn recv(&mut self) -> std::io::Result<PgMessage> {
backend::read_message(&mut self.stream).await
}
}
#[cfg(feature = "sync")]
impl<S: Write> PgConnection<S> {
pub fn flush_sync(&mut self) -> std::io::Result<()> {
if !self.buf.is_empty() {
self.stream.write_all(&self.buf)?;
self.buf.clear();
}
self.stream.flush()
}
pub fn write_raw_sync(&mut self, bytes: &[u8]) -> std::io::Result<()> {
self.stream.write_all(bytes)
}
}
#[cfg(feature = "sync")]
impl<S: Read> PgConnection<S> {
pub fn recv_sync(&mut self) -> std::io::Result<PgMessage> {
backend::read_message_sync(&mut self.stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::PgProtocol;
#[test]
fn test_frontend_message_methods() {
let stream: Vec<u8> = vec![];
let mut conn = PgConnection::new(stream);
conn.query("SELECT 1");
assert!(conn.has_pending());
assert!(conn.pending_len() > 0);
}
#[test]
fn test_take_buf() {
let stream: Vec<u8> = vec![];
let mut conn = PgConnection::new(stream);
conn.sync();
let bytes = conn.take_buf();
assert!(!bytes.is_empty());
assert!(!conn.has_pending());
}
#[test]
fn test_into_parts() {
let stream: Vec<u8> = vec![];
let mut conn = PgConnection::new(stream);
conn.query("test");
let (stream, buf) = conn.into_parts();
assert!(stream.is_empty());
assert!(!buf.is_empty());
}
#[test]
fn test_chaining() {
let stream: Vec<u8> = vec![];
let mut conn = PgConnection::new(stream);
conn.query("SELECT 1").sync().terminate();
assert!(conn.pending_len() > 0);
}
#[test]
fn test_builder_chaining() {
let stream: Vec<u8> = vec![];
let mut conn = PgConnection::new(stream);
conn.parse(None)
.query("SELECT $1")
.finish()
.execute(None, 0)
.sync();
assert!(conn.pending_len() > 0);
}
#[cfg(feature = "async")]
mod async_tests {
use super::*;
#[tokio::test]
async fn test_flush() {
let mut output = Vec::new();
let mut conn = PgConnection::new(&mut output);
conn.sync();
conn.flush().await.unwrap();
assert_eq!(output.len(), 5);
assert_eq!(output[0], b'S');
}
#[tokio::test]
async fn test_recv() {
let input: &[u8] = &[b'Z', 0, 0, 0, 5, b'I'];
let mut conn = PgConnection::new(input);
let msg = conn.recv().await.unwrap();
assert!(matches!(msg, PgMessage::ReadyForQuery(_)));
}
}
#[cfg(feature = "sync")]
mod sync_tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_flush_sync() {
let mut output = Vec::new();
let mut conn = PgConnection::new(&mut output);
conn.sync();
conn.flush_sync().unwrap();
assert_eq!(output.len(), 5);
assert_eq!(output[0], b'S');
}
#[test]
fn test_recv_sync() {
let input: &[u8] = &[b'Z', 0, 0, 0, 5, b'I'];
let mut conn = PgConnection::new(Cursor::new(input));
let msg = conn.recv_sync().unwrap();
assert!(matches!(msg, PgMessage::ReadyForQuery(_)));
}
}
}