use std::collections::HashMap;
use std::error::Error;
use std::ops::Deref;
use std::sync::Arc;
use async_lock::{Mutex, RwLock};
use futures::Future;
use futures::future::{BoxFuture, LocalBoxFuture, join_all};
use prost::Message;
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::proto::request::ClientReq;
use crate::proto::response::ClientResp;
use crate::proto::{
ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
HostedTable, JoinType, MakeJoinTableReq, MakeTableReq, RemoveHostedTablesUpdateReq, Request,
Response, ServerError, ServerSystemInfoReq,
};
use crate::table::{JoinOptions, Table, TableInitOptions, TableOptions};
use crate::table_data::{TableData, UpdateData};
use crate::table_ref::TableRef;
use crate::utils::*;
use crate::view::{OnUpdateData, ViewWindow};
use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
#[derive(Clone, Debug, Serialize, Deserialize, TS)]
pub struct SystemInfo<T = u64> {
pub heap_size: T,
pub used_size: T,
pub cpu_time: u32,
pub cpu_time_epoch: u32,
pub timestamp: Option<T>,
pub client_heap: Option<T>,
pub client_used: Option<T>,
}
impl<U: Copy + 'static> SystemInfo<U> {
pub fn cast<T: Copy + 'static>(&self) -> SystemInfo<T>
where
U: num_traits::AsPrimitive<T>,
{
SystemInfo {
heap_size: self.heap_size.as_(),
used_size: self.used_size.as_(),
cpu_time: self.cpu_time,
cpu_time_epoch: self.cpu_time_epoch,
timestamp: self.timestamp.map(|x| x.as_()),
client_heap: self.client_heap.map(|x| x.as_()),
client_used: self.client_used.map(|x| x.as_()),
}
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Features(Arc<GetFeaturesResp>);
impl Features {
pub fn get_group_rollup_modes(&self) -> Vec<crate::config::GroupRollupMode> {
self.group_rollup_mode
.iter()
.map(|x| {
crate::config::GroupRollupMode::from(
crate::proto::GroupRollupMode::try_from(*x).unwrap(),
)
})
.collect::<Vec<_>>()
}
}
impl Deref for Features {
type Target = GetFeaturesResp;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl GetFeaturesResp {
pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
self.filter_ops
.get(&(col_type as u32))?
.options
.first()
.map(|x| x.as_str())
}
}
type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
type OnErrorCallback =
Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
type SendCallback = Arc<
dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
+ Send
+ Sync
+ 'static,
>;
pub trait ClientHandler: Clone + Send + Sync + 'static {
fn send_request(
&self,
msg: Vec<u8>,
) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
}
mod name_registry {
use std::collections::HashSet;
use std::sync::{Arc, LazyLock, Mutex};
use crate::ClientError;
use crate::view::ClientResult;
static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
if let Some(name) = name {
if let Some(name) = REGISTERED_CLIENTS
.lock()
.map_err(ClientError::from)?
.get(name)
{
Err(ClientError::DuplicateNameError(name.to_owned()))
} else {
Ok(name.to_owned())
}
} else {
let mut guard = CLIENT_ID_GEN.lock()?;
*guard += 1;
Ok(format!("client-{guard}"))
}
}
}
#[derive(Clone)]
#[allow(clippy::type_complexity)]
pub struct ReconnectCallback(
Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>,
);
impl Deref for ReconnectCallback {
type Target = dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl ReconnectCallback {
pub fn new(
f: impl Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync + 'static,
) -> Self {
ReconnectCallback(Arc::new(f))
}
}
#[derive(Clone)]
pub struct Client {
name: Arc<String>,
features: Arc<Mutex<Option<Features>>>,
send: SendCallback,
id_gen: IDGen,
subscriptions_errors: Subscriptions<OnErrorCallback>,
subscriptions_once: Subscriptions<OnceCallback>,
subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
}
impl PartialEq for Client {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client").finish()
}
}
impl Client {
pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
where
T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
{
let name = name_registry::generate_name(name)?;
let send_request = Arc::new(send_request);
let send: SendCallback = Arc::new(move |req| {
let mut bytes: Vec<u8> = Vec::new();
req.encode(&mut bytes).unwrap();
let send_request = send_request.clone();
Box::pin(async move { send_request(bytes).await })
});
Ok(Client {
name: Arc::new(name),
features: Arc::default(),
id_gen: IDGen::default(),
send,
subscriptions: Subscriptions::default(),
subscriptions_errors: Arc::default(),
subscriptions_once: Arc::default(),
})
}
pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
where
T: ClientHandler + 'static + Sync + Send,
{
Self::new_with_callback(
name,
asyncfn!(client_handler, async move |req| {
client_handler.send_request(req).await
}),
)
}
pub fn get_name(&self) -> &'_ str {
self.name.as_str()
}
pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
let msg = Response::decode(msg)?;
tracing::debug!("RECV {}", msg);
let mut wr = self.subscriptions_once.write().await;
if let Some(handler) = (*wr).remove(&msg.msg_id) {
drop(wr);
handler(msg)?;
return Ok(true);
} else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
drop(wr);
handler(msg).await?;
return Ok(true);
}
if let Response {
client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
..
} = &msg
{
tracing::error!("{}", message);
} else {
tracing::debug!("Received unsolicited server response: {}", msg);
}
Ok(false)
}
pub async fn handle_error<T, U>(
&self,
message: ClientError,
reconnect: Option<T>,
) -> ClientResult<()>
where
T: Fn() -> U + Clone + Send + Sync + 'static,
U: Future<Output = ClientResult<()>>,
{
let subs = self.subscriptions_errors.read().await;
let tasks = join_all(subs.values().map(|callback| {
callback(
message.clone(),
reconnect.clone().map(move |f| {
ReconnectCallback(Arc::new(move || {
clone!(f);
Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
}))
}),
)
}));
tasks.await.into_iter().collect::<Result<(), _>>()?;
self.close_and_error_subscriptions(&message).await
}
async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
let synthetic_error = |msg_id| Response {
msg_id,
entity_id: "".to_string(),
client_resp: Some(ClientResp::ServerError(ServerError {
message: format!("{message}"),
status_code: 2,
})),
};
self.subscriptions.write().await.clear();
let callbacks_once = self
.subscriptions_once
.write()
.await
.drain()
.collect::<Vec<_>>();
callbacks_once
.into_iter()
.try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
}
pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
where
T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
U: Future<Output = V> + Send + 'static,
V: Into<Result<(), ClientError>> + Sync + 'static,
{
let id = self.gen_id();
let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
self.subscriptions_errors
.write()
.await
.insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
Ok(id)
}
pub(crate) fn gen_id(&self) -> u32 {
self.id_gen.next()
}
pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
let callback = self
.subscriptions
.write()
.await
.remove(&update_id)
.ok_or(ClientError::Unknown("remove_update".to_string()))?;
drop(callback);
Ok(())
}
pub(crate) async fn subscribe_once(
&self,
msg: &Request,
on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
) -> ClientResult<()> {
self.subscriptions_once
.write()
.await
.insert(msg.msg_id, on_update);
tracing::debug!("SEND {}", msg);
if let Err(e) = (self.send)(msg).await {
self.subscriptions_once.write().await.remove(&msg.msg_id);
Err(ClientError::Unknown(e.to_string()))
} else {
Ok(())
}
}
pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
where
T: Fn(Response) -> U + Send + Sync + 'static,
U: Future<Output = Result<(), ClientError>> + Send + 'static,
{
self.subscriptions
.write()
.await
.insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
tracing::debug!("SEND {}", msg);
if let Err(e) = (self.send)(msg).await {
self.subscriptions.write().await.remove(&msg.msg_id);
Err(ClientError::Unknown(e.to_string()))
} else {
Ok(())
}
}
pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
let on_update = Box::new(move |res: Response| {
sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
});
self.subscribe_once(req, on_update).await?;
receiver
.await
.map_err(|_| ClientError::Unknown(format!("Internal error for req {req}")))
}
pub(crate) async fn get_features(&self) -> ClientResult<Features> {
let mut guard = self.features.lock().await;
let features = if let Some(features) = &*guard {
features.clone()
} else {
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
};
let features = Features(Arc::new(match self.oneshot(&msg).await? {
ClientResp::GetFeaturesResp(features) => Ok(features),
resp => Err(resp),
}?));
*guard = Some(features.clone());
features
};
Ok(features)
}
pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
let entity_id = match options.name.clone() {
Some(x) => x.to_owned(),
None => randid(),
};
if let TableData::View(view) = &input {
let window = ViewWindow::default();
let arrow = view.to_arrow(window).await?;
let mut table = self
.crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
.await?;
let table_ = table.clone();
let callback = asyncfn!(table_, update, async move |update: OnUpdateData| {
let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
let options = crate::UpdateOptions::default();
table_.update(update, options).await.unwrap_or_log();
});
let options = OnUpdateOptions {
mode: Some(OnUpdateMode::Row),
};
let on_update_token = view.on_update(callback, options).await?;
table.view_update_token = Some(on_update_token);
Ok(table)
} else {
self.crate_table_inner(input, options.into(), entity_id)
.await
}
}
async fn crate_table_inner(
&self,
input: TableData,
options: TableOptions,
entity_id: String,
) -> ClientResult<Table> {
let msg = Request {
msg_id: self.gen_id(),
entity_id: entity_id.clone(),
client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
data: Some(input.into()),
options: Some(options.clone().try_into()?),
})),
};
let client = self.clone();
match self.oneshot(&msg).await? {
ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
resp => Err(resp.into()),
}
}
pub async fn join(
&self,
left: TableRef,
right: TableRef,
on: &str,
options: JoinOptions,
) -> ClientResult<Table> {
let entity_id = options.name.unwrap_or_else(randid);
let join_type: JoinType = options.join_type.unwrap_or_default();
let right_on_column = options.right_on.unwrap_or_default();
let msg = Request {
msg_id: self.gen_id(),
entity_id: entity_id.clone(),
client_req: Some(ClientReq::MakeJoinTableReq(MakeJoinTableReq {
left_table_id: left.table_name().to_owned(),
right_table_id: right.table_name().to_owned(),
on_column: on.to_owned(),
join_type: join_type.into(),
right_on_column,
})),
};
let client = self.clone();
match self.oneshot(&msg).await? {
ClientResp::MakeJoinTableResp(_) => Ok(Table::new(entity_id, client, TableOptions {
index: Some(on.to_owned()),
limit: None,
})),
resp => Err(resp.into()),
}
}
async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
subscribe: false,
})),
};
match self.oneshot(&msg).await? {
ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
resp => Err(resp.into()),
}
}
pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
let infos = self.get_table_infos().await?;
if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
let options = TableOptions {
index: info.index,
limit: info.limit,
};
let client = self.clone();
Ok(Table::new(entity_id, client, options))
} else {
Err(ClientError::Unknown(format!(
"Unknown table \"{}\"",
entity_id
)))
}
}
pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
subscribe: false,
})),
};
match self.oneshot(&msg).await? {
ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
},
resp => Err(resp.into()),
}
}
pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
where
T: Fn() -> U + Send + Sync + 'static,
U: Future<Output = ()> + Send + 'static,
{
let on_update = Arc::new(on_update);
let callback = asyncfn!(on_update, async move |resp: Response| {
match resp.client_resp {
Some(ClientResp::GetHostedTablesResp(_)) | None => {
on_update().await;
Ok(())
},
resp => Err(resp.into()),
}
});
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
subscribe: true,
})),
};
self.subscribe(&msg, callback).await?;
Ok(msg.msg_id)
}
pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
RemoveHostedTablesUpdateReq { id: update_id },
)),
};
self.unsubscribe(update_id).await?;
match self.oneshot(&msg).await? {
ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
resp => Err(resp.into()),
}
}
pub async fn system_info(&self) -> ClientResult<SystemInfo> {
let msg = Request {
msg_id: self.gen_id(),
entity_id: "".to_string(),
client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
};
match self.oneshot(&msg).await? {
ClientResp::ServerSystemInfoResp(resp) => {
#[cfg(not(target_family = "wasm"))]
let timestamp = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis() as u64,
);
#[cfg(target_family = "wasm")]
let timestamp = None;
#[cfg(feature = "talc-allocator")]
let (client_used, client_heap) = {
let (client_used, client_heap) = crate::utils::get_used();
(Some(client_used as u64), Some(client_heap as u64))
};
#[cfg(not(feature = "talc-allocator"))]
let (client_used, client_heap) = (None, None);
let info = SystemInfo {
heap_size: resp.heap_size,
used_size: resp.used_size,
cpu_time: resp.cpu_time,
cpu_time_epoch: resp.cpu_time_epoch,
timestamp,
client_heap,
client_used,
};
Ok(info)
},
resp => Err(resp.into()),
}
}
}