1use std::collections::HashMap;
14use std::error::Error;
15use std::ops::Deref;
16use std::sync::Arc;
17
18use async_lock::{Mutex, RwLock};
19use futures::Future;
20use futures::future::{BoxFuture, LocalBoxFuture, join_all};
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use ts_rs::TS;
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28 ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29 HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerError,
30 ServerSystemInfoReq,
31};
32use crate::table::{Table, TableInitOptions, TableOptions};
33use crate::table_data::{TableData, UpdateData};
34use crate::utils::*;
35use crate::view::{OnUpdateData, ViewWindow};
36use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
37
38#[derive(Clone, Debug, Serialize, Deserialize, TS)]
40pub struct SystemInfo {
41 pub heap_size: u64,
43
44 pub used_size: u64,
46
47 pub cpu_time: u32,
52
53 pub cpu_time_epoch: u32,
55
56 pub timestamp: Option<u64>,
60
61 pub client_heap: Option<u64>,
64
65 pub client_used: Option<u64>,
68}
69
70#[derive(Clone, Default)]
73pub struct Features(Arc<GetFeaturesResp>);
74
75impl Deref for Features {
76 type Target = GetFeaturesResp;
77
78 fn deref(&self) -> &Self::Target {
79 &self.0
80 }
81}
82
83impl GetFeaturesResp {
84 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
85 self.filter_ops
86 .get(&(col_type as u32))?
87 .options
88 .first()
89 .map(|x| x.as_str())
90 }
91}
92
93type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
94type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
95
96type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
97type OnErrorCallback =
98 Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
99
100type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
101type SendCallback = Arc<
102 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
103 + Send
104 + Sync
105 + 'static,
106>;
107
108pub trait ClientHandler: Clone + Send + Sync + 'static {
110 fn send_request(
111 &self,
112 msg: Vec<u8>,
113 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
114}
115
116mod name_registry {
117 use std::collections::HashSet;
118 use std::sync::{Arc, LazyLock, Mutex};
119
120 use crate::ClientError;
121 use crate::view::ClientResult;
122
123 static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
124 static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
125
126 pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
127 if let Some(name) = name {
128 if let Some(name) = REGISTERED_CLIENTS
129 .lock()
130 .map_err(ClientError::from)?
131 .get(name)
132 {
133 Err(ClientError::DuplicateNameError(name.to_owned()))
134 } else {
135 Ok(name.to_owned())
136 }
137 } else {
138 let mut guard = CLIENT_ID_GEN.lock()?;
139 *guard += 1;
140 Ok(format!("client-{guard}"))
141 }
142 }
143}
144
145#[derive(Clone)]
152#[allow(clippy::type_complexity)]
153pub struct ReconnectCallback(
154 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>,
155);
156
157impl Deref for ReconnectCallback {
158 type Target = dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync;
159
160 fn deref(&self) -> &Self::Target {
161 &*self.0
162 }
163}
164
165impl ReconnectCallback {
166 pub fn new(
167 f: impl Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync + 'static,
168 ) -> Self {
169 ReconnectCallback(Arc::new(f))
170 }
171}
172
173#[derive(Clone)]
190pub struct Client {
191 name: Arc<String>,
192 features: Arc<Mutex<Option<Features>>>,
193 send: SendCallback,
194 id_gen: IDGen,
195 subscriptions_errors: Subscriptions<OnErrorCallback>,
196 subscriptions_once: Subscriptions<OnceCallback>,
197 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
198}
199
200impl PartialEq for Client {
201 fn eq(&self, other: &Self) -> bool {
202 self.name == other.name
203 }
204}
205
206impl std::fmt::Debug for Client {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 f.debug_struct("Client").finish()
209 }
210}
211
212impl Client {
213 pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
216 where
217 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
218 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
219 {
220 let name = name_registry::generate_name(name)?;
221 let send_request = Arc::new(send_request);
222 let send: SendCallback = Arc::new(move |req| {
223 let mut bytes: Vec<u8> = Vec::new();
224 req.encode(&mut bytes).unwrap();
225 let send_request = send_request.clone();
226 Box::pin(async move { send_request(bytes).await })
227 });
228
229 Ok(Client {
230 name: Arc::new(name),
231 features: Arc::default(),
232 id_gen: IDGen::default(),
233 send,
234 subscriptions: Subscriptions::default(),
235 subscriptions_errors: Arc::default(),
236 subscriptions_once: Arc::default(),
237 })
238 }
239
240 pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
242 where
243 T: ClientHandler + 'static + Sync + Send,
244 {
245 Self::new_with_callback(
246 name,
247 asyncfn!(client_handler, async move |req| {
248 client_handler.send_request(req).await
249 }),
250 )
251 }
252
253 pub fn get_name(&self) -> &'_ str {
254 self.name.as_str()
255 }
256
257 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
264 let msg = Response::decode(msg)?;
265 tracing::debug!("RECV {}", msg);
266 let mut wr = self.subscriptions_once.write().await;
267 if let Some(handler) = (*wr).remove(&msg.msg_id) {
268 drop(wr);
269 handler(msg)?;
270 return Ok(true);
271 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
272 drop(wr);
273 handler(msg).await?;
274 return Ok(true);
275 }
276
277 if let Response {
278 client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
279 ..
280 } = &msg
281 {
282 tracing::error!("{}", message);
283 } else {
284 tracing::debug!("Received unsolicited server response: {}", msg);
285 }
286
287 Ok(false)
288 }
289
290 pub async fn handle_error<T, U>(
292 &self,
293 message: ClientError,
294 reconnect: Option<T>,
295 ) -> ClientResult<()>
296 where
297 T: Fn() -> U + Clone + Send + Sync + 'static,
298 U: Future<Output = ClientResult<()>>,
299 {
300 let subs = self.subscriptions_errors.read().await;
301 let tasks = join_all(subs.values().map(|callback| {
302 callback(
303 message.clone(),
304 reconnect.clone().map(move |f| {
305 ReconnectCallback(Arc::new(move || {
306 clone!(f);
307 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
308 }))
309 }),
310 )
311 }));
312
313 tasks.await.into_iter().collect::<Result<(), _>>()?;
314 self.close_and_error_subscriptions(&message).await
315 }
316
317 async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
322 let synthetic_error = |msg_id| Response {
323 msg_id,
324 entity_id: "".to_string(),
325 client_resp: Some(ClientResp::ServerError(ServerError {
326 message: format!("{message}"),
327 status_code: 2,
328 })),
329 };
330
331 self.subscriptions.write().await.clear();
332 let callbacks_once = self
333 .subscriptions_once
334 .write()
335 .await
336 .drain()
337 .collect::<Vec<_>>();
338
339 callbacks_once
340 .into_iter()
341 .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
342 }
343
344 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
345 where
346 T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
347 U: Future<Output = V> + Send + 'static,
348 V: Into<Result<(), ClientError>> + Sync + 'static,
349 {
350 let id = self.gen_id();
351 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
352 self.subscriptions_errors
353 .write()
354 .await
355 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
356
357 Ok(id)
358 }
359
360 pub async fn init(&self) -> ClientResult<()> {
361 let msg = Request {
362 msg_id: self.gen_id(),
363 entity_id: "".to_owned(),
364 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
365 };
366
367 *self.features.lock().await = Some(Features(Arc::new(match self.oneshot(&msg).await? {
368 ClientResp::GetFeaturesResp(features) => Ok(features),
369 resp => Err(resp),
370 }?)));
371
372 Ok(())
373 }
374
375 pub(crate) fn gen_id(&self) -> u32 {
377 self.id_gen.next()
378 }
379
380 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
381 let callback = self
382 .subscriptions
383 .write()
384 .await
385 .remove(&update_id)
386 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
387
388 drop(callback);
389 Ok(())
390 }
391
392 pub(crate) async fn subscribe_once(
394 &self,
395 msg: &Request,
396 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
397 ) -> ClientResult<()> {
398 self.subscriptions_once
399 .write()
400 .await
401 .insert(msg.msg_id, on_update);
402
403 tracing::debug!("SEND {}", msg);
404 if let Err(e) = (self.send)(msg).await {
405 self.subscriptions_once.write().await.remove(&msg.msg_id);
406 Err(ClientError::Unknown(e.to_string()))
407 } else {
408 Ok(())
409 }
410 }
411
412 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
413 where
414 T: Fn(Response) -> U + Send + Sync + 'static,
415 U: Future<Output = Result<(), ClientError>> + Send + 'static,
416 {
417 self.subscriptions
418 .write()
419 .await
420 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
421
422 tracing::debug!("SEND {}", msg);
423 if let Err(e) = (self.send)(msg).await {
424 self.subscriptions.write().await.remove(&msg.msg_id);
425 Err(ClientError::Unknown(e.to_string()))
426 } else {
427 Ok(())
428 }
429 }
430
431 pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
434 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
435 let on_update = Box::new(move |res: Response| {
436 sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
437 });
438
439 self.subscribe_once(req, on_update).await?;
440 receiver
441 .await
442 .map_err(|_| ClientError::Unknown(format!("Internal error for req {req}")))
443 }
444
445 pub(crate) fn get_features(&self) -> ClientResult<Features> {
446 let features = self
447 .features
448 .try_lock()
449 .ok_or(ClientError::NotInitialized)?
450 .as_ref()
451 .ok_or(ClientError::NotInitialized)?
452 .clone();
453
454 Ok(features)
455 }
456
457 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
510 let entity_id = match options.name.clone() {
511 Some(x) => x.to_owned(),
512 None => randid(),
513 };
514
515 if let TableData::View(view) = &input {
516 let window = ViewWindow::default();
517 let arrow = view.to_arrow(window).await?;
518 let mut table = self
519 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
520 .await?;
521
522 let table_ = table.clone();
523 let callback = asyncfn!(table_, update, async move |update: OnUpdateData| {
524 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
525 let options = crate::UpdateOptions::default();
526 table_.update(update, options).await.unwrap_or_log();
527 });
528
529 let options = OnUpdateOptions {
530 mode: Some(OnUpdateMode::Row),
531 };
532
533 let on_update_token = view.on_update(callback, options).await?;
534 table.view_update_token = Some(on_update_token);
535 Ok(table)
536 } else {
537 self.crate_table_inner(input, options.into(), entity_id)
538 .await
539 }
540 }
541
542 async fn crate_table_inner(
543 &self,
544 input: TableData,
545 options: TableOptions,
546 entity_id: String,
547 ) -> ClientResult<Table> {
548 let msg = Request {
549 msg_id: self.gen_id(),
550 entity_id: entity_id.clone(),
551 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
552 data: Some(input.into()),
553 options: Some(options.clone().try_into()?),
554 })),
555 };
556
557 let client = self.clone();
558 match self.oneshot(&msg).await? {
559 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
560 resp => Err(resp.into()),
561 }
562 }
563
564 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
565 let msg = Request {
566 msg_id: self.gen_id(),
567 entity_id: "".to_owned(),
568 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
569 subscribe: false,
570 })),
571 };
572
573 match self.oneshot(&msg).await? {
574 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
575 resp => Err(resp.into()),
576 }
577 }
578
579 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
592 let infos = self.get_table_infos().await?;
593
594 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
596 let options = TableOptions {
597 index: info.index,
598 limit: info.limit,
599 };
600
601 let client = self.clone();
602 Ok(Table::new(entity_id, client, options))
603 } else {
604 Err(ClientError::Unknown("Unknown table".to_owned()))
605 }
606 }
607
608 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
621 let msg = Request {
622 msg_id: self.gen_id(),
623 entity_id: "".to_owned(),
624 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
625 subscribe: false,
626 })),
627 };
628
629 match self.oneshot(&msg).await? {
630 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
631 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
632 },
633 resp => Err(resp.into()),
634 }
635 }
636
637 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
641 where
642 T: Fn() -> U + Send + Sync + 'static,
643 U: Future<Output = ()> + Send + 'static,
644 {
645 let on_update = Arc::new(on_update);
646 let callback = asyncfn!(on_update, async move |resp: Response| {
647 match resp.client_resp {
648 Some(ClientResp::GetHostedTablesResp(_)) | None => {
649 on_update().await;
650 Ok(())
651 },
652 resp => Err(resp.into()),
653 }
654 });
655
656 let msg = Request {
657 msg_id: self.gen_id(),
658 entity_id: "".to_owned(),
659 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
660 subscribe: true,
661 })),
662 };
663
664 self.subscribe(&msg, callback).await?;
665 Ok(msg.msg_id)
666 }
667
668 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
671 let msg = Request {
672 msg_id: self.gen_id(),
673 entity_id: "".to_owned(),
674 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
675 RemoveHostedTablesUpdateReq { id: update_id },
676 )),
677 };
678
679 self.unsubscribe(update_id).await?;
680 match self.oneshot(&msg).await? {
681 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
682 resp => Err(resp.into()),
683 }
684 }
685
686 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
690 let msg = Request {
691 msg_id: self.gen_id(),
692 entity_id: "".to_string(),
693 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
694 };
695
696 match self.oneshot(&msg).await? {
697 ClientResp::ServerSystemInfoResp(resp) => {
698 #[cfg(not(target_family = "wasm"))]
699 let timestamp = Some(
700 std::time::SystemTime::now()
701 .duration_since(std::time::UNIX_EPOCH)?
702 .as_millis() as u64,
703 );
704
705 #[cfg(target_family = "wasm")]
706 let timestamp = None;
707
708 #[cfg(feature = "talc-allocator")]
709 let (client_used, client_heap) = {
710 let (client_used, client_heap) = crate::utils::get_used();
711 (Some(client_used as u64), Some(client_heap as u64))
712 };
713
714 #[cfg(not(feature = "talc-allocator"))]
715 let (client_used, client_heap) = (None, None);
716
717 let info = SystemInfo {
718 heap_size: resp.heap_size,
719 used_size: resp.used_size,
720 cpu_time: resp.cpu_time,
721 cpu_time_epoch: resp.cpu_time_epoch,
722 timestamp,
723 client_heap,
724 client_used,
725 };
726
727 Ok(info)
728 },
729 resp => Err(resp.into()),
730 }
731 }
732}