unreql 0.2.1

Well documented and easy to use RethinkDB Rust Driver
Documentation
use super::args::Args;
use crate::cmd::options::{Db, RunOptions};
use crate::proto::{Command, Payload};
use crate::{err, Connection, Result, Session};
use async_stream::try_stream;
use async_trait::async_trait;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::stream::{Stream, StreamExt};
use ql2::query::QueryType;
use ql2::response::{ErrorType, ResponseType};
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde_json::Value;
use std::str;
use std::sync::atomic::Ordering;
use tracing::trace;

const DATA_SIZE: usize = 4;
const TOKEN_SIZE: usize = 8;
const HEADER_SIZE: usize = DATA_SIZE + TOKEN_SIZE;

#[derive(Deserialize, Debug)]
#[allow(dead_code)]
pub(crate) struct Response {
    t: i32,
    e: Option<i32>,
    pub(crate) r: Value,
    b: Option<Value>,
    p: Option<Value>,
    n: Option<Value>,
}

impl Response {
    fn new() -> Self {
        Self {
            t: ResponseType::SuccessAtom as i32,
            e: None,
            r: Value::Array(Vec::new()),
            b: None,
            p: None,
            n: None,
        }
    }
}

#[async_trait]
pub trait Arg {
    async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, RunOptions)>;
}

#[async_trait]
impl Arg for &Session {
    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
        let conn = self.connection()?;
        Ok((conn, Default::default()))
    }
}

#[async_trait]
impl Arg for Connection {
    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
        Ok((self, Default::default()))
    }
}

#[async_trait]
impl Arg for Args<(&Session, RunOptions)> {
    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
        let Args((session, options)) = self;
        let conn = session.connection()?;
        Ok((conn, options))
    }
}

#[async_trait]
impl Arg for Args<(Connection, RunOptions)> {
    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, RunOptions)> {
        let Args(arg) = self;
        Ok(arg)
    }
}

#[async_trait]
impl Arg for &mut Session {
    async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, RunOptions)> {
        self.connection()?.into_run_opts(for_changes).await
    }
}

impl RunOptions {
    async fn default_db(self, session: &Session) -> RunOptions {
        let session_db = session.inner.db.lock().await;
        const DEFAULT_DB: &str = "test";
        if self.db.is_none() && *session_db != DEFAULT_DB {
            return Self {
                db: Some(Db(session_db.clone())),
                ..self
            };
        }
        self
    }
}

pub(crate) fn new<A, T>(query: Command, arg: A) -> impl Stream<Item = Result<T>>
where
    A: Arg,
    T: Unpin + DeserializeOwned,
{
    try_stream! {
        let (mut conn, mut opts) = arg.into_run_opts(query.change_feed()).await?;
        opts = opts.default_db(&conn.session).await;
        let change_feed = query.change_feed();
        if change_feed {
            conn.session.inner.mark_change_feed();
        }
        let noreply = opts.noreply.unwrap_or_default();
        let mut payload = Payload(QueryType::Start, Some(&query), opts);
        loop {
            let (response_type, resp) = conn.request(&payload, noreply).await?;
            trace!("yielding response; token: {}", conn.token);
            match response_type {
                ResponseType::SuccessAtom => {
                    // If response is array then will try to flat it
                    // [[1, 2, 3]] => [1, 2, 3]
                    let atom_val = if let Value::Array(arr) = resp.r {
                        if arr.is_empty() {
                            Value::Array(arr)
                        } else {
                            match &arr[0] {
                                Value::Array(inner_arr) => Value::Array(inner_arr.clone()),
                                _ => Value::Array(arr),
                            }
                        }
                    } else {
                        resp.r
                    };
                    for val in serde_json::from_value::<Vec<T>>(atom_val)? {
                        yield val;
                    }
                    break;
                },
                ResponseType::SuccessSequence | ResponseType::ServerInfo => {
                    for val in serde_json::from_value::<Vec<T>>(resp.r)? {
                        yield val;
                    }
                    break;
                }
                ResponseType::SuccessPartial => {
                    if conn.closed() {
                        // reopen so we can use the connection in future
                        conn.set_closed(false);
                        trace!("connection closed; token: {}", conn.token);
                        break;
                    }
                    payload = Payload(QueryType::Continue, None, Default::default());
                    for val in serde_json::from_value::<Vec<T>>(resp.r)? {
                        yield val;
                    }
                    continue;
                }
                ResponseType::WaitComplete => { break; }
                typ => {
                    let msg = error_message(resp.r)?;
                    match typ {
                        // This feed has been closed by conn.close().
                        ResponseType::ClientError if change_feed && msg.contains("not in stream cache") => { break; }
                        _ => Err(response_error(typ, resp.e, msg))?,
                    }
                }
            }
        }
    }
}

impl Payload<'_> {
    fn encode(&self, token: u64) -> Result<Vec<u8>> {
        let bytes = self.to_bytes()?;
        let data_len = bytes.len();
        let mut buf = Vec::with_capacity(HEADER_SIZE + data_len);
        buf.extend_from_slice(&token.to_le_bytes());
        buf.extend_from_slice(&(data_len as u32).to_le_bytes());
        buf.extend_from_slice(&bytes);
        Ok(buf)
    }
}

