#![cfg_attr(not(feature="unstable"), allow(dead_code, unused_imports))]
use core::future::Future;
use std::collections::HashMap;
use std::fmt;
use std::str;
use std::time::{Duration, Instant};
use async_std::prelude::FutureExt;
use async_std::prelude::StreamExt;
use async_std::future::{timeout, pending};
use async_std::io::prelude::WriteExt;
use bytes::{Bytes, BytesMut};
use futures_util::io::{ReadHalf, WriteHalf};
use typemap::TypeMap;
use tls_api::TlsStream;
use edgedb_protocol::QueryResult;
use edgedb_protocol::client_message::ClientMessage;
use edgedb_protocol::client_message::{DescribeStatement, DescribeAspect};
use edgedb_protocol::client_message::{Execute, ExecuteScript};
use edgedb_protocol::client_message::{Prepare, IoFormat, Cardinality};
use edgedb_protocol::descriptors::OutputTypedesc;
use edgedb_protocol::encoding::Output;
use edgedb_protocol::features::ProtocolVersion;
use edgedb_protocol::query_arg::{QueryArgs, Encoder};
use edgedb_protocol::queryable::{Queryable};
use edgedb_protocol::server_message::ServerMessage;
use edgedb_protocol::server_message::{TransactionState};
use crate::debug::PartialDebug;
use crate::errors::{ClientConnectionError, ProtocolError};
use crate::errors::{ClientConnectionTimeoutError, ClientConnectionEosError};
use crate::errors::{ClientInconsistentError, ClientEncodingError};
use crate::errors::{Error, ErrorKind, ResultExt};
use crate::errors::{NoResultExpected, NoDataError};
use crate::errors::{ProtocolOutOfOrderError, ProtocolEncodingError};
use crate::reader::{self, QueryResponse, Reader};
use crate::server_params::{ServerParam, SystemConfig};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum State {
Normal {
idle_since: Instant,
},
Dirty,
AwaitingPing,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum PingInterval {
Unknown,
Disabled,
Interval(Duration),
}
#[derive(Debug)]
pub struct Connection {
pub(crate) ping_interval: PingInterval,
pub(crate) input: ReadHalf<TlsStream>,
pub(crate) output: WriteHalf<TlsStream>,
pub(crate) input_buf: BytesMut,
pub(crate) output_buf: BytesMut,
pub(crate) version: ProtocolVersion,
pub(crate) params: TypeMap<dyn typemap::DebugAny + Send + Sync>,
pub(crate) transaction_state: TransactionState,
pub(crate) state: State,
}
pub struct Sequence<'a> {
pub writer: Writer<'a>,
pub reader: Reader<'a>,
pub(crate) active: bool,
pub(crate) state: &'a mut State,
}
pub struct Writer<'a> {
stream: &'a mut WriteHalf<TlsStream>,
proto: &'a ProtocolVersion,
outbuf: &'a mut BytesMut,
}
#[derive(Debug, Clone)]
pub struct StatementParams {
pub io_format: IoFormat,
pub cardinality: Cardinality,
}
impl StatementParams {
pub fn new() -> StatementParams {
StatementParams {
io_format: IoFormat::Binary,
cardinality: Cardinality::Many,
}
}
pub fn io_format(&mut self, fmt: IoFormat) -> &mut Self {
self.io_format = fmt;
self
}
pub fn cardinality(&mut self, card: Cardinality) -> &mut Self {
self.cardinality = card;
self
}
}
impl<'a> Sequence<'a> {
pub fn response<T: QueryResult>(self, state: T::State)
-> QueryResponse<'a, T>
{
assert!(self.active); reader::QueryResponse {
seq: self,
buffer: Vec::new(),
error: None,
complete: false,
state,
}
}
pub async fn response_blobs(mut self) -> Result<(Vec<Bytes>, Bytes), Error>
{
assert!(self.active); let mut data = Vec::new();
let complete = loop {
match self.reader.message().await? {
ServerMessage::Data(m) => data.extend(m.data),
ServerMessage::CommandComplete(m) => break Ok(m.status_data),
ServerMessage::ErrorResponse(e) => break Err(e),
msg => {
return Err(ProtocolOutOfOrderError::with_message(format!(
"unsolicited packet: {}", PartialDebug(msg))));
}
}
};
match self.reader.message().await? {
ServerMessage::ReadyForCommand(r) => {
self.reader.consume_ready(r);
self.end_clean();
return complete.map(|m| (data, m)).map_err(|e| e.into());
}
msg => {
return Err(ProtocolOutOfOrderError::with_message(format!(
"unsolicited packet: {}", PartialDebug(msg))));
}
}
}
pub fn end_clean(&mut self) {
self.active = false;
*self.state = State::Normal {
idle_since: Instant::now(),
};
}
}
impl Connection {
pub fn protocol(&self) -> &ProtocolVersion {
return &self.version
}
pub async fn passive_wait<T>(&mut self) -> T {
let (_, mut reader, _) = self.split();
reader.passive_wait().await.ok();
self.state = State::Dirty;
pending::<()>().await;
unreachable!();
}
async fn do_pings(&mut self, interval: Duration) -> Result<(), Error> {
use async_std::io;
let (mut writer, mut reader, state) = self.split();
if *state == State::AwaitingPing {
Self::synchronize_ping(&mut reader, state).await?;
}
while let State::Normal { idle_since: last_pong } = *state {
match io::timeout(
interval.saturating_sub(Instant::now() - last_pong),
reader.passive_wait()
).await {
Err(e) if e.kind() == io::ErrorKind::TimedOut => (),
Err(e) => {
*state = State::Dirty;
return Err(ClientConnectionError::with_source(e))?;
}
Ok(_) => unreachable!(),
}
*state = State::Dirty;
writer.send_messages(&[ClientMessage::Sync]).await?;
*state = State::AwaitingPing;
Self::synchronize_ping(&mut reader, state).await?;
}
Ok(())
}
async fn background_pings<T>(&mut self, interval: Duration) -> T {
self.do_pings(interval).await
.map_err(|e| {
log::info!("Connection error during background pings: {}", e)
})
.ok();
debug_assert_eq!(self.state, State::Dirty);
pending::<()>().await;
unreachable!();
}
async fn synchronize_ping<'a>(
reader: &mut Reader<'a>, state: &mut State
) -> Result<(), Error> {
debug_assert_eq!(*state, State::AwaitingPing);
if let Err(e) = reader.wait_ready().await {
*state = State::Dirty;
Err(e)
} else {
*state = State::Normal { idle_since: Instant::now() };
Ok(())
}
}
fn calc_ping_interval(&self) -> PingInterval {
if let Some(config) = self.params.get::<SystemConfig>() {
if let Some(timeout) = config.session_idle_timeout {
if timeout.is_zero() {
log::info!(
"Server disabled session_idle_timeout; \
pings are disabled."
);
PingInterval::Disabled
} else {
let interval = Duration::from_secs(
(
timeout.saturating_sub(
Duration::from_secs(1)
).as_secs_f64() * 0.9
).ceil() as u64
);
if interval.is_zero() {
log::warn!(
"session_idle_timeout={:?} is too short; \
pings are disabled.",
timeout,
);
PingInterval::Disabled
} else {
log::info!(
"Setting ping interval to {:?} as \
session_idle_timeout={:?}",
interval, timeout,
);
PingInterval::Interval(interval)
}
}
} else {
PingInterval::Unknown
}
} else {
PingInterval::Unknown
}
}
#[cfg(feature="unstable")]
pub async fn ping_while<T, F>(&mut self, other: F) -> T
where F: Future<Output = T>
{
if self.ping_interval == PingInterval::Unknown {
self.ping_interval = self.calc_ping_interval();
}
if let PingInterval::Interval(interval) = self.ping_interval {
let rv = other.race(self.background_pings(interval)).await;
if self.state == State::AwaitingPing {
let (_, ref mut reader, state) = self.split();
Self::synchronize_ping(reader, state).await.ok();
}
rv
} else {
other.await
}
}
pub fn is_consistent(&self) -> bool {
matches!(self.state, State::Normal {
idle_since: _,
})
}
pub async fn terminate(mut self) -> Result<(), Error> {
let mut seq = self.start_sequence().await?;
seq.send_messages(&[ClientMessage::Terminate]).await?;
match seq.message().await {
Err(e) if e.is::<ClientConnectionEosError>() => Ok(()),
Err(e) => Err(e),
Ok(msg) => Err(ProtocolError::with_message(format!(
"unsolicited message {:?}", msg))),
}
}
pub async fn start_sequence<'x>(&'x mut self)
-> Result<Sequence<'x>, Error>
{
let (writer, reader, state) = self.split();
if !matches!(*state, State::Normal {
idle_since: _,
}) {
return Err(ClientInconsistentError::with_message(
"Connection is inconsistent state. Please reconnect."));
}
*state = State::Dirty;
Ok(Sequence {
writer,
reader,
state,
active: true,
})
}
pub fn get_param<T: ServerParam>(&self)
-> Option<&<T as typemap::Key>::Value>
where <T as typemap::Key>::Value: fmt::Debug + Send + Sync
{
self.params.get::<T>()
}
pub fn transaction_state(&self) -> TransactionState {
self.transaction_state
}
fn split(&mut self) -> (Writer, Reader, &mut State) {
let reader = Reader {
proto: &self.version,
buf: &mut self.input_buf,
stream: &mut self.input,
transaction_state: &mut self.transaction_state,
};
let writer = Writer {
proto: &self.version,
outbuf: &mut self.output_buf,
stream: &mut self.output,
};
(writer, reader, &mut self.state)
}
}
impl<'a> Writer<'a> {
pub async fn send_messages<'x, I>(&mut self, msgs: I) -> Result<(), Error>
where I: IntoIterator<Item=&'x ClientMessage>
{
self.outbuf.truncate(0);
for msg in msgs {
msg.encode(&mut Output::new(
&self.proto,
self.outbuf,
)).map_err(ClientEncodingError::with_source)?;
}
self.stream.write_all(&self.outbuf[..]).await
.map_err(ClientConnectionError::with_source)?;
Ok(())
}
}
impl<'a> Sequence<'a> {
pub async fn send_messages<'x, I>(&mut self, msgs: I)
-> Result<(), Error>
where I: IntoIterator<Item=&'x ClientMessage>
{
assert!(self.active); self.writer.send_messages(msgs).await
}
pub async fn expect_ready(&mut self) -> Result<(), Error> {
assert!(self.active); self.reader.wait_ready().await?;
self.end_clean();
Ok(())
}
pub fn message(&mut self) -> reader::MessageFuture<'_, 'a> {
assert!(self.active); self.reader.message()
}
pub async fn err_sync(&mut self) -> Result<(), Error> {
assert!(self.active); self.writer.send_messages(&[ClientMessage::Sync]).await?;
timeout(Duration::from_secs(10), self.expect_ready()).await
.map_err(ClientConnectionTimeoutError::with_source)??;
Ok(())
}
pub async fn _process_exec(&mut self) -> Result<Bytes, Error> {
assert!(self.active); let status = loop {
let msg = self.reader.message().await?;
match msg {
ServerMessage::CommandComplete(c) => {
self.reader.wait_ready().await?;
self.end_clean();
break c.status_data;
}
ServerMessage::ErrorResponse(err) => {
self.reader.wait_ready().await?;
self.end_clean();
return Err(err.into());
}
ServerMessage::Data(_) => { }
msg => {
eprintln!("WARNING: unsolicited message {:?}", msg);
}
}
};
Ok(status)
}
pub(crate) async fn _query<A>(&mut self, request: &str, arguments: &A,
bld: &StatementParams)
-> Result<OutputTypedesc, Error>
where A: QueryArgs + ?Sized,
{
assert!(self.active); let statement_name = Bytes::from_static(b"");
self.send_messages(&[
ClientMessage::Prepare(Prepare {
headers: HashMap::new(),
io_format: bld.io_format,
expected_cardinality: bld.cardinality,
statement_name: statement_name.clone(),
command_text: String::from(request),
}),
ClientMessage::Flush,
]).await?;
loop {
let msg = self.reader.message().await?;
match msg {
ServerMessage::PrepareComplete(..) => {
break;
}
ServerMessage::ErrorResponse(err) => {
self.err_sync().await?;
return Err(err.into());
}
_ => {
return Err(ProtocolOutOfOrderError::with_message(format!(
"Unsolicited message {:?}", msg)));
}
}
}
self.send_messages(&[
ClientMessage::DescribeStatement(DescribeStatement {
headers: HashMap::new(),
aspect: DescribeAspect::DataDescription,
statement_name: statement_name.clone(),
}),
ClientMessage::Flush,
]).await?;
let data_description = loop {
let msg = self.reader.message().await?;
match msg {
ServerMessage::CommandDataDescription(data_desc) => {
break data_desc;
}
ServerMessage::ErrorResponse(err) => {
self.err_sync().await?;
return Err(err.into());
}
_ => {
return Err(ProtocolOutOfOrderError::with_message(format!(
"Unsolicited message {:?}", msg)));
}
}
};
let desc = data_description.output()
.map_err(ProtocolEncodingError::with_source)?;
let inp_desc = data_description.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
self.send_messages(&[
ClientMessage::Execute(Execute {
headers: HashMap::new(),
statement_name: statement_name.clone(),
arguments: arg_buf.freeze(),
}),
ClientMessage::Sync,
]).await?;
Ok(desc)
}
}
impl Connection {
pub async fn execute<S>(&mut self, request: S)
-> Result<Bytes, Error>
where S: ToString,
{
let mut seq = self.start_sequence().await?;
seq.send_messages(&[
ClientMessage::ExecuteScript(ExecuteScript {
headers: HashMap::new(),
script_text: request.to_string(),
}),
]).await?;
let status = loop {
match seq.message().await? {
ServerMessage::CommandComplete(c) => {
seq.expect_ready().await?;
break c.status_data;
}
ServerMessage::ErrorResponse(err) => {
seq.expect_ready().await?;
return Err(err.into());
}
msg => {
eprintln!("WARNING: unsolicited message {:?}", msg);
}
}
};
Ok(status)
}
pub async fn query<R, A>(&mut self, request: &str, arguments: &A)
-> Result<QueryResponse<'_, R>, Error>
where R: QueryResult,
A: QueryArgs,
{
let mut seq = self.start_sequence().await?;
let desc = seq._query(
request, arguments,
&StatementParams::new(),
).await?;
match desc.root_pos() {
Some(root_pos) => {
let ctx = desc.as_queryable_context();
let state = R::prepare(&ctx, root_pos)?;
Ok(seq.response(state))
}
None => {
let completion_message = seq._process_exec().await?;
Err(NoResultExpected::with_message(
String::from_utf8_lossy(&completion_message[..])
.to_string()))?
}
}
}
pub async fn query_row<R, A>(&mut self, request: &str, arguments: &A)
-> Result<R, Error>
where R: Queryable,
A: QueryArgs,
{
let mut query = self.query(request, arguments).await?;
if let Some(result) = query.next().await.transpose()? {
if let Some(_) = query.next().await.transpose()? {
query.skip_remaining().await?;
return Err(ProtocolError::with_message(
"extra row returned for query_row"
));
}
Ok(result)
} else {
return Err(NoDataError::build());
}
}
pub async fn query_row_opt<R, A>(&mut self, request: &str, arguments: &A)
-> Result<Option<R>, Error>
where R: Queryable,
A: QueryArgs,
{
let mut query = self.query(request, arguments).await?;
if let Some(result) = query.next().await.transpose()? {
if let Some(_) = query.next().await.transpose()? {
return Err(ProtocolError::with_message(
"extra row returned for query_row"
));
}
Ok(Some(result))
} else {
Ok(None)
}
}
pub async fn query_json<A>(&mut self, request: &str, arguments: &A)
-> Result<QueryResponse<'_, String>, Error>
where A: QueryArgs,
{
let mut seq = self.start_sequence().await?;
let desc = seq._query(
request, arguments,
&StatementParams::new().io_format(IoFormat::Json),
).await?;
match desc.root_pos() {
Some(root_pos) => {
let ctx = desc.as_queryable_context();
let state = String::prepare(&ctx, root_pos)?;
Ok(seq.response(state))
}
None => {
let completion_message = seq._process_exec().await?;
Err(NoResultExpected::with_message(
String::from_utf8_lossy(&completion_message[..])
.to_string()))?
}
}
}
pub async fn query_json_els<A>(&mut self, request: &str, arguments: &A)
-> Result<QueryResponse<'_, String>, Error>
where A: QueryArgs,
{
let mut seq = self.start_sequence().await?;
let desc = seq._query(
request, arguments,
&StatementParams::new().io_format(IoFormat::JsonElements),
).await?;
match desc.root_pos() {
Some(root_pos) => {
let ctx = desc.as_queryable_context();
let state = String::prepare(&ctx, root_pos)?;
Ok(seq.response(state))
}
None => {
let completion_message = seq._process_exec().await?;
Err(NoResultExpected::with_message(
String::from_utf8_lossy(&completion_message[..])
.to_string()))?
}
}
}
#[allow(dead_code)]
pub async fn execute_args<A>(&mut self, request: &str, arguments: &A)
-> Result<Bytes, Error>
where A: QueryArgs,
{
let mut seq = self.start_sequence().await?;
seq._query(request, arguments, &StatementParams::new()).await?;
return seq._process_exec().await;
}
pub async fn get_version(&mut self) -> Result<String, Error> {
self.query_row("SELECT sys::get_version_as_str()", &()).await
.context("cannot fetch database version")
}
}