1use std::collections::HashMap;
14use std::error::Error;
15use std::ops::Deref;
16use std::sync::Arc;
17
18use async_lock::{Mutex, RwLock};
19use futures::Future;
20use futures::future::{BoxFuture, LocalBoxFuture, join_all};
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use ts_rs::TS;
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28 ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29 HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerError,
30 ServerSystemInfoReq,
31};
32use crate::table::{Table, TableInitOptions, TableOptions};
33use crate::table_data::{TableData, UpdateData};
34use crate::utils::*;
35use crate::view::{OnUpdateData, ViewWindow};
36use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
37
38#[derive(Clone, Debug, Serialize, Deserialize, TS)]
40pub struct SystemInfo<T = u64> {
41 pub heap_size: T,
43
44 pub used_size: T,
46
47 pub cpu_time: u32,
52
53 pub cpu_time_epoch: u32,
55
56 pub timestamp: Option<T>,
60
61 pub client_heap: Option<T>,
64
65 pub client_used: Option<T>,
68}
69
70impl<U: Copy + 'static> SystemInfo<U> {
71 pub fn cast<T: Copy + 'static>(&self) -> SystemInfo<T>
74 where
75 U: num_traits::AsPrimitive<T>,
76 {
77 SystemInfo {
78 heap_size: self.heap_size.as_(),
79 used_size: self.used_size.as_(),
80 cpu_time: self.cpu_time,
81 cpu_time_epoch: self.cpu_time_epoch,
82 timestamp: self.timestamp.map(|x| x.as_()),
83 client_heap: self.client_heap.map(|x| x.as_()),
84 client_used: self.client_used.map(|x| x.as_()),
85 }
86 }
87}
88
89#[derive(Clone, Debug, Default)]
92pub struct Features(Arc<GetFeaturesResp>);
93
94impl Deref for Features {
95 type Target = GetFeaturesResp;
96
97 fn deref(&self) -> &Self::Target {
98 &self.0
99 }
100}
101
102impl GetFeaturesResp {
103 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
104 self.filter_ops
105 .get(&(col_type as u32))?
106 .options
107 .first()
108 .map(|x| x.as_str())
109 }
110}
111
112type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
113type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
114
115type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
116type OnErrorCallback =
117 Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
118
119type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
120type SendCallback = Arc<
121 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
122 + Send
123 + Sync
124 + 'static,
125>;
126
127pub trait ClientHandler: Clone + Send + Sync + 'static {
129 fn send_request(
130 &self,
131 msg: Vec<u8>,
132 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
133}
134
135mod name_registry {
136 use std::collections::HashSet;
137 use std::sync::{Arc, LazyLock, Mutex};
138
139 use crate::ClientError;
140 use crate::view::ClientResult;
141
142 static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
143 static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
144
145 pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
146 if let Some(name) = name {
147 if let Some(name) = REGISTERED_CLIENTS
148 .lock()
149 .map_err(ClientError::from)?
150 .get(name)
151 {
152 Err(ClientError::DuplicateNameError(name.to_owned()))
153 } else {
154 Ok(name.to_owned())
155 }
156 } else {
157 let mut guard = CLIENT_ID_GEN.lock()?;
158 *guard += 1;
159 Ok(format!("client-{guard}"))
160 }
161 }
162}
163
164#[derive(Clone)]
171#[allow(clippy::type_complexity)]
172pub struct ReconnectCallback(
173 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>,
174);
175
176impl Deref for ReconnectCallback {
177 type Target = dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync;
178
179 fn deref(&self) -> &Self::Target {
180 &*self.0
181 }
182}
183
184impl ReconnectCallback {
185 pub fn new(
186 f: impl Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync + 'static,
187 ) -> Self {
188 ReconnectCallback(Arc::new(f))
189 }
190}
191
192#[derive(Clone)]
209pub struct Client {
210 name: Arc<String>,
211 features: Arc<Mutex<Option<Features>>>,
212 send: SendCallback,
213 id_gen: IDGen,
214 subscriptions_errors: Subscriptions<OnErrorCallback>,
215 subscriptions_once: Subscriptions<OnceCallback>,
216 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
217}
218
219impl PartialEq for Client {
220 fn eq(&self, other: &Self) -> bool {
221 self.name == other.name
222 }
223}
224
225impl std::fmt::Debug for Client {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("Client").finish()
228 }
229}
230
231impl Client {
232 pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
235 where
236 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
237 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
238 {
239 let name = name_registry::generate_name(name)?;
240 let send_request = Arc::new(send_request);
241 let send: SendCallback = Arc::new(move |req| {
242 let mut bytes: Vec<u8> = Vec::new();
243 req.encode(&mut bytes).unwrap();
244 let send_request = send_request.clone();
245 Box::pin(async move { send_request(bytes).await })
246 });
247
248 Ok(Client {
249 name: Arc::new(name),
250 features: Arc::default(),
251 id_gen: IDGen::default(),
252 send,
253 subscriptions: Subscriptions::default(),
254 subscriptions_errors: Arc::default(),
255 subscriptions_once: Arc::default(),
256 })
257 }
258
259 pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
261 where
262 T: ClientHandler + 'static + Sync + Send,
263 {
264 Self::new_with_callback(
265 name,
266 asyncfn!(client_handler, async move |req| {
267 client_handler.send_request(req).await
268 }),
269 )
270 }
271
272 pub fn get_name(&self) -> &'_ str {
273 self.name.as_str()
274 }
275
276 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
283 let msg = Response::decode(msg)?;
284 tracing::debug!("RECV {}", msg);
285 let mut wr = self.subscriptions_once.write().await;
286 if let Some(handler) = (*wr).remove(&msg.msg_id) {
287 drop(wr);
288 handler(msg)?;
289 return Ok(true);
290 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
291 drop(wr);
292 handler(msg).await?;
293 return Ok(true);
294 }
295
296 if let Response {
297 client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
298 ..
299 } = &msg
300 {
301 tracing::error!("{}", message);
302 } else {
303 tracing::debug!("Received unsolicited server response: {}", msg);
304 }
305
306 Ok(false)
307 }
308
309 pub async fn handle_error<T, U>(
311 &self,
312 message: ClientError,
313 reconnect: Option<T>,
314 ) -> ClientResult<()>
315 where
316 T: Fn() -> U + Clone + Send + Sync + 'static,
317 U: Future<Output = ClientResult<()>>,
318 {
319 let subs = self.subscriptions_errors.read().await;
320 let tasks = join_all(subs.values().map(|callback| {
321 callback(
322 message.clone(),
323 reconnect.clone().map(move |f| {
324 ReconnectCallback(Arc::new(move || {
325 clone!(f);
326 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
327 }))
328 }),
329 )
330 }));
331
332 tasks.await.into_iter().collect::<Result<(), _>>()?;
333 self.close_and_error_subscriptions(&message).await
334 }
335
336 async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
341 let synthetic_error = |msg_id| Response {
342 msg_id,
343 entity_id: "".to_string(),
344 client_resp: Some(ClientResp::ServerError(ServerError {
345 message: format!("{message}"),
346 status_code: 2,
347 })),
348 };
349
350 self.subscriptions.write().await.clear();
351 let callbacks_once = self
352 .subscriptions_once
353 .write()
354 .await
355 .drain()
356 .collect::<Vec<_>>();
357
358 callbacks_once
359 .into_iter()
360 .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
361 }
362
363 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
364 where
365 T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
366 U: Future<Output = V> + Send + 'static,
367 V: Into<Result<(), ClientError>> + Sync + 'static,
368 {
369 let id = self.gen_id();
370 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
371 self.subscriptions_errors
372 .write()
373 .await
374 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
375
376 Ok(id)
377 }
378
379 pub(crate) fn gen_id(&self) -> u32 {
381 self.id_gen.next()
382 }
383
384 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
385 let callback = self
386 .subscriptions
387 .write()
388 .await
389 .remove(&update_id)
390 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
391
392 drop(callback);
393 Ok(())
394 }
395
396 pub(crate) async fn subscribe_once(
398 &self,
399 msg: &Request,
400 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
401 ) -> ClientResult<()> {
402 self.subscriptions_once
403 .write()
404 .await
405 .insert(msg.msg_id, on_update);
406
407 tracing::debug!("SEND {}", msg);
408 if let Err(e) = (self.send)(msg).await {
409 self.subscriptions_once.write().await.remove(&msg.msg_id);
410 Err(ClientError::Unknown(e.to_string()))
411 } else {
412 Ok(())
413 }
414 }
415
416 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
417 where
418 T: Fn(Response) -> U + Send + Sync + 'static,
419 U: Future<Output = Result<(), ClientError>> + Send + 'static,
420 {
421 self.subscriptions
422 .write()
423 .await
424 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
425
426 tracing::debug!("SEND {}", msg);
427 if let Err(e) = (self.send)(msg).await {
428 self.subscriptions.write().await.remove(&msg.msg_id);
429 Err(ClientError::Unknown(e.to_string()))
430 } else {
431 Ok(())
432 }
433 }
434
435 pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
438 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
439 let on_update = Box::new(move |res: Response| {
440 sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
441 });
442
443 self.subscribe_once(req, on_update).await?;
444 receiver
445 .await
446 .map_err(|_| ClientError::Unknown(format!("Internal error for req {req}")))
447 }
448
449 pub(crate) async fn get_features(&self) -> ClientResult<Features> {
450 let mut guard = self.features.lock().await;
451 let features = if let Some(features) = &*guard {
452 features.clone()
453 } else {
454 let msg = Request {
455 msg_id: self.gen_id(),
456 entity_id: "".to_owned(),
457 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
458 };
459
460 let features = Features(Arc::new(match self.oneshot(&msg).await? {
461 ClientResp::GetFeaturesResp(features) => Ok(features),
462 resp => Err(resp),
463 }?));
464
465 *guard = Some(features.clone());
466 features
467 };
468
469 Ok(features)
470 }
471
472 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
525 let entity_id = match options.name.clone() {
526 Some(x) => x.to_owned(),
527 None => randid(),
528 };
529
530 if let TableData::View(view) = &input {
531 let window = ViewWindow::default();
532 let arrow = view.to_arrow(window).await?;
533 let mut table = self
534 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
535 .await?;
536
537 let table_ = table.clone();
538 let callback = asyncfn!(table_, update, async move |update: OnUpdateData| {
539 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
540 let options = crate::UpdateOptions::default();
541 table_.update(update, options).await.unwrap_or_log();
542 });
543
544 let options = OnUpdateOptions {
545 mode: Some(OnUpdateMode::Row),
546 };
547
548 let on_update_token = view.on_update(callback, options).await?;
549 table.view_update_token = Some(on_update_token);
550 Ok(table)
551 } else {
552 self.crate_table_inner(input, options.into(), entity_id)
553 .await
554 }
555 }
556
557 async fn crate_table_inner(
558 &self,
559 input: TableData,
560 options: TableOptions,
561 entity_id: String,
562 ) -> ClientResult<Table> {
563 let msg = Request {
564 msg_id: self.gen_id(),
565 entity_id: entity_id.clone(),
566 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
567 data: Some(input.into()),
568 options: Some(options.clone().try_into()?),
569 })),
570 };
571
572 let client = self.clone();
573 match self.oneshot(&msg).await? {
574 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
575 resp => Err(resp.into()),
576 }
577 }
578
579 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
580 let msg = Request {
581 msg_id: self.gen_id(),
582 entity_id: "".to_owned(),
583 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
584 subscribe: false,
585 })),
586 };
587
588 match self.oneshot(&msg).await? {
589 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
590 resp => Err(resp.into()),
591 }
592 }
593
594 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
607 let infos = self.get_table_infos().await?;
608
609 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
611 let options = TableOptions {
612 index: info.index,
613 limit: info.limit,
614 };
615
616 let client = self.clone();
617 Ok(Table::new(entity_id, client, options))
618 } else {
619 Err(ClientError::Unknown("Unknown table".to_owned()))
620 }
621 }
622
623 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
636 let msg = Request {
637 msg_id: self.gen_id(),
638 entity_id: "".to_owned(),
639 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
640 subscribe: false,
641 })),
642 };
643
644 match self.oneshot(&msg).await? {
645 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
646 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
647 },
648 resp => Err(resp.into()),
649 }
650 }
651
652 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
656 where
657 T: Fn() -> U + Send + Sync + 'static,
658 U: Future<Output = ()> + Send + 'static,
659 {
660 let on_update = Arc::new(on_update);
661 let callback = asyncfn!(on_update, async move |resp: Response| {
662 match resp.client_resp {
663 Some(ClientResp::GetHostedTablesResp(_)) | None => {
664 on_update().await;
665 Ok(())
666 },
667 resp => Err(resp.into()),
668 }
669 });
670
671 let msg = Request {
672 msg_id: self.gen_id(),
673 entity_id: "".to_owned(),
674 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
675 subscribe: true,
676 })),
677 };
678
679 self.subscribe(&msg, callback).await?;
680 Ok(msg.msg_id)
681 }
682
683 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
686 let msg = Request {
687 msg_id: self.gen_id(),
688 entity_id: "".to_owned(),
689 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
690 RemoveHostedTablesUpdateReq { id: update_id },
691 )),
692 };
693
694 self.unsubscribe(update_id).await?;
695 match self.oneshot(&msg).await? {
696 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
697 resp => Err(resp.into()),
698 }
699 }
700
701 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
705 let msg = Request {
706 msg_id: self.gen_id(),
707 entity_id: "".to_string(),
708 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
709 };
710
711 match self.oneshot(&msg).await? {
712 ClientResp::ServerSystemInfoResp(resp) => {
713 #[cfg(not(target_family = "wasm"))]
714 let timestamp = Some(
715 std::time::SystemTime::now()
716 .duration_since(std::time::UNIX_EPOCH)?
717 .as_millis() as u64,
718 );
719
720 #[cfg(target_family = "wasm")]
721 let timestamp = None;
722
723 #[cfg(feature = "talc-allocator")]
724 let (client_used, client_heap) = {
725 let (client_used, client_heap) = crate::utils::get_used();
726 (Some(client_used as u64), Some(client_heap as u64))
727 };
728
729 #[cfg(not(feature = "talc-allocator"))]
730 let (client_used, client_heap) = (None, None);
731
732 let info = SystemInfo {
733 heap_size: resp.heap_size,
734 used_size: resp.used_size,
735 cpu_time: resp.cpu_time,
736 cpu_time_epoch: resp.cpu_time_epoch,
737 timestamp,
738 client_heap,
739 client_used,
740 };
741
742 Ok(info)
743 },
744 resp => Err(resp.into()),
745 }
746 }
747}