1use std::collections::HashMap;
14use std::error::Error;
15use std::sync::Arc;
16use std::sync::atomic::AtomicU32;
17
18use async_lock::{Mutex, RwLock};
19use futures::Future;
20use futures::future::{BoxFuture, LocalBoxFuture, join_all};
21use nanoid::*;
22use prost::Message;
23use serde::{Deserialize, Serialize};
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28 self, ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29 HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerSystemInfoReq,
30};
31use crate::table::{Table, TableInitOptions, TableOptions};
32use crate::table_data::{TableData, UpdateData};
33use crate::utils::*;
34use crate::view::ViewWindow;
35use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct SystemInfo {
40 pub heap_size: f64,
41}
42
43impl From<proto::ServerSystemInfoResp> for SystemInfo {
44 fn from(value: proto::ServerSystemInfoResp) -> Self {
45 SystemInfo {
46 heap_size: value.heap_size,
47 }
48 }
49}
50
51pub type Features = Arc<GetFeaturesResp>;
54
55impl GetFeaturesResp {
56 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
57 self.filter_ops
58 .get(&(col_type as u32))?
59 .options
60 .first()
61 .map(|x| x.as_str())
62 }
63}
64
65type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
66type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
67
68type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
69type OnErrorCallback =
70 Box2Fn<Option<String>, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
71type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
72type SendCallback = Arc<
73 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
74 + Send
75 + Sync
76 + 'static,
77>;
78
79pub trait ClientHandler: Clone + Send + Sync + 'static {
80 fn send_request(
81 &self,
82 msg: Vec<u8>,
83 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
84}
85
86#[derive(Clone)]
87#[doc = include_str!("../../docs/client.md")]
88pub struct Client {
89 features: Arc<Mutex<Option<Features>>>,
90 send: SendCallback,
91 id_gen: Arc<AtomicU32>,
92 subscriptions_errors: Subscriptions<OnErrorCallback>,
93 subscriptions_once: Subscriptions<OnceCallback>,
94 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
95}
96
97impl std::fmt::Debug for Client {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("Client")
100 .field("id_gen", &self.id_gen)
101 .finish()
102 }
103}
104
105pub type ReconnectCallback =
112 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>;
113
114impl Client {
115 pub fn new_with_callback<T, U>(send_request: T) -> Self
118 where
119 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
120 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
121 {
122 let send_request = Arc::new(send_request);
123 let send: SendCallback = Arc::new(move |req| {
124 let mut bytes: Vec<u8> = Vec::new();
125 req.encode(&mut bytes).unwrap();
126 let send_request = send_request.clone();
127 Box::pin(async move { send_request(bytes).await })
128 });
129
130 Client {
131 features: Arc::default(),
132 id_gen: Arc::new(AtomicU32::new(1)),
133 send,
134 subscriptions: Subscriptions::default(),
135 subscriptions_errors: Arc::default(),
136 subscriptions_once: Arc::default(),
137 }
138 }
139
140 pub fn new<T>(client_handler: T) -> Self
142 where
143 T: ClientHandler + 'static + Sync + Send,
144 {
145 Self::new_with_callback(asyncfn!(client_handler, async move |req| {
146 client_handler.send_request(req).await
147 }))
148 }
149
150 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
157 let msg = Response::decode(msg)?;
158 tracing::debug!("RECV {}", msg);
159 let mut wr = self.subscriptions_once.write().await;
160 if let Some(handler) = (*wr).remove(&msg.msg_id) {
161 drop(wr);
162 handler(msg)?;
163 return Ok(true);
164 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
165 drop(wr);
166 handler(msg).await?;
167 return Ok(true);
168 }
169
170 tracing::warn!("Received unsolicited server message");
171 Ok(false)
172 }
173
174 pub async fn handle_error<T, U>(
200 &self,
201 message: Option<String>,
202 reconnect: Option<T>,
203 ) -> ClientResult<()>
204 where
205 T: Fn() -> U + Clone + Send + Sync + 'static,
206 U: Future<Output = ClientResult<()>>,
207 {
208 let subs = self.subscriptions_errors.read().await;
209 let tasks = join_all(subs.values().map(|callback| {
210 callback(
211 message.clone(),
212 reconnect.clone().map(move |f| {
213 Arc::new(move || {
214 clone!(f);
215 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
216 }) as ReconnectCallback
217 }),
218 )
219 }));
220
221 tasks.await.into_iter().collect::<Result<(), _>>()?;
222 Ok(())
223 }
224
225 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
226 where
227 T: Fn(Option<String>, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
228 U: Future<Output = V> + Send + 'static,
229 V: Into<Result<(), ClientError>> + Sync + 'static,
230 {
231 let id = self.gen_id();
232 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
233 self.subscriptions_errors
234 .write()
235 .await
236 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
237
238 Ok(id)
239 }
240
241 pub async fn init(&self) -> ClientResult<()> {
242 let msg = Request {
243 msg_id: self.gen_id(),
244 entity_id: "".to_owned(),
245 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
246 };
247
248 *self.features.lock().await = Some(Arc::new(match self.oneshot(&msg).await? {
249 ClientResp::GetFeaturesResp(features) => Ok(features),
250 resp => Err(resp),
251 }?));
252
253 Ok(())
254 }
255
256 pub(crate) fn gen_id(&self) -> u32 {
258 self.id_gen
259 .fetch_add(1, std::sync::atomic::Ordering::Acquire)
260 }
261
262 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
263 let callback = self
264 .subscriptions
265 .write()
266 .await
267 .remove(&update_id)
268 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
269
270 drop(callback);
271 Ok(())
272 }
273
274 pub(crate) async fn subscribe_once(
276 &self,
277 msg: &Request,
278 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
279 ) -> ClientResult<()> {
280 self.subscriptions_once
281 .write()
282 .await
283 .insert(msg.msg_id, on_update);
284
285 tracing::debug!("SEND {}", msg);
286 if let Err(e) = (self.send)(msg).await {
287 self.subscriptions_once.write().await.remove(&msg.msg_id);
288 Err(ClientError::Unknown(e.to_string()))
289 } else {
290 Ok(())
291 }
292 }
293
294 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
313 where
314 T: Fn(Response) -> U + Send + Sync + 'static,
315 U: Future<Output = Result<(), ClientError>> + Send + 'static,
316 {
317 self.subscriptions
318 .write()
319 .await
320 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
321
322 tracing::debug!("SEND {}", msg);
323 if let Err(e) = (self.send)(msg).await {
324 self.subscriptions.write().await.remove(&msg.msg_id);
325 Err(ClientError::Unknown(e.to_string()))
326 } else {
327 Ok(())
328 }
329 }
330
331 pub(crate) async fn oneshot(&self, msg: &Request) -> ClientResult<ClientResp> {
334 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
335 let on_update = Box::new(move |msg: Response| {
336 sender.send(msg.client_resp.unwrap()).map_err(|x| x.into())
337 });
338
339 self.subscribe_once(msg, on_update).await?;
340 receiver
341 .await
342 .map_err(|_| ClientError::Unknown("Internal error".to_owned()))
343 }
344
345 pub(crate) fn get_features(&self) -> ClientResult<Features> {
346 Ok(self
347 .features
348 .try_lock()
349 .ok_or(ClientError::NotInitialized)?
350 .as_ref()
351 .ok_or(ClientError::NotInitialized)?
352 .clone())
353 }
354
355 #[doc = include_str!("../../docs/client/table.md")]
356 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
357 let entity_id = match options.name.clone() {
358 Some(x) => x.to_owned(),
359 None => nanoid!(),
360 };
361
362 if let TableData::View(view) = &input {
363 let window = ViewWindow::default();
364 let arrow = view.to_arrow(window).await?;
365 let mut table = self
366 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
367 .await?;
368
369 let table_ = table.clone();
370 let callback = asyncfn!(
371 table_,
372 update,
373 async move |update: crate::proto::ViewOnUpdateResp| {
374 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
375 let options = crate::UpdateOptions::default();
376 table_.update(update, options).await.unwrap_or_log();
377 }
378 );
379
380 let options = OnUpdateOptions {
381 mode: Some(OnUpdateMode::Row),
382 };
383
384 let on_update_token = view.on_update(callback, options).await?;
385 table.view_update_token = Some(on_update_token);
386 Ok(table)
387 } else {
388 self.crate_table_inner(input, options.into(), entity_id)
389 .await
390 }
391 }
392
393 async fn crate_table_inner(
394 &self,
395 input: TableData,
396 options: TableOptions,
397 entity_id: String,
398 ) -> ClientResult<Table> {
399 let msg = Request {
400 msg_id: self.gen_id(),
401 entity_id: entity_id.clone(),
402 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
403 data: Some(input.into()),
404 options: Some(options.clone().try_into()?),
405 })),
406 };
407
408 let client = self.clone();
409 match self.oneshot(&msg).await? {
410 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
411 resp => Err(resp.into()),
412 }
413 }
414
415 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
416 let msg = Request {
417 msg_id: self.gen_id(),
418 entity_id: "".to_owned(),
419 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
420 subscribe: false,
421 })),
422 };
423
424 match self.oneshot(&msg).await? {
425 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
426 resp => Err(resp.into()),
427 }
428 }
429
430 #[doc = include_str!("../../docs/client/open_table.md")]
431 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
432 let infos = self.get_table_infos().await?;
433
434 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
436 let options = TableOptions {
437 index: info.index,
438 limit: info.limit,
439 };
440
441 let client = self.clone();
442 Ok(Table::new(entity_id, client, options))
443 } else {
444 Err(ClientError::Unknown("Unknown table".to_owned()))
445 }
446 }
447
448 #[doc = include_str!("../../docs/client/get_hosted_table_names.md")]
449 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
450 let msg = Request {
451 msg_id: self.gen_id(),
452 entity_id: "".to_owned(),
453 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
454 subscribe: false,
455 })),
456 };
457
458 match self.oneshot(&msg).await? {
459 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
460 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
461 },
462 resp => Err(resp.into()),
463 }
464 }
465
466 #[doc = include_str!("../../docs/client/on_hosted_tables_update.md")]
467 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
468 where
469 T: Fn() -> U + Send + Sync + 'static,
470 U: Future<Output = ()> + Send + 'static,
471 {
472 let on_update = Arc::new(on_update);
473 let callback = asyncfn!(on_update, async move |resp: Response| {
474 match resp.client_resp {
475 Some(ClientResp::GetHostedTablesResp(_)) | None => {
476 on_update().await;
477 Ok(())
478 },
479 resp => Err(ClientError::OptionResponseFailed(resp.into())),
480 }
481 });
482
483 let msg = Request {
484 msg_id: self.gen_id(),
485 entity_id: "".to_owned(),
486 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
487 subscribe: true,
488 })),
489 };
490
491 self.subscribe(&msg, callback).await?;
492 Ok(msg.msg_id)
493 }
494
495 #[doc = include_str!("../../docs/client/remove_hosted_tables_update.md")]
496 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
497 let msg = Request {
498 msg_id: self.gen_id(),
499 entity_id: "".to_owned(),
500 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
501 RemoveHostedTablesUpdateReq { id: update_id },
502 )),
503 };
504
505 self.unsubscribe(update_id).await?;
506 match self.oneshot(&msg).await? {
507 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
508 resp => Err(resp.into()),
509 }
510 }
511
512 #[doc = include_str!("../../docs/client/system_info.md")]
513 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
514 let msg = Request {
515 msg_id: self.gen_id(),
516 entity_id: "".to_string(),
517 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
518 };
519
520 match self.oneshot(&msg).await? {
521 ClientResp::ServerSystemInfoResp(resp) => Ok(resp.into()),
522 resp => Err(resp.into()),
523 }
524 }
525}