use std::{collections::HashMap, sync::Arc};
use futures_util::{SinkExt, StreamExt};
use reifydb_type::{error::Error, params::Params};
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::{
AdminRequest, AdminResult, AuthRequest, ChangePayload, CommandRequest, CommandResult, LoginResult,
QueryRequest, QueryResult, Request, RequestPayload, Response, ResponsePayload, ServerPush, SubscribeRequest,
UnsubscribeRequest, params_to_wire,
session::{parse_admin_response, parse_command_response, parse_query_response},
utils::generate_request_id,
};
type PendingRequests = Arc<Mutex<HashMap<String, oneshot::Sender<Response>>>>;
pub struct WsClient {
request_tx: mpsc::Sender<(Request, oneshot::Sender<Response>)>,
shutdown_tx: Option<mpsc::Sender<()>>,
is_authenticated: bool,
change_rx: mpsc::Receiver<ChangePayload>,
}
impl WsClient {
pub async fn connect(url: &str) -> Result<Self, Error> {
let url = if !url.starts_with("ws://") && !url.starts_with("wss://") {
format!("ws://{}", url)
} else {
url.to_string()
};
let (ws_stream, _) = connect_async(&url).await.unwrap();
let (write, read) = ws_stream.split();
let (request_tx, request_rx) = mpsc::channel::<(Request, oneshot::Sender<Response>)>(32);
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let (change_tx, change_rx) = mpsc::channel::<ChangePayload>(100);
let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending.clone();
tokio::spawn(async move {
Self::connection_loop(write, read, request_rx, shutdown_rx, pending_clone, change_tx).await;
});
Ok(Self {
request_tx,
shutdown_tx: Some(shutdown_tx),
is_authenticated: false,
change_rx,
})
}
async fn connection_loop(
mut write: futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
Message,
>,
mut read: futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
>,
mut request_rx: mpsc::Receiver<(Request, oneshot::Sender<Response>)>,
mut shutdown_rx: mpsc::Receiver<()>,
pending: PendingRequests,
change_tx: mpsc::Sender<ChangePayload>,
) {
loop {
tokio::select! {
Some(msg) = read.next() => {
match msg {
Ok(Message::Text(text)) => {
if let Ok(response) = serde_json::from_str::<Response>(&text) {
let mut pending_guard = pending.lock().await;
if let Some(tx) = pending_guard.remove(&response.id) {
let _ = tx.send(response);
}
}
else if let Ok(push) = serde_json::from_str::<ServerPush>(&text) {
match push {
ServerPush::Change(change) => {
let _ = change_tx.send(change).await;
}
}
}
}
Ok(Message::Ping(data)) => {
let _ = write.send(Message::Pong(data)).await;
}
Ok(Message::Close(_)) => {
break;
}
Err(_) => {
break;
}
_ => {}
}
}
Some((request, response_tx)) = request_rx.recv() => {
let id = request.id.clone();
{
let mut pending_guard = pending.lock().await;
pending_guard.insert(id, response_tx);
}
if let Ok(json) = serde_json::to_string(&request)
&& write.send(Message::Text(json.into())).await.is_err() {
break;
}
}
_ = shutdown_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
break;
}
}
}
let mut pending_guard = pending.lock().await;
pending_guard.clear();
}
pub async fn authenticate(&mut self, token: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: Some(token.to_string()),
method: None,
credentials: None,
}),
};
let response = self.send_request(request).await?;
match response.payload {
ResponsePayload::Auth(_) => {
self.is_authenticated = true;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for auth"), }
}
pub async fn login_with_password(&mut self, identifier: &str, password: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("identifier".to_string(), identifier.to_string());
credentials.insert("password".to_string(), password.to_string());
self.login("password", credentials).await
}
pub async fn login_with_token(&mut self, token: &str) -> Result<LoginResult, Error> {
let mut credentials = HashMap::new();
credentials.insert("token".to_string(), token.to_string());
self.login("token", credentials).await
}
pub async fn login(
&mut self,
method: &str,
credentials: HashMap<String, String>,
) -> Result<LoginResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Auth(AuthRequest {
token: None,
method: Some(method.to_string()),
credentials: Some(credentials),
}),
};
let response = self.send_request(request).await?;
match response.payload {
ResponsePayload::Auth(auth) => {
if auth.status.as_deref() == Some("authenticated") {
let token = auth.token.unwrap_or_default();
let identity = auth.identity.unwrap_or_default();
self.is_authenticated = true;
Ok(LoginResult {
token,
identity,
})
} else {
panic!("Authentication failed") }
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for login"), }
}
pub async fn logout(&mut self) -> Result<(), Error> {
if !self.is_authenticated {
return Ok(());
}
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Logout,
};
let response = self.send_request(request).await?;
match response.payload {
ResponsePayload::Logout(_) => {
self.is_authenticated = false;
Ok(())
}
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for logout"), }
}
pub async fn admin(&self, rql: &str, params: Option<Params>) -> Result<AdminResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Admin(AdminRequest {
statements: vec![rql.to_string()],
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_admin_response(response)
}
pub async fn admin_batch(&self, statements: Vec<&str>, params: Option<Params>) -> Result<AdminResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Admin(AdminRequest {
statements: statements.into_iter().map(String::from).collect(),
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_admin_response(response)
}
pub async fn command(&self, rql: &str, params: Option<Params>) -> Result<CommandResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Command(CommandRequest {
statements: vec![rql.to_string()],
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_command_response(response)
}
pub async fn query(&self, rql: &str, params: Option<Params>) -> Result<QueryResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Query(QueryRequest {
statements: vec![rql.to_string()],
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_query_response(response)
}
pub async fn command_batch(
&self,
statements: Vec<&str>,
params: Option<Params>,
) -> Result<CommandResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Command(CommandRequest {
statements: statements.into_iter().map(String::from).collect(),
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_command_response(response)
}
pub async fn query_batch(&self, statements: Vec<&str>, params: Option<Params>) -> Result<QueryResult, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Query(QueryRequest {
statements: statements.into_iter().map(String::from).collect(),
params: params.and_then(params_to_wire),
}),
};
let response = self.send_request(request).await?;
parse_query_response(response)
}
pub async fn subscribe(&self, query: &str) -> Result<String, Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Subscribe(SubscribeRequest {
query: query.to_string(),
}),
};
let response = self.send_request(request).await?;
match response.payload {
ResponsePayload::Subscribed(sub) => Ok(sub.subscription_id),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for subscribe"), }
}
pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), Error> {
let id = generate_request_id();
let request = Request {
id,
payload: RequestPayload::Unsubscribe(UnsubscribeRequest {
subscription_id: subscription_id.to_string(),
}),
};
let response = self.send_request(request).await?;
match response.payload {
ResponsePayload::Unsubscribed(_) => Ok(()),
ResponsePayload::Err(err) => Err(Error(Box::new(err.diagnostic))),
_ => panic!("Unexpected response type for unsubscribe"), }
}
pub async fn recv(&mut self) -> Option<ChangePayload> {
self.change_rx.recv().await
}
pub fn try_recv(&mut self) -> Result<ChangePayload, mpsc::error::TryRecvError> {
self.change_rx.try_recv()
}
async fn send_request(&self, request: Request) -> Result<Response, Error> {
let (tx, rx) = oneshot::channel();
self.request_tx.send((request, tx)).await.unwrap();
Ok(rx.await.unwrap()) }
pub async fn close(mut self) -> Result<(), Error> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
pub fn is_authenticated(&self) -> bool {
self.is_authenticated
}
}
impl Drop for WsClient {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.try_send(());
}
}
}