use std::collections::HashMap;
use std::pin::Pin;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use async_lock::RwLock;
use futures::Future;
use nanoid::*;
use prost::Message;
use proto::make_table_data::Data;
use crate::proto::request::ClientReq;
use crate::proto::response::ClientResp;
use crate::proto::*;
use crate::table::SystemInfo;
use crate::utils::*;
use crate::view::View;
use crate::{proto, Table, TableInitOptions};
#[derive(Debug)]
pub enum TableData {
Schema(Vec<(String, ColumnType)>),
Csv(String),
Arrow(Vec<u8>),
JsonRows(String),
JsonColumns(String),
View(View),
}
impl From<TableData> for proto::make_table_data::Data {
fn from(value: TableData) -> Self {
match value {
TableData::Csv(x) => make_table_data::Data::FromCsv(x),
TableData::Arrow(x) => make_table_data::Data::FromArrow(x),
TableData::JsonRows(x) => make_table_data::Data::FromRows(x),
TableData::JsonColumns(x) => make_table_data::Data::FromCols(x),
TableData::View(view) => make_table_data::Data::FromView(view.name),
TableData::Schema(x) => make_table_data::Data::FromSchema(proto::Schema {
schema: x
.into_iter()
.map(|(name, r#type)| KeyTypePair {
name,
r#type: r#type as i32,
})
.collect(),
}),
}
}
}
type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
type ManyCallback = Box<dyn Fn(ClientResp) -> Result<(), ClientError> + Send + Sync + 'static>;
type OnceCallback = Box<dyn FnOnce(ClientResp) -> Result<(), ClientError> + Send + Sync + 'static>;
type SendFuture = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
type SendCallback = Arc<dyn Fn(&Client, &RequestEnvelope) -> SendFuture + Send + Sync + 'static>;
#[derive(Clone)]
#[doc = include_str!("../../docs/client.md")]
pub struct Client {
send: SendCallback,
id_gen: Arc<AtomicU32>,
subscriptions_once: Subscriptions<OnceCallback>,
subscriptions_many: Subscriptions<ManyCallback>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("id_gen", &self.id_gen)
.finish()
}
}
fn encode(req: &RequestEnvelope) -> Vec<u8> {
let mut bytes: Vec<u8> = Vec::new();
req.encode(&mut bytes).unwrap();
bytes
}
impl Client {
pub fn new<T>(send_handler: T) -> Self
where
T: Fn(&Client, &Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>
+ Send
+ Sync
+ 'static,
{
Client {
id_gen: Arc::new(AtomicU32::new(1)),
subscriptions_once: Arc::default(),
subscriptions_many: Subscriptions::default(),
send: Arc::new(move |client, msg| send_handler(client, &encode(msg))),
}
}
pub fn new_sync<T>(send_handler: T) -> Self
where
T: Fn(&Client, &Vec<u8>) + Send + Sync + 'static + Clone,
{
Client {
id_gen: Arc::new(AtomicU32::new(1)),
subscriptions_once: Arc::default(),
subscriptions_many: Subscriptions::default(),
send: Arc::new(move |client, msg| {
let client = client.clone();
let msg = msg.clone();
let send_handler = send_handler.clone();
Box::pin(async move {
send_handler(&client, &encode(&msg));
})
}),
}
}
pub fn set_send_handler<T>(&mut self, send_handler: T)
where
T: Fn(&Client, &Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>
+ Send
+ Sync
+ 'static,
{
self.send = Arc::new(move |client, msg| send_handler(client, &encode(msg)))
}
pub fn receive(&self, msg: &Vec<u8>) -> Result<(), ClientError> {
let msg = ResponseEnvelope::decode(msg.as_slice())?;
let payload = msg
.payload
.ok_or(ClientError::Option)?
.client_resp
.ok_or(ClientError::Option)?;
let mut wr = self.subscriptions_once.try_write().unwrap();
if let Some(handler) = (*wr).remove(&msg.msg_id) {
handler(payload)?;
} else if let Some(handler) = self.subscriptions_many.try_read().unwrap().get(&msg.msg_id) {
handler(payload)?;
} else {
tracing::warn!("Received unsolicited server message");
}
Ok(())
}
#[doc = include_str!("../../docs/client/table.md")]
pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
let entity_id = match options.name.clone() {
Some(x) => x.to_owned(),
None => nanoid!(),
};
let msg = RequestEnvelope {
msg_id: self.gen_id(),
entity_id: entity_id.clone(),
entity_type: EntityType::Table as i32,
payload: Some(Request {
client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
data: Some(MakeTableData {
data: Some(input.into()),
}),
options: Some(options.clone().try_into()?),
})),
}),
};
let client = self.clone();
match self.oneshot(&msg).await {
ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
resp => Err(resp.into()),
}
}
#[doc = include_str!("../../docs/client/open_table.md")]
pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
let names = self.get_hosted_table_names().await?;
if names.contains(&entity_id) {
let options = TableInitOptions::default();
let client = self.clone();
Ok(Table::new(entity_id, client, options))
} else {
Err(ClientError::Unknown("Unknown table".to_owned()))
}
}
#[doc = include_str!("../../docs/client/get_hosted_table_names.md")]
pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
let msg = RequestEnvelope {
msg_id: self.gen_id(),
entity_id: "".to_owned(),
entity_type: EntityType::Table as i32,
payload: Some(Request {
client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {})),
}),
};
match self.oneshot(&msg).await {
ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_names }) => Ok(table_names),
resp => Err(resp.into()),
}
}
#[doc = include_str!("../../docs/client/system_info.md")]
pub async fn system_info(&self) -> ClientResult<SystemInfo> {
let msg = RequestEnvelope {
msg_id: self.gen_id(),
entity_id: "".to_string(),
entity_type: EntityType::Table as i32,
payload: Some(Request {
client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
}),
};
match self.oneshot(&msg).await {
ClientResp::ServerSystemInfoResp(resp) => Ok(resp.into()),
resp => Err(resp.into()),
}
}
pub(crate) fn gen_id(&self) -> u32 {
self.id_gen
.fetch_add(1, std::sync::atomic::Ordering::Acquire)
}
pub(crate) fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
let callback = self
.subscriptions_many
.try_write()
.unwrap()
.remove(&update_id)
.ok_or(ClientError::Unknown("remove_update".to_string()))?;
drop(callback);
Ok(())
}
pub(crate) async fn subscribe_once(
&self,
msg: &RequestEnvelope,
on_update: Box<dyn FnOnce(ClientResp) -> ClientResult<()> + Send + Sync + 'static>,
) {
self.subscriptions_once
.try_write()
.unwrap()
.insert(msg.msg_id, on_update);
tracing::info!("SEND {}", msg);
(self.send)(self, msg).await;
}
pub(crate) async fn subscribe(
&self,
msg: &RequestEnvelope,
on_update: Box<dyn Fn(ClientResp) -> ClientResult<()> + Send + Sync + 'static>,
) {
self.subscriptions_many
.try_write()
.unwrap()
.insert(msg.msg_id, on_update);
tracing::info!("SEND {}", msg);
(self.send)(self, msg).await;
}
pub(crate) async fn oneshot(&self, msg: &RequestEnvelope) -> ClientResp {
let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
let callback = Box::new(move |msg| sender.send(msg).map_err(|x| x.into()));
self.subscriptions_once
.try_write()
.unwrap()
.insert(msg.msg_id, callback);
tracing::info!("SEND {}", msg);
(self.send)(self, msg).await;
receiver.await.unwrap()
}
}
fn replace(x: Data) -> Data {
match x {
Data::FromArrow(_) => Data::FromArrow("<< redacted >>".to_string().encode_to_vec()),
Data::FromRows(_) => Data::FromRows("<< redacted >>".to_string()),
Data::FromCols(_) => Data::FromCols("".to_string()),
Data::FromCsv(_) => Data::FromCsv("".to_string()),
x => x,
}
}
impl std::fmt::Display for RequestEnvelope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut msg = self.clone();
msg.payload = match msg.payload {
Some(Request {
client_req:
Some(request::ClientReq::MakeTableReq(MakeTableReq {
options,
data: Some(MakeTableData { data: Some(data) }),
})),
}) => Some(Request {
client_req: Some(request::ClientReq::MakeTableReq(MakeTableReq {
options,
data: Some(MakeTableData {
data: Some(replace(data)),
}),
})),
}),
x => x,
};
write!(f, "{}", serde_json::to_string(&msg).unwrap())
}
}