pub mod codec;
mod frame;
mod handshake;
pub mod scram;
#[cfg(feature = "redwire-tls")]
mod tls;
pub use codec::FrameError;
pub use frame::{Flags, Frame, MessageKind};
#[cfg(feature = "redwire-tls")]
pub use tls::TlsConfig;
use std::io;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use reddb_wire::query_with_params::{
encode_query_with_params, ParamValue as RedWireParamValue, FEATURE_PARAMS,
};
pub(crate) type Stream = Pin<Box<dyn AsyncReadWrite + Send + Unpin>>;
pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite + ?Sized> AsyncReadWrite for T {}
use crate::error::{ClientError, ErrorCode, Result};
use crate::types::{BulkInsertResult, QueryResult};
use codec::{decode_frame, encode_frame};
use frame::FRAME_HEADER_SIZE;
use handshake::HandshakeOutcome;
pub const MAGIC: u8 = 0xFE;
pub const SUPPORTED_VERSION: u8 = 0x01;
#[derive(Debug, Clone)]
pub enum Auth {
Anonymous,
Bearer(String),
}
#[derive(Debug, Clone)]
pub enum BinaryValue {
I64(i64),
F64(f64),
Text(String),
Bool(bool),
Null,
}
impl BinaryValue {
const TAG_I64: u8 = 1;
const TAG_F64: u8 = 2;
const TAG_TEXT: u8 = 3;
const TAG_BOOL: u8 = 4;
const TAG_NULL: u8 = 0;
pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
match self {
Self::I64(n) => {
buf.push(Self::TAG_I64);
buf.extend_from_slice(&n.to_le_bytes());
}
Self::F64(n) => {
buf.push(Self::TAG_F64);
buf.extend_from_slice(&n.to_le_bytes());
}
Self::Text(s) => {
buf.push(Self::TAG_TEXT);
buf.extend_from_slice(&(s.len() as u32).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
Self::Bool(b) => {
buf.push(Self::TAG_BOOL);
buf.push(if *b { 1 } else { 0 });
}
Self::Null => buf.push(Self::TAG_NULL),
}
}
}
pub struct ConnectOptions {
pub host: String,
pub port: u16,
pub auth: Auth,
pub client_name: Option<String>,
#[cfg(feature = "redwire-tls")]
pub tls: Option<TlsConfig>,
}
impl std::fmt::Debug for ConnectOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("ConnectOptions");
s.field("host", &self.host)
.field("port", &self.port)
.field("auth", &self.auth)
.field("client_name", &self.client_name);
#[cfg(feature = "redwire-tls")]
s.field("tls", &self.tls.is_some());
s.finish()
}
}
impl Clone for ConnectOptions {
fn clone(&self) -> Self {
Self {
host: self.host.clone(),
port: self.port,
auth: self.auth.clone(),
client_name: self.client_name.clone(),
#[cfg(feature = "redwire-tls")]
tls: self.tls.clone(),
}
}
}
impl ConnectOptions {
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into(),
port,
auth: Auth::Anonymous,
client_name: Some(format!("reddb-rs/{}", env!("CARGO_PKG_VERSION"))),
#[cfg(feature = "redwire-tls")]
tls: None,
}
}
pub fn with_auth(mut self, auth: Auth) -> Self {
self.auth = auth;
self
}
pub fn with_client_name(mut self, name: impl Into<String>) -> Self {
self.client_name = Some(name.into());
self
}
#[cfg(feature = "redwire-tls")]
pub fn with_tls(mut self, tls: TlsConfig) -> Self {
self.tls = Some(tls);
self
}
}
pub struct RedWireClient {
stream: Stream,
next_correlation_id: u64,
#[allow(dead_code)]
session_id: String,
#[allow(dead_code)]
server_features: u32,
}
impl std::fmt::Debug for RedWireClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedWireClient")
.field("session_id", &self.session_id)
.field("server_features", &self.server_features)
.finish()
}
}
impl RedWireClient {
pub async fn connect(opts: ConnectOptions) -> Result<Self> {
let addr = format!("{}:{}", opts.host, opts.port);
let tcp = TcpStream::connect(&addr)
.await
.map_err(|e| ClientError::new(ErrorCode::Network, format!("{addr}: {e}")))?;
let mut stream: Stream = match () {
#[cfg(feature = "redwire-tls")]
_ if opts.tls.is_some() => {
let tls_cfg = opts.tls.as_ref().unwrap();
let tls_stream = tls::wrap_client(tcp, &opts.host, tls_cfg).await?;
Box::pin(tls_stream)
}
_ => Box::pin(tcp),
};
stream
.write_all(&[MAGIC, SUPPORTED_VERSION])
.await
.map_err(io_err)?;
let outcome = handshake::run(&mut stream, &opts).await?;
match outcome {
HandshakeOutcome::Authenticated {
session_id,
server_features,
} => Ok(Self {
stream,
next_correlation_id: 1,
session_id,
server_features,
}),
HandshakeOutcome::Refused(reason) => Err(ClientError::new(
ErrorCode::AuthRefused,
format!("redwire auth refused: {reason}"),
)),
}
}
pub async fn query(&mut self, sql: &str) -> Result<QueryResult> {
let corr = self.next_corr();
let req = Frame::new(MessageKind::Query, corr, sql.as_bytes().to_vec());
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::Result => {
let value: serde_json::Value =
serde_json::from_slice(&resp.payload).map_err(|e| {
ClientError::new(ErrorCode::Protocol, format!("decode result: {e}"))
})?;
Ok(QueryResult::from_envelope(value))
}
MessageKind::Error => {
let msg = String::from_utf8_lossy(&resp.payload).to_string();
Err(ClientError::new(ErrorCode::Engine, msg))
}
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected Result/Error, got {other:?}"),
)),
}
}
pub async fn query_with(
&mut self,
sql: &str,
params: &[crate::params::Value],
) -> Result<QueryResult> {
if params.is_empty() {
return self.query(sql).await;
}
if self.server_features & FEATURE_PARAMS == 0 {
return Err(ClientError::new(
ErrorCode::ParamsUnsupported,
"server did not advertise RedWire parameter support",
));
}
let wire_params = params.iter().map(param_to_redwire).collect::<Vec<_>>();
let payload = encode_query_with_params(sql, &wire_params)
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("encode params: {e}")))?;
let corr = self.next_corr();
let req = Frame::new(MessageKind::QueryWithParams, corr, payload);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::Result => {
let value: serde_json::Value =
serde_json::from_slice(&resp.payload).map_err(|e| {
ClientError::new(ErrorCode::Protocol, format!("decode result: {e}"))
})?;
Ok(QueryResult::from_envelope(value))
}
MessageKind::Error => {
let msg = String::from_utf8_lossy(&resp.payload).to_string();
Err(ClientError::new(ErrorCode::Engine, msg))
}
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected Result/Error, got {other:?}"),
)),
}
}
pub async fn insert(&mut self, collection: &str, payload: serde_json::Value) -> Result<u64> {
let mut obj = serde_json::Map::new();
obj.insert(
"collection".into(),
serde_json::Value::String(collection.to_string()),
);
obj.insert("payload".into(), payload);
self.send_insert_frame(serde_json::Value::Object(obj))
.await
.map(|result| result.affected)
}
pub async fn bulk_insert(
&mut self,
collection: &str,
payloads: Vec<serde_json::Value>,
) -> Result<BulkInsertResult> {
let mut obj = serde_json::Map::new();
obj.insert(
"collection".into(),
serde_json::Value::String(collection.to_string()),
);
obj.insert("payloads".into(), serde_json::Value::Array(payloads));
self.send_insert_frame(serde_json::Value::Object(obj)).await
}
async fn send_insert_frame(&mut self, body: serde_json::Value) -> Result<BulkInsertResult> {
let bytes = serde_json::to_vec(&body)
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("encode insert: {e}")))?;
let corr = self.next_corr();
let req = Frame::new(MessageKind::BulkInsert, corr, bytes);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::BulkOk => {
let v: serde_json::Value = serde_json::from_slice(&resp.payload).map_err(|e| {
ClientError::new(ErrorCode::Protocol, format!("decode bulk_ok: {e}"))
})?;
Ok(bulk_insert_result_from_json(v))
}
MessageKind::Error => {
let msg = String::from_utf8_lossy(&resp.payload).to_string();
Err(ClientError::new(ErrorCode::Engine, msg))
}
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected BulkOk/Error, got {other:?}"),
)),
}
}
pub async fn get(&mut self, collection: &str, id: &str) -> Result<serde_json::Value> {
let mut obj = serde_json::Map::new();
obj.insert(
"collection".into(),
serde_json::Value::String(collection.to_string()),
);
obj.insert("id".into(), serde_json::Value::String(id.to_string()));
let bytes = serde_json::to_vec(&serde_json::Value::Object(obj))
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("encode get: {e}")))?;
let corr = self.next_corr();
let req = Frame::new(MessageKind::Get, corr, bytes);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::Result => serde_json::from_slice(&resp.payload)
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("decode get: {e}"))),
MessageKind::Error => Err(ClientError::new(
ErrorCode::Engine,
String::from_utf8_lossy(&resp.payload).to_string(),
)),
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected Result/Error, got {other:?}"),
)),
}
}
pub async fn delete(&mut self, collection: &str, id: &str) -> Result<u64> {
let mut obj = serde_json::Map::new();
obj.insert(
"collection".into(),
serde_json::Value::String(collection.to_string()),
);
obj.insert("id".into(), serde_json::Value::String(id.to_string()));
let bytes = serde_json::to_vec(&serde_json::Value::Object(obj))
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("encode delete: {e}")))?;
let corr = self.next_corr();
let req = Frame::new(MessageKind::Delete, corr, bytes);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::DeleteOk => {
let v: serde_json::Value = serde_json::from_slice(&resp.payload).map_err(|e| {
ClientError::new(ErrorCode::Protocol, format!("decode delete_ok: {e}"))
})?;
Ok(v.as_object()
.and_then(|o| o.get("affected"))
.and_then(|x| x.as_u64())
.unwrap_or(0))
}
MessageKind::Error => Err(ClientError::new(
ErrorCode::Engine,
String::from_utf8_lossy(&resp.payload).to_string(),
)),
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected DeleteOk/Error, got {other:?}"),
)),
}
}
pub async fn bulk_insert_binary(
&mut self,
collection: &str,
columns: &[&str],
rows: &[Vec<BinaryValue>],
) -> Result<u64> {
let mut payload = Vec::with_capacity(64 + rows.len() * columns.len() * 16);
payload.extend_from_slice(&(collection.len() as u16).to_le_bytes());
payload.extend_from_slice(collection.as_bytes());
payload.extend_from_slice(&(columns.len() as u16).to_le_bytes());
for c in columns {
payload.extend_from_slice(&(c.len() as u16).to_le_bytes());
payload.extend_from_slice(c.as_bytes());
}
payload.extend_from_slice(&(rows.len() as u32).to_le_bytes());
for row in rows {
if row.len() != columns.len() {
return Err(ClientError::new(
ErrorCode::Protocol,
format!("row had {} values for {} columns", row.len(), columns.len()),
));
}
for v in row {
v.encode(&mut payload);
}
}
let corr = self.next_corr();
let req = Frame::new(MessageKind::BulkInsertBinary, corr, payload);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::BulkOk => {
if resp.payload.len() < 8 {
return Err(ClientError::new(
ErrorCode::Protocol,
"BulkOk truncated: expected 8-byte count",
));
}
Ok(u64::from_le_bytes([
resp.payload[0],
resp.payload[1],
resp.payload[2],
resp.payload[3],
resp.payload[4],
resp.payload[5],
resp.payload[6],
resp.payload[7],
]))
}
MessageKind::Error => Err(ClientError::new(
ErrorCode::Engine,
String::from_utf8_lossy(&resp.payload).to_string(),
)),
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected BulkOk/Error, got {other:?}"),
)),
}
}
pub async fn ping(&mut self) -> Result<()> {
let corr = self.next_corr();
let req = Frame::new(MessageKind::Ping, corr, vec![]);
self.stream
.write_all(&encode_frame(&req))
.await
.map_err(io_err)?;
let resp = self.read_frame().await?;
match resp.kind {
MessageKind::Pong => Ok(()),
other => Err(ClientError::new(
ErrorCode::Protocol,
format!("expected Pong, got {other:?}"),
)),
}
}
pub async fn close(mut self) -> Result<()> {
let corr = self.next_corr();
let bye = Frame::new(MessageKind::Bye, corr, vec![]);
let _ = self.stream.write_all(&encode_frame(&bye)).await;
Ok(())
}
fn next_corr(&mut self) -> u64 {
let c = self.next_correlation_id;
self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
c
}
async fn read_frame(&mut self) -> Result<Frame> {
let mut header = [0u8; FRAME_HEADER_SIZE];
self.stream.read_exact(&mut header).await.map_err(io_err)?;
let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
if length < FRAME_HEADER_SIZE {
return Err(ClientError::new(
ErrorCode::Protocol,
format!("server sent a frame with length {length}"),
));
}
let mut buf = vec![0u8; length];
buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
if length > FRAME_HEADER_SIZE {
self.stream
.read_exact(&mut buf[FRAME_HEADER_SIZE..length])
.await
.map_err(io_err)?;
}
let (frame, _) = decode_frame(&buf)
.map_err(|e| ClientError::new(ErrorCode::Protocol, format!("decode frame: {e}")))?;
Ok(frame)
}
}
fn param_to_redwire(value: &crate::params::Value) -> RedWireParamValue {
match value {
crate::params::Value::Null => RedWireParamValue::Null,
crate::params::Value::Bool(value) => RedWireParamValue::Bool(*value),
crate::params::Value::Int(value) => RedWireParamValue::Int(*value),
crate::params::Value::Float(value) => RedWireParamValue::Float(*value),
crate::params::Value::Text(value) => RedWireParamValue::Text(value.clone()),
crate::params::Value::Bytes(value) => RedWireParamValue::Bytes(value.clone()),
crate::params::Value::Vector(value) => RedWireParamValue::Vector(value.clone()),
crate::params::Value::Json(value) => {
RedWireParamValue::Json(value.to_json_string().into_bytes())
}
crate::params::Value::Timestamp(value) => RedWireParamValue::Timestamp(*value),
crate::params::Value::Uuid(value) => RedWireParamValue::Uuid(*value),
}
}
fn io_err(err: io::Error) -> ClientError {
ClientError::new(ErrorCode::Network, err.to_string())
}
#[cfg(test)]
mod tests {
use super::param_to_redwire;
use crate::{JsonValue, Value};
use reddb_wire::query_with_params::ParamValue as WireValue;
#[test]
fn param_to_redwire_preserves_all_wire_variants() {
let uuid = [0x11; 16];
let cases = vec![
(Value::Null, WireValue::Null),
(Value::Bool(true), WireValue::Bool(true)),
(Value::Int64(42), WireValue::Int(42)),
(Value::Float(1.5), WireValue::Float(1.5)),
(Value::Text("Ada".into()), WireValue::Text("Ada".into())),
(
Value::Bytes(vec![0xde, 0xad]),
WireValue::Bytes(vec![0xde, 0xad]),
),
(
Value::Vector(vec![0.25, 0.5]),
WireValue::Vector(vec![0.25, 0.5]),
),
(
Value::Json(JsonValue::object([("role", JsonValue::string("admin"))])),
WireValue::Json(br#"{"role":"admin"}"#.to_vec()),
),
(
Value::Timestamp(1_700_000_000),
WireValue::Timestamp(1_700_000_000),
),
(Value::Uuid(uuid), WireValue::Uuid(uuid)),
];
for (input, expected) in cases {
assert_eq!(param_to_redwire(&input), expected);
}
}
}
fn bulk_insert_result_from_json(value: serde_json::Value) -> BulkInsertResult {
let affected = value
.as_object()
.and_then(|o| o.get("affected"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
let rids: Vec<String> = value
.as_object()
.and_then(|o| o.get("rids").or_else(|| o.get("ids")))
.and_then(|v| v.as_array())
.map(|items| items.iter().filter_map(json_id_to_string).collect())
.unwrap_or_default();
let ids = value
.as_object()
.and_then(|o| o.get("ids"))
.and_then(|v| v.as_array())
.map(|items| items.iter().filter_map(json_id_to_string).collect())
.unwrap_or_else(|| rids.clone());
BulkInsertResult {
affected,
rids,
ids,
}
}
fn json_id_to_string(value: &serde_json::Value) -> Option<String> {
value
.as_str()
.map(String::from)
.or_else(|| value.as_u64().map(|n| n.to_string()))
}