use std::{collections::HashMap, sync::Arc};
use futures_util::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use reifydb_type::{
error::{Diagnostic, Error},
params::Params,
value::frame::frame::Frame,
};
use reifydb_wire_format::decode::decode_frames;
use serde_json::{Value, from_str, to_string};
use tokio::{
net::TcpStream,
select, spawn,
sync::{Mutex, mpsc, oneshot},
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config, tungstenite::Message};
use crate::{
AdminRequest, AdminResult, AuthRequest, BatchChangeEntry, BatchChangePayload, BatchClosedPayload,
BatchMemberClosedPayload, BatchMemberInfo, BatchSubscribeRequest, BatchUnsubscribeRequest, CallRequest,
ChangePayload, CommandRequest, CommandResult, LoginResult, QueryRequest, QueryResult, Request, RequestPayload,
Response, ResponseMeta, ResponsePayload, ServerPush, SubscribeRequest, UnsubscribeRequest, WireFormat,
params_to_wire,
session::{parse_admin_response, parse_call_response, parse_command_response, parse_query_response},
utils::generate_request_id,
};
enum ClientResponse {
Json(Box<Response>),
Frames(Vec<Frame>, Option<ResponseMeta>),
}
type PendingRequests = Arc<Mutex<HashMap<String, oneshot::Sender<ClientResponse>>>>;
#[derive(Debug, Clone)]
pub enum BatchPushEvent {
Change(BatchChangePayload),
MemberClosed(BatchMemberClosedPayload),
Closed(BatchClosedPayload),
}
type BatchRouters = Arc<Mutex<HashMap<String, mpsc::Sender<BatchPushEvent>>>>;
pub struct WsClient {
request_tx: mpsc::Sender<(Request, oneshot::Sender<ClientResponse>)>,
shutdown_tx: Option<mpsc::Sender<()>>,
is_authenticated: bool,
change_rx: mpsc::Receiver<ChangePayload>,
batch_routers: BatchRouters,
format: WireFormat,
}
impl WsClient {
pub async fn connect(url: &str, format: WireFormat) -> Result<Self, Error> {
if format == WireFormat::Proto {
return Err(Error(Box::new(Diagnostic {
code: "INVALID_FORMAT".to_string(),
message: "WireFormat::Proto is not supported for WsClient".to_string(),
..Default::default()
})));
}
let url = if !url.starts_with("ws://") && !url.starts_with("wss://") {
format!("ws://{}", url)
} else {
url.to_string()
};
let (ws_stream, _) = connect_async_with_config(&url, None, true).await.unwrap();
let (write, read) = ws_stream.split();
let (request_tx, request_rx) = mpsc::channel::<(Request, oneshot::Sender<ClientResponse>)>(32);
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let (change_tx, change_rx) = mpsc::channel::<ChangePayload>(100);
let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new()));
let batch_routers: BatchRouters = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending.clone();
let batch_routers_clone = batch_routers.clone();
spawn(async move {
Self::connection_loop(
write,
read,
request_rx,
shutdown_rx,
pending_clone,
change_tx,
batch_routers_clone,
)
.await;
});
Ok(Self {
request_tx,
shutdown_tx: Some(shutdown_tx),
is_authenticated: false,
change_rx,
batch_routers,
format,
})
}
async fn connection_loop(
mut write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
mut read: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
mut request_rx: mpsc::Receiver<(Request, oneshot::Sender<ClientResponse>)>,
mut shutdown_rx: mpsc::Receiver<()>,
pending: PendingRequests,
change_tx: mpsc::Sender<ChangePayload>,
batch_routers: BatchRouters,
) {
loop {
select! {
Some(msg) = read.next() => {
match msg {
Ok(Message::Text(text)) => {
if let Ok(response) = from_str::<Response>(&text) {
let mut pending_guard = pending.lock().await;
if let Some(tx) = pending_guard.remove(&response.id) {
let _ = tx.send(ClientResponse::Json(Box::new(response)));
}
}
else if let Ok(push) = from_str::<ServerPush>(&text) {
match push {
ServerPush::Change(change) => {
let _ = change_tx.send(change).await;
}
ServerPush::BatchChange(batch) => {
let sender = {
let routers = batch_routers.lock().await;
routers.get(&batch.batch_id).cloned()
};
if let Some(tx) = sender {
let _ = tx.send(BatchPushEvent::Change(batch)).await;
}
}
ServerPush::BatchMemberClosed(m) => {
let sender = {
let routers = batch_routers.lock().await;
routers.get(&m.batch_id).cloned()
};
if let Some(tx) = sender {
let _ = tx.send(BatchPushEvent::MemberClosed(m)).await;
}
}
ServerPush::BatchClosed(c) => {
let batch_id = c.batch_id.clone();
let sender = {
let mut routers = batch_routers.lock().await;
routers.remove(&batch_id)
};
if let Some(tx) = sender {
let _ = tx.send(BatchPushEvent::Closed(c)).await;
}
}
}
}
}
Ok(Message::Binary(data)) => {
if data.is_empty() { continue; }
let kind = data[0];
if kind == 0x02 {
if let Some(payload) = parse_rbcf_batch_envelope(&data) {
let batch_id = payload.batch_id.clone();
let sender = {
let routers = batch_routers.lock().await;
routers.get(&batch_id).cloned()
};
if let Some(tx) = sender {
let _ = tx.send(BatchPushEvent::Change(payload)).await;
}
}
continue;
}
if data.len() < 5 { continue; }
let id_len = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
let meta_len_pos = 5 + id_len;
if data.len() < meta_len_pos + 4 { continue; }
let id = String::from_utf8_lossy(&data[5..meta_len_pos]).to_string();
let meta_len = u32::from_le_bytes([
data[meta_len_pos],
data[meta_len_pos + 1],
data[meta_len_pos + 2],
data[meta_len_pos + 3],
]) as usize;
let meta_start = meta_len_pos + 4;
if data.len() < meta_start + meta_len { continue; }
let meta = if meta_len > 0 {
from_str::<ResponseMeta>(
&String::from_utf8_lossy(&data[meta_start..meta_start + meta_len])
).ok()
} else {
None
};
let rbcf_data = &data[meta_start + meta_len..];
let frames = match decode_frames(rbcf_data) {
Ok(f) => f,
Err(_) => continue,
};
match kind {
0x00 => {
let mut pending_guard = pending.lock().await;
if let Some(tx) = pending_guard.remove(&id) {
let _ = tx.send(ClientResponse::Frames(frames, meta));
}
}
0x01 => {
let _ = change_tx.send(ChangePayload {
subscription_id: id,
content_type: "application/vnd.reifydb.rbcf".to_string(),
body: Value::Null,
frames: Some(frames),
}).await;
}
_ => {}
}
}
Ok(Message::Ping(data)) => {
let _ = write.send(Message::Pong(data)).await;
}
Ok(Message::Close(_)) => {
break;
}
Err(_) => {
break;
}
_ => {}
}
}
Some((request, response_tx)) = request_rx.recv() => {
let id = request.id.clone();
{
let mut pending_guard = pending.lock().await;
pending_guard.insert(id, response_tx);
}
if let Ok(json) = to_string(&request)
&& write.send(Message::Text(json.into())).await.is_err() {
break;
}
}
_ = shutdown_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
break;
}
}
}
let mut pending_guard = pending.lock().await;
pending_guard.clear();
}
fn wire_format(&self) -> Option<String> {
match self.format {
WireFormat::Rbcf => Some("rbcf".to_string()),
WireFormat::Json => Some("frames".to_string()),
WireFormat::Proto => None,
}
}
pub async fn authenticate(&mut self, token: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: Some(token.to_string()),
method: None,
credentials: None,
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Auth(_) => {
self.is_authenticated = true;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for auth"), }
}
pub async fn login_with_password(&mut self, identifier: &str, password: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("identifier".to_string(), identifier.to_string());
credentials.insert("password".to_string(), password.to_string());
self.login("password", credentials).await
}
pub async fn login_with_token(&mut self, token: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("token".to_string(), token.to_string());
self.login("token", credentials).await
}
pub async fn login(
&mut self,
method: &str,
credentials: HashMap<String, String>,
) -> Result<LoginResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: None,
method: Some(method.to_string()),
credentials: Some(credentials),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Auth(auth) => {
if auth.status.as_deref() == Some("authenticated") {
let token = auth.token.unwrap_or_default();
let identity = auth.identity.unwrap_or_default();
self.is_authenticated = true;
Ok(LoginResult {
token,
identity,
})
} else {
panic!("Authentication failed") }
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for login"), }
}
pub async fn logout(&mut self) -> Result<(), Error> {
if !self.is_authenticated {
return Ok(());
}
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Logout,
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Logout(_) => {
self.is_authenticated = false;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for logout"), }
}
pub async fn admin(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.admin_with_meta(rql, params).await?.frames)
}
pub async fn admin_with_meta(&self, rql: &str, params: Option<Params>) -> Result<AdminResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Admin(AdminRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(AdminResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_admin_response(*resp),
}
}
pub async fn command(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.command_with_meta(rql, params).await?.frames)
}
pub async fn command_with_meta(&self, rql: &str, params: Option<Params>) -> Result<CommandResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Command(CommandRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(CommandResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_command_response(*resp),
}
}
pub async fn query(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.query_with_meta(rql, params).await?.frames)
}
pub async fn query_with_meta(&self, rql: &str, params: Option<Params>) -> Result<QueryResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Query(QueryRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(QueryResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_query_response(*resp),
}
}
pub async fn call(&self, name: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.call_with_meta(name, params).await?.frames)
}
pub async fn call_with_meta(&self, name: &str, params: Option<Params>) -> Result<CommandResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Call(CallRequest {
name: name.to_string(),
params: params.and_then(params_to_wire),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(CommandResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_call_response(*resp),
}
}
pub async fn subscribe(&self, rql: &str) -> Result<String, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Subscribe(SubscribeRequest {
rql: rql.to_string(),
format: self.wire_format(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Subscribed(sub) => Ok(sub.subscription_id),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for subscribe"), }
}
pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Unsubscribe(UnsubscribeRequest {
subscription_id: subscription_id.to_string(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Unsubscribed(_) => Ok(()),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for unsubscribe"), }
}
pub async fn batch_subscribe(&self, queries: &[&str]) -> Result<WsBatchSubscription, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::BatchSubscribe(BatchSubscribeRequest {
queries: queries.iter().map(|q| q.to_string()).collect(),
format: self.wire_format(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::BatchSubscribed(ack) => {
let (push_tx, push_rx) = mpsc::channel::<BatchPushEvent>(100);
{
let mut routers = self.batch_routers.lock().await;
routers.insert(ack.batch_id.clone(), push_tx);
}
Ok(WsBatchSubscription {
batch_id: ack.batch_id,
members: ack.members,
push_rx,
})
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => Err(Error(Box::new(Diagnostic {
code: "UNEXPECTED_RESPONSE".to_string(),
message: "Unexpected response type for BatchSubscribe".to_string(),
..Default::default()
}))),
}
}
pub async fn batch_unsubscribe(&self, batch_id: &str) -> Result<(), Error> {
{
let mut routers = self.batch_routers.lock().await;
routers.remove(batch_id);
}
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::BatchUnsubscribe(BatchUnsubscribeRequest {
batch_id: batch_id.to_string(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::BatchUnsubscribed(_) => Ok(()),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => Err(Error(Box::new(Diagnostic {
code: "UNEXPECTED_RESPONSE".to_string(),
message: "Unexpected response type for BatchUnsubscribe".to_string(),
..Default::default()
}))),
}
}
pub async fn recv(&mut self) -> Option<ChangePayload> {
self.change_rx.recv().await
}
pub fn try_recv(&mut self) -> Result<ChangePayload, mpsc::error::TryRecvError> {
self.change_rx.try_recv()
}
async fn send_request(&self, request: Request) -> Result<ClientResponse, Error> {
let (tx, rx) = oneshot::channel();
self.request_tx.send((request, tx)).await.unwrap();
Ok(rx.await.unwrap()) }
async fn send_request_json(&self, request: Request) -> Result<Response, Error> {
match self.send_request(request).await? {
ClientResponse::Json(resp) => Ok(*resp),
ClientResponse::Frames(_, _) => panic!("unexpected binary response"),
}
}
pub async fn close(mut self) -> Result<(), Error> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
pub fn is_authenticated(&self) -> bool {
self.is_authenticated
}
}
impl Drop for WsClient {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.try_send(());
}
}
}
pub struct WsBatchSubscription {
batch_id: String,
members: Vec<BatchMemberInfo>,
push_rx: mpsc::Receiver<BatchPushEvent>,
}
impl WsBatchSubscription {
pub fn batch_id(&self) -> &str {
&self.batch_id
}
pub fn members(&self) -> &[BatchMemberInfo] {
&self.members
}
pub async fn recv(&mut self) -> Option<BatchPushEvent> {
self.push_rx.recv().await
}
}
fn parse_rbcf_batch_envelope(data: &[u8]) -> Option<BatchChangePayload> {
if data.len() < 5 || data[0] != 0x02 {
return None;
}
let batch_id_len = u32::from_le_bytes(data[1..5].try_into().ok()?) as usize;
let batch_id_end = 5 + batch_id_len;
if data.len() < batch_id_end + 4 {
return None;
}
let batch_id = String::from_utf8_lossy(&data[5..batch_id_end]).into_owned();
let num_entries = u32::from_le_bytes(data[batch_id_end..batch_id_end + 4].try_into().ok()?) as usize;
let mut pos = batch_id_end + 4;
let mut entries = Vec::with_capacity(num_entries);
for _ in 0..num_entries {
if data.len() < pos + 4 {
return None;
}
let sub_id_len = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
pos += 4;
if data.len() < pos + sub_id_len + 4 {
return None;
}
let sub_id = String::from_utf8_lossy(&data[pos..pos + sub_id_len]).into_owned();
pos += sub_id_len;
let rbcf_len = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
pos += 4;
if data.len() < pos + rbcf_len {
return None;
}
let rbcf_bytes = &data[pos..pos + rbcf_len];
pos += rbcf_len;
let (frames, decode_error) = match decode_frames(rbcf_bytes) {
Ok(f) => (Some(f), None),
Err(e) => (None, Some(e.to_string())),
};
entries.push(BatchChangeEntry {
subscription_id: sub_id,
content_type: "application/vnd.reifydb.rbcf".to_string(),
body: Value::Null,
frames,
decode_error,
});
}
Some(BatchChangePayload {
batch_id,
entries,
})
}