1use std::collections::HashMap;
14use std::error::Error;
15use std::ops::Deref;
16use std::sync::Arc;
17
18use async_lock::{Mutex, RwLock};
19use futures::Future;
20use futures::future::{BoxFuture, LocalBoxFuture, join_all};
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use ts_rs::TS;
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28 ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29 HostedTable, JoinType, MakeJoinTableReq, MakeTableReq, RemoveHostedTablesUpdateReq, Request,
30 Response, ServerError, ServerSystemInfoReq,
31};
32use crate::table::{JoinOptions, Table, TableInitOptions, TableOptions};
33use crate::table_data::{TableData, UpdateData};
34use crate::table_ref::TableRef;
35use crate::utils::*;
36use crate::view::{OnUpdateData, ViewWindow};
37use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
38
39#[derive(Clone, Debug, Serialize, Deserialize, TS)]
41pub struct SystemInfo<T = u64> {
42 pub heap_size: T,
44
45 pub used_size: T,
47
48 pub cpu_time: u32,
53
54 pub cpu_time_epoch: u32,
56
57 pub timestamp: Option<T>,
61
62 pub client_heap: Option<T>,
65
66 pub client_used: Option<T>,
69}
70
71impl<U: Copy + 'static> SystemInfo<U> {
72 pub fn cast<T: Copy + 'static>(&self) -> SystemInfo<T>
75 where
76 U: num_traits::AsPrimitive<T>,
77 {
78 SystemInfo {
79 heap_size: self.heap_size.as_(),
80 used_size: self.used_size.as_(),
81 cpu_time: self.cpu_time,
82 cpu_time_epoch: self.cpu_time_epoch,
83 timestamp: self.timestamp.map(|x| x.as_()),
84 client_heap: self.client_heap.map(|x| x.as_()),
85 client_used: self.client_used.map(|x| x.as_()),
86 }
87 }
88}
89
90#[derive(Clone, Debug, Default, PartialEq)]
93pub struct Features(Arc<GetFeaturesResp>);
94
95impl Features {
96 pub fn get_group_rollup_modes(&self) -> Vec<crate::config::GroupRollupMode> {
97 self.group_rollup_mode
98 .iter()
99 .map(|x| {
100 crate::config::GroupRollupMode::from(
101 crate::proto::GroupRollupMode::try_from(*x).unwrap(),
102 )
103 })
104 .collect::<Vec<_>>()
105 }
106}
107
108impl Deref for Features {
109 type Target = GetFeaturesResp;
110
111 fn deref(&self) -> &Self::Target {
112 &self.0
113 }
114}
115
116impl GetFeaturesResp {
117 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
118 self.filter_ops
119 .get(&(col_type as u32))?
120 .options
121 .first()
122 .map(|x| x.as_str())
123 }
124}
125
126type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
127type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
128
129type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
130type OnErrorCallback =
131 Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
132
133type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
134type SendCallback = Arc<
135 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
136 + Send
137 + Sync
138 + 'static,
139>;
140
141pub trait ClientHandler: Clone + Send + Sync + 'static {
143 fn send_request(
144 &self,
145 msg: Vec<u8>,
146 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
147}
148
149mod name_registry {
150 use std::collections::HashSet;
151 use std::sync::{Arc, LazyLock, Mutex};
152
153 use crate::ClientError;
154 use crate::view::ClientResult;
155
156 static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
157 static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
158
159 pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
160 if let Some(name) = name {
161 if let Some(name) = REGISTERED_CLIENTS
162 .lock()
163 .map_err(ClientError::from)?
164 .get(name)
165 {
166 Err(ClientError::DuplicateNameError(name.to_owned()))
167 } else {
168 Ok(name.to_owned())
169 }
170 } else {
171 let mut guard = CLIENT_ID_GEN.lock()?;
172 *guard += 1;
173 Ok(format!("client-{guard}"))
174 }
175 }
176}
177
178#[derive(Clone)]
185#[allow(clippy::type_complexity)]
186pub struct ReconnectCallback(
187 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>,
188);
189
190impl Deref for ReconnectCallback {
191 type Target = dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync;
192
193 fn deref(&self) -> &Self::Target {
194 &*self.0
195 }
196}
197
198impl ReconnectCallback {
199 pub fn new(
200 f: impl Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync + 'static,
201 ) -> Self {
202 ReconnectCallback(Arc::new(f))
203 }
204}
205
206#[derive(Clone)]
210pub struct Client {
211 name: Arc<String>,
212 features: Arc<Mutex<Option<Features>>>,
213 send: SendCallback,
214 id_gen: IDGen,
215 subscriptions_errors: Subscriptions<OnErrorCallback>,
216 subscriptions_once: Subscriptions<OnceCallback>,
217 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
218}
219
220impl PartialEq for Client {
221 fn eq(&self, other: &Self) -> bool {
222 self.name == other.name
223 }
224}
225
226impl std::fmt::Debug for Client {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("Client").finish()
229 }
230}
231
232impl Client {
233 pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
236 where
237 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
238 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
239 {
240 let name = name_registry::generate_name(name)?;
241 let send_request = Arc::new(send_request);
242 let send: SendCallback = Arc::new(move |req| {
243 let mut bytes: Vec<u8> = Vec::new();
244 req.encode(&mut bytes).unwrap();
245 let send_request = send_request.clone();
246 Box::pin(async move { send_request(bytes).await })
247 });
248
249 Ok(Client {
250 name: Arc::new(name),
251 features: Arc::default(),
252 id_gen: IDGen::default(),
253 send,
254 subscriptions: Subscriptions::default(),
255 subscriptions_errors: Arc::default(),
256 subscriptions_once: Arc::default(),
257 })
258 }
259
260 pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
262 where
263 T: ClientHandler + 'static + Sync + Send,
264 {
265 Self::new_with_callback(
266 name,
267 asyncfn!(client_handler, async move |req| {
268 client_handler.send_request(req).await
269 }),
270 )
271 }
272
273 pub fn get_name(&self) -> &'_ str {
274 self.name.as_str()
275 }
276
277 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
284 let msg = Response::decode(msg)?;
285 tracing::debug!("RECV {}", msg);
286 let mut wr = self.subscriptions_once.write().await;
287 if let Some(handler) = (*wr).remove(&msg.msg_id) {
288 drop(wr);
289 handler(msg)?;
290 return Ok(true);
291 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
292 drop(wr);
293 handler(msg).await?;
294 return Ok(true);
295 }
296
297 if let Response {
298 client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
299 ..
300 } = &msg
301 {
302 tracing::error!("{}", message);
303 } else {
304 tracing::debug!("Received unsolicited server response: {}", msg);
305 }
306
307 Ok(false)
308 }
309
310 pub async fn handle_error<T, U>(
312 &self,
313 message: ClientError,
314 reconnect: Option<T>,
315 ) -> ClientResult<()>
316 where
317 T: Fn() -> U + Clone + Send + Sync + 'static,
318 U: Future<Output = ClientResult<()>>,
319 {
320 let subs = self.subscriptions_errors.read().await;
321 let tasks = join_all(subs.values().map(|callback| {
322 callback(
323 message.clone(),
324 reconnect.clone().map(move |f| {
325 ReconnectCallback(Arc::new(move || {
326 clone!(f);
327 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
328 }))
329 }),
330 )
331 }));
332
333 tasks.await.into_iter().collect::<Result<(), _>>()?;
334 self.close_and_error_subscriptions(&message).await
335 }
336
337 async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
342 let synthetic_error = |msg_id| Response {
343 msg_id,
344 entity_id: "".to_string(),
345 client_resp: Some(ClientResp::ServerError(ServerError {
346 message: format!("{message}"),
347 status_code: 2,
348 })),
349 };
350
351 self.subscriptions.write().await.clear();
352 let callbacks_once = self
353 .subscriptions_once
354 .write()
355 .await
356 .drain()
357 .collect::<Vec<_>>();
358
359 callbacks_once
360 .into_iter()
361 .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
362 }
363
364 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
365 where
366 T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
367 U: Future<Output = V> + Send + 'static,
368 V: Into<Result<(), ClientError>> + Sync + 'static,
369 {
370 let id = self.gen_id();
371 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
372 self.subscriptions_errors
373 .write()
374 .await
375 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
376
377 Ok(id)
378 }
379
380 pub(crate) fn gen_id(&self) -> u32 {
382 self.id_gen.next()
383 }
384
385 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
386 let callback = self
387 .subscriptions
388 .write()
389 .await
390 .remove(&update_id)
391 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
392
393 drop(callback);
394 Ok(())
395 }
396
397 pub(crate) async fn subscribe_once(
399 &self,
400 msg: &Request,
401 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
402 ) -> ClientResult<()> {
403 self.subscriptions_once
404 .write()
405 .await
406 .insert(msg.msg_id, on_update);
407
408 tracing::debug!("SEND {}", msg);
409 if let Err(e) = (self.send)(msg).await {
410 self.subscriptions_once.write().await.remove(&msg.msg_id);
411 Err(ClientError::Unknown(e.to_string()))
412 } else {
413 Ok(())
414 }
415 }
416
417 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
418 where
419 T: Fn(Response) -> U + Send + Sync + 'static,
420 U: Future<Output = Result<(), ClientError>> + Send + 'static,
421 {
422 self.subscriptions
423 .write()
424 .await
425 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
426
427 tracing::debug!("SEND {}", msg);
428 if let Err(e) = (self.send)(msg).await {
429 self.subscriptions.write().await.remove(&msg.msg_id);
430 Err(ClientError::Unknown(e.to_string()))
431 } else {
432 Ok(())
433 }
434 }
435
436 pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
439 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
440 let on_update = Box::new(move |res: Response| {
441 sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
442 });
443
444 self.subscribe_once(req, on_update).await?;
445 receiver
446 .await
447 .map_err(|_| ClientError::Unknown(format!("Internal error for req {req}")))
448 }
449
450 pub(crate) async fn get_features(&self) -> ClientResult<Features> {
451 let mut guard = self.features.lock().await;
452 let features = if let Some(features) = &*guard {
453 features.clone()
454 } else {
455 let msg = Request {
456 msg_id: self.gen_id(),
457 entity_id: "".to_owned(),
458 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
459 };
460
461 let features = Features(Arc::new(match self.oneshot(&msg).await? {
462 ClientResp::GetFeaturesResp(features) => Ok(features),
463 resp => Err(resp),
464 }?));
465
466 *guard = Some(features.clone());
467 features
468 };
469
470 Ok(features)
471 }
472
473 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
529 let entity_id = match options.name.clone() {
530 Some(x) => x.to_owned(),
531 None => randid(),
532 };
533
534 if let TableData::View(view) = &input {
535 let window = ViewWindow::default();
536 let arrow = view.to_arrow(window).await?;
537 let mut table = self
538 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
539 .await?;
540
541 let table_ = table.clone();
542 let callback = asyncfn!(table_, update, async move |update: OnUpdateData| {
543 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
544 let options = crate::UpdateOptions::default();
545 table_.update(update, options).await.unwrap_or_log();
546 });
547
548 let options = OnUpdateOptions {
549 mode: Some(OnUpdateMode::Row),
550 };
551
552 let on_update_token = view.on_update(callback, options).await?;
553 table.view_update_token = Some(on_update_token);
554 Ok(table)
555 } else {
556 self.crate_table_inner(input, options.into(), entity_id)
557 .await
558 }
559 }
560
561 async fn crate_table_inner(
562 &self,
563 input: TableData,
564 options: TableOptions,
565 entity_id: String,
566 ) -> ClientResult<Table> {
567 let msg = Request {
568 msg_id: self.gen_id(),
569 entity_id: entity_id.clone(),
570 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
571 data: Some(input.into()),
572 options: Some(options.clone().try_into()?),
573 })),
574 };
575
576 let client = self.clone();
577 match self.oneshot(&msg).await? {
578 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
579 resp => Err(resp.into()),
580 }
581 }
582
583 pub async fn join(
595 &self,
596 left: TableRef,
597 right: TableRef,
598 on: &str,
599 options: JoinOptions,
600 ) -> ClientResult<Table> {
601 let entity_id = options.name.unwrap_or_else(randid);
602 let join_type: JoinType = options.join_type.unwrap_or_default();
603 let right_on_column = options.right_on.unwrap_or_default();
604 let msg = Request {
605 msg_id: self.gen_id(),
606 entity_id: entity_id.clone(),
607 client_req: Some(ClientReq::MakeJoinTableReq(MakeJoinTableReq {
608 left_table_id: left.table_name().to_owned(),
609 right_table_id: right.table_name().to_owned(),
610 on_column: on.to_owned(),
611 join_type: join_type.into(),
612 right_on_column,
613 })),
614 };
615
616 let client = self.clone();
617 match self.oneshot(&msg).await? {
618 ClientResp::MakeJoinTableResp(_) => Ok(Table::new(entity_id, client, TableOptions {
619 index: Some(on.to_owned()),
620 limit: None,
621 })),
622 resp => Err(resp.into()),
623 }
624 }
625
626 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
627 let msg = Request {
628 msg_id: self.gen_id(),
629 entity_id: "".to_owned(),
630 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
631 subscribe: false,
632 })),
633 };
634
635 match self.oneshot(&msg).await? {
636 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
637 resp => Err(resp.into()),
638 }
639 }
640
641 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
658 let infos = self.get_table_infos().await?;
659
660 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
662 let options = TableOptions {
663 index: info.index,
664 limit: info.limit,
665 };
666
667 let client = self.clone();
668 Ok(Table::new(entity_id, client, options))
669 } else {
670 Err(ClientError::Unknown(format!(
671 "Unknown table \"{}\"",
672 entity_id
673 )))
674 }
675 }
676
677 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
694 let msg = Request {
695 msg_id: self.gen_id(),
696 entity_id: "".to_owned(),
697 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
698 subscribe: false,
699 })),
700 };
701
702 match self.oneshot(&msg).await? {
703 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
704 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
705 },
706 resp => Err(resp.into()),
707 }
708 }
709
710 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
714 where
715 T: Fn() -> U + Send + Sync + 'static,
716 U: Future<Output = ()> + Send + 'static,
717 {
718 let on_update = Arc::new(on_update);
719 let callback = asyncfn!(on_update, async move |resp: Response| {
720 match resp.client_resp {
721 Some(ClientResp::GetHostedTablesResp(_)) | None => {
722 on_update().await;
723 Ok(())
724 },
725 resp => Err(resp.into()),
726 }
727 });
728
729 let msg = Request {
730 msg_id: self.gen_id(),
731 entity_id: "".to_owned(),
732 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
733 subscribe: true,
734 })),
735 };
736
737 self.subscribe(&msg, callback).await?;
738 Ok(msg.msg_id)
739 }
740
741 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
744 let msg = Request {
745 msg_id: self.gen_id(),
746 entity_id: "".to_owned(),
747 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
748 RemoveHostedTablesUpdateReq { id: update_id },
749 )),
750 };
751
752 self.unsubscribe(update_id).await?;
753 match self.oneshot(&msg).await? {
754 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
755 resp => Err(resp.into()),
756 }
757 }
758
759 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
763 let msg = Request {
764 msg_id: self.gen_id(),
765 entity_id: "".to_string(),
766 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
767 };
768
769 match self.oneshot(&msg).await? {
770 ClientResp::ServerSystemInfoResp(resp) => {
771 #[cfg(not(target_family = "wasm"))]
772 let timestamp = Some(
773 std::time::SystemTime::now()
774 .duration_since(std::time::UNIX_EPOCH)?
775 .as_millis() as u64,
776 );
777
778 #[cfg(target_family = "wasm")]
779 let timestamp = None;
780
781 #[cfg(feature = "talc-allocator")]
782 let (client_used, client_heap) = {
783 let (client_used, client_heap) = crate::utils::get_used();
784 (Some(client_used as u64), Some(client_heap as u64))
785 };
786
787 #[cfg(not(feature = "talc-allocator"))]
788 let (client_used, client_heap) = (None, None);
789
790 let info = SystemInfo {
791 heap_size: resp.heap_size,
792 used_size: resp.used_size,
793 cpu_time: resp.cpu_time,
794 cpu_time_epoch: resp.cpu_time_epoch,
795 timestamp,
796 client_heap,
797 client_used,
798 };
799
800 Ok(info)
801 },
802 resp => Err(resp.into()),
803 }
804 }
805}