#![cfg_attr(not(feature="unstable"), allow(dead_code))]
use std::cmp::{min, max};
use std::convert::TryInto;
use std::future::{Future};
use std::pin::Pin;
use std::slice;
use std::task::{Poll, Context};
use async_std::io::Read as AsyncRead;
use async_std::stream::{Stream, StreamExt};
use bytes::{Bytes, BytesMut, BufMut};
use futures_util::io::ReadHalf;
use tls_api::TlsStream;
use edgedb_errors::{ClientConnectionError, ClientConnectionEosError};
use edgedb_errors::{Error, ErrorKind};
use edgedb_errors::{ProtocolOutOfOrderError, ProtocolEncodingError};
use edgedb_protocol::encoding::Input;
use edgedb_protocol::features::ProtocolVersion;
use edgedb_protocol::server_message::{ReadyForCommand, TransactionState};
use edgedb_protocol::server_message::{ServerMessage, ErrorResponse};
use edgedb_protocol::{QueryResult};
use crate::client;
use crate::debug::PartialDebug;
const BUFFER_SIZE: usize = 8192;
const MAX_BUFFER: usize = 1_048_576;
pub struct Reader<'a> {
pub(crate) proto: &'a ProtocolVersion,
pub(crate) stream: &'a mut ReadHalf<TlsStream>,
pub(crate) buf: &'a mut BytesMut,
pub(crate) transaction_state: &'a mut TransactionState,
}
pub struct MessageFuture<'a, 'r: 'a> {
reader: &'a mut Reader<'r>,
}
pub struct QueryResponse<'a, T: QueryResult> {
pub(crate) seq: client::Sequence<'a>,
pub(crate) complete: bool,
pub(crate) error: Option<ErrorResponse>,
pub(crate) buffer: Vec<Bytes>,
pub(crate) state: T::State,
}
impl<T: QueryResult> Unpin for QueryResponse<'_, T> {}
impl<'r> Reader<'r> {
pub fn message(&mut self) -> MessageFuture<'_, 'r> {
MessageFuture {
reader: self,
}
}
pub fn consume_ready(&mut self, ready: ReadyForCommand) {
*self.transaction_state = ready.transaction_state;
}
pub async fn wait_ready(&mut self) -> Result<(), Error> {
loop {
let msg = self.message().await?;
match msg {
ServerMessage::ReadyForCommand(ready) => {
self.consume_ready(ready);
return Ok(())
}
_ => {},
}
}
}
fn poll_message(&mut self, cx: &mut Context)
-> Poll<Result<ServerMessage, Error>>
{
let Reader { ref mut buf, ref mut stream, .. } = self;
let frame_len = loop {
let mut next_read = BUFFER_SIZE;
let buf_len = buf.len();
if buf_len > 5 {
let len = u32::from_be_bytes(
buf[1..5].try_into().unwrap())
as usize;
if buf_len >= len + 1 {
break len+1;
}
next_read = max(min(len + 1 - buf_len, MAX_BUFFER),
BUFFER_SIZE);
debug_assert!(next_read > 0);
}
buf.reserve(next_read);
unsafe {
let chunk = buf.chunk_mut();
let dest: &mut [u8] = slice::from_raw_parts_mut(
chunk.as_mut_ptr(), chunk.len());
match Pin::new(&mut *stream).poll_read(cx, dest) {
Poll::Ready(Ok(0)) => {
return Poll::Ready(
Err(ClientConnectionEosError::build())
);
}
Poll::Ready(Ok(bytes)) => {
buf.advance_mut(bytes);
continue;
}
Poll::Ready(Err(e)) => {
return Poll::Ready(
Err(ClientConnectionError::with_source(e))
);
}
Poll::Pending => return Poll::Pending,
}
}
};
let frame = buf.split_to(frame_len).freeze();
let result = ServerMessage::decode(&mut Input::new(
self.proto.clone(),
frame,
)).map_err(ProtocolEncodingError::with_source)?;
log::debug!(target: "edgedb::incoming::frame",
"Frame Contents: {:#?}", result);
return Poll::Ready(Ok(result));
}
}
impl Future for MessageFuture<'_, '_> {
type Output = Result<ServerMessage, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.reader.poll_message(cx)
}
}
impl<T: QueryResult> QueryResponse<'_, T> {
pub async fn skip_remaining(mut self) -> Result<(), Error> {
while let Some(_) = self.next().await.transpose()? {}
Ok(())
}
pub async fn get_completion(mut self) -> Result<Bytes, Error> {
Ok(self.seq._process_exec().await?)
}
}
impl<T: QueryResult> Stream for QueryResponse<'_, T> {
type Item = Result<T, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Option<Self::Item>>
{
assert!(self.seq.active); let QueryResponse {
ref mut buffer,
ref mut complete,
ref mut error,
ref mut seq,
ref mut state,
} = *self;
while buffer.len() == 0 {
match seq.reader.poll_message(cx) {
Poll::Ready(Ok(ServerMessage::Data(data))) if error.is_none()
=> {
if *complete {
return Poll::Ready(Some(
Err(ProtocolOutOfOrderError::with_message(format!(
"unsolicited packet: {}", PartialDebug(data))))
));
}
buffer.extend(data.data.into_iter().rev());
}
Poll::Ready(Ok(m @ ServerMessage::CommandComplete(_)))
if error.is_none()
=> {
if *complete {
return Poll::Ready(Some(
Err(ProtocolOutOfOrderError::with_message(format!(
"unsolicited packet: {}", PartialDebug(m))))
));
}
*complete = true;
}
Poll::Ready(Ok(ServerMessage::ReadyForCommand(r))) => {
if let Some(error) = error.take() {
seq.reader.consume_ready(r);
seq.end_clean();
return Poll::Ready(Some(Err(error.into())));
} else {
if !*complete {
let pkt = ServerMessage::ReadyForCommand(r);
return Poll::Ready(Some(
Err(ProtocolOutOfOrderError::with_message(
format!("unsolicited packet: {}",
PartialDebug(pkt))))
));
}
seq.reader.consume_ready(r);
seq.end_clean();
return Poll::Ready(None);
}
}
Poll::Ready(Ok(ServerMessage::ErrorResponse(e))) => {
*error = Some(e);
continue;
}
Poll::Ready(Ok(message)) => {
return Poll::Ready(Some(
Err(ProtocolOutOfOrderError::with_message(format!(
"unsolicited packet: {}", PartialDebug(message))))
));
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
let chunk = buffer.pop().unwrap();
Poll::Ready(Some(T::decode(state, &chunk)
.map_err(ProtocolEncodingError::with_source)))
}
}