impl Connection {
    fn send_response(&self, db_token: u64, resp: Result<(ResponseType, Response)>) {
        if let Some(tx) = self.session.inner.channels.get(&db_token) {
            if let Err(error) = tx.unbounded_send(resp) {
                if error.is_disconnected() {
                    self.session.inner.channels.remove(&db_token);
                }
            }
        }
    }

    pub(crate) async fn request<'a>(
        &mut self,
        query: &'a Payload<'a>,
        noreply: bool,
    ) -> Result<(ResponseType, Response)> {
        self.submit(query, noreply).await;
        match self.rx.lock().await.next().await {
            Some(resp) => resp,
            None => Ok((ResponseType::SuccessAtom, Response::new())),
        }
    }

    async fn submit<'a>(&self, query: &'a Payload<'a>, noreply: bool) {
        let mut db_token = self.token;
        let result = self.exec(query, noreply, &mut db_token).await;
        self.send_response(db_token, result);
    }

    async fn exec<'a>(
        &self,
        query: &'a Payload<'a>,
        noreply: bool,
        db_token: &mut u64,
    ) -> Result<(ResponseType, Response)> {
        let buf = query.encode(self.token)?;

        let guard = self.session.inner.stream.lock().await;
        let mut stream = guard.clone();

        trace!("sending query; token: {}, payload: {}", self.token, query);
        stream.write_all(&buf).await?;
        trace!("query sent; token: {}", self.token);

        if noreply {
            return Ok((ResponseType::SuccessAtom, Response::new()));
        }

        trace!("reading header; token: {}", self.token);
        let mut header = [0u8; HEADER_SIZE];
        stream.read_exact(&mut header).await?;

        let mut buf = [0u8; TOKEN_SIZE];
        buf.copy_from_slice(&header[..TOKEN_SIZE]);
        *db_token = {
            let token = u64::from_le_bytes(buf);
            trace!("db_token: {}", token);
            if token > self.session.inner.token.load(Ordering::SeqCst) {
                self.session.inner.mark_broken();
                return Err(err::Driver::ConnectionBroken.into());
            }
            token
        };

        let mut buf = [0u8; DATA_SIZE];
        buf.copy_from_slice(&header[TOKEN_SIZE..]);
        let len = u32::from_le_bytes(buf) as usize;
        trace!(
            "header read; token: {}, db_token: {}, response_len: {}",
            self.token,
            db_token,
            len
        );

        trace!("reading body; token: {}", self.token);
        let mut buf = vec![0u8; len];
        stream.read_exact(&mut buf).await?;

        trace!(
            "body read; token: {}, db_token: {}, body: {}",
            self.token,
            db_token,
            crate::tools::bytes_to_string(&buf),
        );

        let resp = serde_json::from_slice::<Response>(&buf)?;
        trace!("response successfully parsed; token: {}", self.token,);

        let response_type = ResponseType::from_i32(resp.t)
            .ok_or_else(|| err::Driver::Other(format!("unknown response type `{}`", resp.t)))?;

        if let Some(error_type) = resp.e {
            let msg = error_message(resp.r)?;
            return Err(response_error(response_type, Some(error_type), msg));
        }

        Ok((response_type, resp))
    }
}

fn error_message(response: Value) -> Result<String> {
    let messages = serde_json::from_value::<Vec<String>>(response)?;
    Ok(messages.join(" "))
}

fn response_error(response_type: ResponseType, error_type: Option<i32>, msg: String) -> err::Error {
    match response_type {
        ResponseType::ClientError => err::Driver::Other(msg).into(),
        ResponseType::CompileError => err::Error::Compile(msg),
        ResponseType::RuntimeError => match error_type
            .map(ErrorType::from_i32)
            .ok_or_else(|| err::Driver::Other(format!("unexpected runtime error: {}", msg)))
        {
            Ok(Some(ErrorType::Internal)) => err::Runtime::Internal(msg).into(),
            Ok(Some(ErrorType::ResourceLimit)) => err::Runtime::ResourceLimit(msg).into(),
            Ok(Some(ErrorType::QueryLogic)) => err::Runtime::QueryLogic(msg).into(),
            Ok(Some(ErrorType::NonExistence)) => err::Runtime::NonExistence(msg).into(),
            Ok(Some(ErrorType::OpFailed)) => err::Availability::OpFailed(msg).into(),
            Ok(Some(ErrorType::OpIndeterminate)) => err::Availability::OpIndeterminate(msg).into(),
            Ok(Some(ErrorType::User)) => err::Runtime::User(msg).into(),
            Ok(Some(ErrorType::PermissionError)) => err::Runtime::Permission(msg).into(),
            Err(error) => error.into(),
            _ => err::Driver::Other(format!("unexpected runtime error: {}", msg)).into(),
        },
        _ => err::Driver::Other(format!("unexpected response: {}", msg)).into(),
    }
}