use super::args::Args;
use super::{bytes_to_string, StaticString};
use crate::{err, InnerSession, Result, Session};
use async_net::{AsyncToSocketAddrs, TcpStream};
use dashmap::DashMap;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::lock::Mutex;
use ql2::version_dummy::Version;
use reql_macros::CommandOptions;
use scram::client::{ScramClient, ServerFinal, ServerFirst};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::Arc;
use tracing::trace;
const BUF_SIZE: usize = 1024;
const NULL_BYTE: u8 = b'\0';
const PROTOCOL_VERSION: usize = 0;
pub(crate) const DEFAULT_DB: &str = "test";
#[derive(Debug, Clone, CommandOptions, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[non_exhaustive]
pub struct Options {
pub host: Cow<'static, str>,
pub port: u16,
pub db: Cow<'static, str>,
pub user: Cow<'static, str>,
pub password: Cow<'static, str>,
}
impl Default for Options {
fn default() -> Self {
Self {
host: "localhost".static_string(),
port: 28015,
db: DEFAULT_DB.static_string(),
user: "admin".static_string(),
password: "".static_string(),
}
}
}
pub trait Arg {
type ToAddrs: AsyncToSocketAddrs;
fn into_connect_opts(self) -> (Option<Self::ToAddrs>, Options);
}
impl Arg for () {
type ToAddrs = SocketAddr;
fn into_connect_opts(self) -> (Option<Self::ToAddrs>, Options) {
(None, Default::default())
}
}
impl Arg for Options {
type ToAddrs = SocketAddr;
fn into_connect_opts(self) -> (Option<Self::ToAddrs>, Options) {
(None, self)
}
}
impl<'a> Arg for &'a str {
type ToAddrs = (&'a str, u16);
fn into_connect_opts(self) -> (Option<Self::ToAddrs>, Options) {
let opts = Options::default();
(Some((self, opts.port)), opts)
}
}
impl<T> Arg for Args<(T, Options)>
where
T: AsyncToSocketAddrs,
{
type ToAddrs = T;
fn into_connect_opts(self) -> (Option<Self::ToAddrs>, Options) {
let Args((addr, opts)) = self;
(Some(addr), opts)
}
}
pub(crate) async fn new<T>((addr, options): (Option<T>, Options)) -> Result<Session>
where
T: AsyncToSocketAddrs,
{
let stream = match addr {
Some(addr) => TcpStream::connect(addr).await?,
None => TcpStream::connect((options.host.as_ref(), options.port)).await?,
};
let inner = InnerSession {
stream: Mutex::new(handshake(stream, &options).await?),
db: Mutex::new(options.db),
channels: DashMap::new(),
token: AtomicU64::new(0),
broken: AtomicBool::new(false),
change_feed: AtomicBool::new(false),
};
Ok(Session {
inner: Arc::new(inner),
})
}
async fn handshake(mut stream: TcpStream, opts: &Options) -> Result<TcpStream> {
trace!("sending supported version to RethinkDB");
stream
.write_all(&(Version::V10 as i32).to_le_bytes())
.await?;
let scram = ScramClient::new(opts.user.as_ref(), opts.password.as_ref(), None);
let (scram, msg) = client_first(scram)?;
trace!("sending client first message");
stream.write_all(&msg).await?;
let mut buf = [0u8; BUF_SIZE];
trace!("receiving message(s) from RethinkDB");
stream.read(&mut buf).await?; let (len, resp) = bytes(&buf, 0);
trace!("received server info; info: {}", bytes_to_string(resp));
ServerInfo::validate(resp)?;
let offset = len + 1;
let resp = if offset < BUF_SIZE && buf[offset] != NULL_BYTE {
bytes(&buf, offset).1
} else {
trace!("reading auth response");
stream.read(&mut buf).await?; bytes(&buf, 0).1
};
trace!("received auth response");
let info = AuthResponse::from_slice(resp)?;
let auth = match info.authentication {
Some(auth) => auth,
None => {
let msg = String::from("server did not send authentication info");
return Err(err::Driver::Other(msg).into());
}
};
let (scram, msg) = client_final(scram, &auth)?;
trace!("sending client final message");
stream.write_all(&msg).await?;
trace!("reading server final message");
stream.read(&mut buf).await?; let resp = bytes(&buf, 0).1;
trace!("received server final message");
server_final(scram, resp)?;
trace!("client connected successfully");
Ok(stream)
}
fn bytes(buf: &[u8], offset: usize) -> (usize, &[u8]) {
let len = (&buf[offset..])
.iter()
.take_while(|x| **x != NULL_BYTE)
.count();
let max = offset + len;
(max, &buf[offset..max])
}
#[derive(Serialize, Deserialize, Debug)]
struct ServerInfo<'a> {
success: bool,
min_protocol_version: usize,
max_protocol_version: usize,
server_version: &'a str,
}
impl ServerInfo<'_> {
fn validate(resp: &[u8]) -> Result<()> {
let info = serde_json::from_slice::<ServerInfo>(resp)?;
if !info.success {
return Err(err::Runtime::Internal(bytes_to_string(resp)).into());
}
#[allow(clippy::absurd_extreme_comparisons)]
if PROTOCOL_VERSION < info.min_protocol_version
|| info.max_protocol_version < PROTOCOL_VERSION
{
let msg = format!(
"unsupported protocol version {version}, expected between {min} and {max}",
version = PROTOCOL_VERSION,
min = info.min_protocol_version,
max = info.max_protocol_version,
);
return Err(err::Driver::Other(msg).into());
}
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug)]
struct AuthRequest {
protocol_version: usize,
authentication_method: &'static str,
authentication: String,
}
fn client_first(scram: ScramClient<'_>) -> Result<(ServerFirst<'_>, Vec<u8>)> {
let (scram, client_first) = scram.client_first();
let ar = AuthRequest {
protocol_version: PROTOCOL_VERSION,
authentication_method: "SCRAM-SHA-256",
authentication: client_first,
};
let mut msg = serde_json::to_vec(&ar)?;
msg.push(NULL_BYTE);
Ok((scram, msg))
}
#[derive(Serialize, Deserialize, Debug)]
struct AuthConfirmation {
authentication: String,
}
fn client_final(scram: ServerFirst<'_>, auth: &str) -> Result<(ServerFinal, Vec<u8>)> {
let scram = scram
.handle_server_first(auth)
.map_err(|x| x.to_string())
.map_err(err::Driver::Other)?;
let (scram, client_final) = scram.client_final();
let conf = AuthConfirmation {
authentication: client_final,
};
let mut msg = serde_json::to_vec(&conf)?;
msg.push(NULL_BYTE);
Ok((scram, msg))
}
#[derive(Serialize, Deserialize, Debug)]
struct AuthResponse {
success: bool,
authentication: Option<String>,
error_code: Option<usize>,
error: Option<String>,
}
impl AuthResponse {
fn from_slice(resp: &[u8]) -> Result<Self> {
let info = serde_json::from_slice::<AuthResponse>(resp)?;
if !info.success {
if let Some(10..=20) = info.error_code {
if let Some(msg) = info.error {
return Err(err::Driver::Auth(msg).into());
}
}
return Err(err::Runtime::Internal(bytes_to_string(resp)).into());
}
Ok(info)
}
}
fn server_final(scram: ServerFinal, resp: &[u8]) -> Result<()> {
let info = AuthResponse::from_slice(resp)?;
if let Some(auth) = info.authentication {
if let Err(error) = scram.handle_server_final(&auth) {
return Err(err::Driver::Other(error.to_string()).into());
}
}
Ok(())
}