use std::{collections::HashMap, sync::Arc};
use futures_util::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use reifydb_type::{
error::{Diagnostic, Error},
params::Params,
value::frame::frame::Frame,
};
use reifydb_wire_format::decode::decode_frames;
use serde_json::{Value, from_str, to_string};
use tokio::{
net::TcpStream,
select, spawn,
sync::{Mutex, mpsc, oneshot},
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config, tungstenite::Message};
use crate::{
AdminRequest, AdminResult, AuthRequest, ChangePayload, CommandRequest, CommandResult, LoginResult,
QueryRequest, QueryResult, Request, RequestPayload, Response, ResponseMeta, ResponsePayload, ServerPush,
SubscribeRequest, UnsubscribeRequest, WireFormat, params_to_wire,
session::{parse_admin_response, parse_command_response, parse_query_response},
utils::generate_request_id,
};
enum ClientResponse {
Json(Box<Response>),
Frames(Vec<Frame>, Option<ResponseMeta>),
}
type PendingRequests = Arc<Mutex<HashMap<String, oneshot::Sender<ClientResponse>>>>;
pub struct WsClient {
request_tx: mpsc::Sender<(Request, oneshot::Sender<ClientResponse>)>,
shutdown_tx: Option<mpsc::Sender<()>>,
is_authenticated: bool,
change_rx: mpsc::Receiver<ChangePayload>,
format: WireFormat,
}
impl WsClient {
pub async fn connect(url: &str, format: WireFormat) -> Result<Self, Error> {
if format == WireFormat::Proto {
return Err(Error(Box::new(Diagnostic {
code: "INVALID_FORMAT".to_string(),
message: "WireFormat::Proto is not supported for WsClient".to_string(),
..Default::default()
})));
}
let url = if !url.starts_with("ws://") && !url.starts_with("wss://") {
format!("ws://{}", url)
} else {
url.to_string()
};
let (ws_stream, _) = connect_async_with_config(&url, None, true).await.unwrap();
let (write, read) = ws_stream.split();
let (request_tx, request_rx) = mpsc::channel::<(Request, oneshot::Sender<ClientResponse>)>(32);
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let (change_tx, change_rx) = mpsc::channel::<ChangePayload>(100);
let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending.clone();
spawn(async move {
Self::connection_loop(write, read, request_rx, shutdown_rx, pending_clone, change_tx).await;
});
Ok(Self {
request_tx,
shutdown_tx: Some(shutdown_tx),
is_authenticated: false,
change_rx,
format,
})
}
async fn connection_loop(
mut write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
mut read: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
mut request_rx: mpsc::Receiver<(Request, oneshot::Sender<ClientResponse>)>,
mut shutdown_rx: mpsc::Receiver<()>,
pending: PendingRequests,
change_tx: mpsc::Sender<ChangePayload>,
) {
loop {
select! {
Some(msg) = read.next() => {
match msg {
Ok(Message::Text(text)) => {
if let Ok(response) = from_str::<Response>(&text) {
let mut pending_guard = pending.lock().await;
if let Some(tx) = pending_guard.remove(&response.id) {
let _ = tx.send(ClientResponse::Json(Box::new(response)));
}
}
else if let Ok(push) = from_str::<ServerPush>(&text) {
match push {
ServerPush::Change(change) => {
let _ = change_tx.send(change).await;
}
}
}
}
Ok(Message::Binary(data)) => {
if data.len() < 5 { continue; }
let kind = data[0];
let id_len = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
let meta_len_pos = 5 + id_len;
if data.len() < meta_len_pos + 4 { continue; }
let id = String::from_utf8_lossy(&data[5..meta_len_pos]).to_string();
let meta_len = u32::from_le_bytes([
data[meta_len_pos],
data[meta_len_pos + 1],
data[meta_len_pos + 2],
data[meta_len_pos + 3],
]) as usize;
let meta_start = meta_len_pos + 4;
if data.len() < meta_start + meta_len { continue; }
let meta = if meta_len > 0 {
from_str::<ResponseMeta>(
&String::from_utf8_lossy(&data[meta_start..meta_start + meta_len])
).ok()
} else {
None
};
let rbcf_data = &data[meta_start + meta_len..];
let frames = match decode_frames(rbcf_data) {
Ok(f) => f,
Err(_) => continue,
};
match kind {
0x00 => {
let mut pending_guard = pending.lock().await;
if let Some(tx) = pending_guard.remove(&id) {
let _ = tx.send(ClientResponse::Frames(frames, meta));
}
}
0x01 => {
let _ = change_tx.send(ChangePayload {
subscription_id: id,
content_type: "application/vnd.reifydb.rbcf".to_string(),
body: Value::Null,
frames: Some(frames),
}).await;
}
_ => {}
}
}
Ok(Message::Ping(data)) => {
let _ = write.send(Message::Pong(data)).await;
}
Ok(Message::Close(_)) => {
break;
}
Err(_) => {
break;
}
_ => {}
}
}
Some((request, response_tx)) = request_rx.recv() => {
let id = request.id.clone();
{
let mut pending_guard = pending.lock().await;
pending_guard.insert(id, response_tx);
}
if let Ok(json) = to_string(&request)
&& write.send(Message::Text(json.into())).await.is_err() {
break;
}
}
_ = shutdown_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
break;
}
}
}
let mut pending_guard = pending.lock().await;
pending_guard.clear();
}
fn wire_format(&self) -> Option<String> {
match self.format {
WireFormat::Rbcf => Some("rbcf".to_string()),
WireFormat::Json => Some("frames".to_string()),
WireFormat::Proto => None,
}
}
pub async fn authenticate(&mut self, token: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: Some(token.to_string()),
method: None,
credentials: None,
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Auth(_) => {
self.is_authenticated = true;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for auth"), }
}
pub async fn login_with_password(&mut self, identifier: &str, password: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("identifier".to_string(), identifier.to_string());
credentials.insert("password".to_string(), password.to_string());
self.login("password", credentials).await
}
pub async fn login_with_token(&mut self, token: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("token".to_string(), token.to_string());
self.login("token", credentials).await
}
pub async fn login(
&mut self,
method: &str,
credentials: HashMap<String, String>,
) -> Result<LoginResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: None,
method: Some(method.to_string()),
credentials: Some(credentials),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Auth(auth) => {
if auth.status.as_deref() == Some("authenticated") {
let token = auth.token.unwrap_or_default();
let identity = auth.identity.unwrap_or_default();
self.is_authenticated = true;
Ok(LoginResult {
token,
identity,
})
} else {
panic!("Authentication failed") }
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for login"), }
}
pub async fn logout(&mut self) -> Result<(), Error> {
if !self.is_authenticated {
return Ok(());
}
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Logout,
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Logout(_) => {
self.is_authenticated = false;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for logout"), }
}
pub async fn admin(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.admin_with_meta(rql, params).await?.frames)
}
pub async fn admin_with_meta(&self, rql: &str, params: Option<Params>) -> Result<AdminResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Admin(AdminRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(AdminResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_admin_response(*resp),
}
}
pub async fn command(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.command_with_meta(rql, params).await?.frames)
}
pub async fn command_with_meta(&self, rql: &str, params: Option<Params>) -> Result<CommandResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Command(CommandRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(CommandResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_command_response(*resp),
}
}
pub async fn query(&self, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
Ok(self.query_with_meta(rql, params).await?.frames)
}
pub async fn query_with_meta(&self, rql: &str, params: Option<Params>) -> Result<QueryResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Query(QueryRequest {
rql: rql.to_string(),
params: params.and_then(params_to_wire),
format: self.wire_format(),
}),
};
match self.send_request(request).await? {
ClientResponse::Frames(frames, meta) => Ok(QueryResult {
frames,
meta,
}),
ClientResponse::Json(resp) => parse_query_response(*resp),
}
}
pub async fn subscribe(&self, rql: &str) -> Result<String, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Subscribe(SubscribeRequest {
rql: rql.to_string(),
format: self.wire_format(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Subscribed(sub) => Ok(sub.subscription_id),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for subscribe"), }
}
pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Unsubscribe(UnsubscribeRequest {
subscription_id: subscription_id.to_string(),
}),
};
let response = self.send_request_json(request).await?;
match response.payload {
ResponsePayload::Unsubscribed(_) => Ok(()),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for unsubscribe"), }
}
pub async fn recv(&mut self) -> Option<ChangePayload> {
self.change_rx.recv().await
}
pub fn try_recv(&mut self) -> Result<ChangePayload, mpsc::error::TryRecvError> {
self.change_rx.try_recv()
}
async fn send_request(&self, request: Request) -> Result<ClientResponse, Error> {
let (tx, rx) = oneshot::channel();
self.request_tx.send((request, tx)).await.unwrap();
Ok(rx.await.unwrap()) }
async fn send_request_json(&self, request: Request) -> Result<Response, Error> {
match self.send_request(request).await? {
ClientResponse::Json(resp) => Ok(*resp),
ClientResponse::Frames(_, _) => panic!("unexpected binary response"),
}
}
pub async fn close(mut self) -> Result<(), Error> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
pub fn is_authenticated(&self) -> bool {
self.is_authenticated
}
}
impl Drop for WsClient {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.try_send(());
}
}
}