use super::args::Args;
use crate::cmd::options::{Db, RunOptions};
use crate::proto::{Command, Payload};
use crate::{err, Connection, Result, Session};
use async_stream::try_stream;
use async_trait::async_trait;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::stream::{Stream, StreamExt};
use ql2::query::QueryType;
use ql2::response::{ErrorType, ResponseType};
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde_json::Value;
use std::str;
use std::sync::atomic::Ordering;
use tracing::trace;
const DATA_SIZE: usize = 4;
const TOKEN_SIZE: usize = 8;
const HEADER_SIZE: usize = DATA_SIZE + TOKEN_SIZE;
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
pub(crate) struct Response {
t: i32,
e: Option<i32>,
pub(crate) r: Value,
b: Option<Value>,
p: Option<Value>,
n: Option<Value>,
}
impl Response {
fn new() -> Self {
Self {
t: ResponseType::SuccessAtom as i32,
e: None,
r: Value::Array(Vec::new()),
b: None,
p: None,
n: None,
}
}
}
#[async_trait]
pub trait Arg {
async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, RunOptions)>;
}
#[async_trait]
impl Arg for &Session {
async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
let conn = self.connection()?;
Ok((conn, Default::default()))
}
}
#[async_trait]
impl Arg for Connection {
async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
Ok((self, Default::default()))
}
}
#[async_trait]
impl Arg for Args<(&Session, RunOptions)> {
async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
let Args((session, options)) = self;
let conn = session.connection()?;
Ok((conn, options))
}
}
#[async_trait]
impl Arg for Args<(Connection, RunOptions)> {
async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
let Args(arg) = self;
Ok(arg)
}
}
#[async_trait]
impl Arg for &mut Session {
async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, RunOptions)> {
self.connection()?.into_run_opts(for_changes).await
}
}
impl RunOptions {
async fn default_db(self, session: &Session) -> RunOptions {
let session_db = session.inner.db.lock().await;
const DEFAULT_DB: &str = "test";
if self.db.is_none() && *session_db != DEFAULT_DB {
return Self {
db: Some(Db(session_db.clone())),
..self
};
}
self
}
}
pub(crate) fn new<A, T>(query: Command, arg: A) -> impl Stream<Item = Result<T>>
where
A: Arg,
T: Unpin + DeserializeOwned,
{
try_stream! {
let (mut conn, mut opts) = arg.into_run_opts(query.change_feed()).await?;
opts = opts.default_db(&conn.session).await;
let change_feed = query.change_feed();
if change_feed {
conn.session.inner.mark_change_feed();
}
let noreply = opts.noreply.unwrap_or_default();
let mut payload = Payload(QueryType::Start, Some(&query), opts);
loop {
let (response_type, resp) = conn.request(&payload, noreply).await?;
trace!("yielding response; token: {}", conn.token);
match response_type {
ResponseType::SuccessAtom => {
let atom_val = if let Value::Array(arr) = resp.r {
if arr.is_empty() {
Value::Array(arr)
} else {
match &arr[0] {
Value::Array(inner_arr) => Value::Array(inner_arr.clone()),
_ => Value::Array(arr),
}
}
} else {
resp.r
};
for val in serde_json::from_value::<Vec<T>>(atom_val)? {
yield val;
}
break;
},
ResponseType::SuccessSequence | ResponseType::ServerInfo => {
for val in serde_json::from_value::<Vec<T>>(resp.r)? {
yield val;
}
break;
}
ResponseType::SuccessPartial => {
if conn.closed() {
conn.set_closed(false);
trace!("connection closed; token: {}", conn.token);
break;
}
payload = Payload(QueryType::Continue, None, Default::default());
for val in serde_json::from_value::<Vec<T>>(resp.r)? {
yield val;
}
continue;
}
ResponseType::WaitComplete => { break; }
typ => {
let msg = error_message(resp.r)?;
match typ {
ResponseType::ClientError if change_feed && msg.contains("not in stream cache") => { break; }
_ => Err(response_error(typ, resp.e, msg))?,
}
}
}
}
}
}
impl Payload<'_> {
fn encode(&self, token: u64) -> Result<Vec<u8>> {
let bytes = self.to_bytes()?;
let data_len = bytes.len();
let mut buf = Vec::with_capacity(HEADER_SIZE + data_len);
buf.extend_from_slice(&token.to_le_bytes());
buf.extend_from_slice(&(data_len as u32).to_le_bytes());
buf.extend_from_slice(&bytes);
Ok(buf)
}
}
impl Connection {
fn send_response(&self, db_token: u64, resp: Result<(ResponseType, Response)>) {
if let Some(tx) = self.session.inner.channels.get(&db_token) {
if let Err(error) = tx.unbounded_send(resp) {
if error.is_disconnected() {
self.session.inner.channels.remove(&db_token);
}
}
}
}
pub(crate) async fn request<'a>(
&mut self,
query: &'a Payload<'a>,
noreply: bool,
) -> Result<(ResponseType, Response)> {
self.submit(query, noreply).await;
match self.rx.lock().await.next().await {
Some(resp) => resp,
None => Ok((ResponseType::SuccessAtom, Response::new())),
}
}
async fn submit<'a>(&self, query: &'a Payload<'a>, noreply: bool) {
let mut db_token = self.token;
let result = self.exec(query, noreply, &mut db_token).await;
self.send_response(db_token, result);
}
async fn exec<'a>(
&self,
query: &'a Payload<'a>,
noreply: bool,
db_token: &mut u64,
) -> Result<(ResponseType, Response)> {
let buf = query.encode(self.token)?;
let guard = self.session.inner.stream.lock().await;
let mut stream = guard.clone();
trace!("sending query; token: {}, payload: {}", self.token, query);
stream.write_all(&buf).await?;
trace!("query sent; token: {}", self.token);
if noreply {
return Ok((ResponseType::SuccessAtom, Response::new()));
}
trace!("reading header; token: {}", self.token);
let mut header = [0u8; HEADER_SIZE];
stream.read_exact(&mut header).await?;
let mut buf = [0u8; TOKEN_SIZE];
buf.copy_from_slice(&header[..TOKEN_SIZE]);
*db_token = {
let token = u64::from_le_bytes(buf);
trace!("db_token: {}", token);
if token > self.session.inner.token.load(Ordering::SeqCst) {
self.session.inner.mark_broken();
return Err(err::Driver::ConnectionBroken.into());
}
token
};
let mut buf = [0u8; DATA_SIZE];
buf.copy_from_slice(&header[TOKEN_SIZE..]);
let len = u32::from_le_bytes(buf) as usize;
trace!(
"header read; token: {}, db_token: {}, response_len: {}",
self.token,
db_token,
len
);
trace!("reading body; token: {}", self.token);
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf).await?;
trace!(
"body read; token: {}, db_token: {}, body: {}",
self.token,
db_token,
crate::tools::bytes_to_string(&buf),
);
let resp = serde_json::from_slice::<Response>(&buf)?;
trace!("response successfully parsed; token: {}", self.token,);
let response_type = ResponseType::from_i32(resp.t)
.ok_or_else(|| err::Driver::Other(format!("unknown response type `{}`", resp.t)))?;
if let Some(error_type) = resp.e {
let msg = error_message(resp.r)?;
return Err(response_error(response_type, Some(error_type), msg));
}
Ok((response_type, resp))
}
}
fn error_message(response: Value) -> Result<String> {
let messages = serde_json::from_value::<Vec<String>>(response)?;
Ok(messages.join(" "))
}
fn response_error(response_type: ResponseType, error_type: Option<i32>, msg: String) -> err::Error {
match response_type {
ResponseType::ClientError => err::Driver::Other(msg).into(),
ResponseType::CompileError => err::Error::Compile(msg),
ResponseType::RuntimeError => match error_type
.map(ErrorType::from_i32)
.ok_or_else(|| err::Driver::Other(format!("unexpected runtime error: {}", msg)))
{
Ok(Some(ErrorType::Internal)) => err::Runtime::Internal(msg).into(),
Ok(Some(ErrorType::ResourceLimit)) => err::Runtime::ResourceLimit(msg).into(),
Ok(Some(ErrorType::QueryLogic)) => err::Runtime::QueryLogic(msg).into(),
Ok(Some(ErrorType::NonExistence)) => err::Runtime::NonExistence(msg).into(),
Ok(Some(ErrorType::OpFailed)) => err::Availability::OpFailed(msg).into(),
Ok(Some(ErrorType::OpIndeterminate)) => err::Availability::OpIndeterminate(msg).into(),
Ok(Some(ErrorType::User)) => err::Runtime::User(msg).into(),
Ok(Some(ErrorType::PermissionError)) => err::Runtime::Permission(msg).into(),
Err(error) => error.into(),
_ => err::Driver::Other(format!("unexpected runtime error: {}", msg)).into(),
},
_ => err::Driver::Other(format!("unexpected response: {}", msg)).into(),
}
}