1use std::collections::HashMap;
14use std::error::Error;
15use std::ops::Deref;
16use std::sync::Arc;
17
18use async_lock::{Mutex, RwLock};
19use futures::Future;
20use futures::future::{BoxFuture, LocalBoxFuture, join_all};
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use ts_rs::TS;
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28    ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29    HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerError,
30    ServerSystemInfoReq,
31};
32use crate::table::{Table, TableInitOptions, TableOptions};
33use crate::table_data::{TableData, UpdateData};
34use crate::utils::*;
35use crate::view::{OnUpdateData, ViewWindow};
36use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
37
38#[derive(Clone, Debug, Serialize, Deserialize, TS)]
40pub struct SystemInfo<T = u64> {
41    pub heap_size: T,
43
44    pub used_size: T,
46
47    pub cpu_time: u32,
52
53    pub cpu_time_epoch: u32,
55
56    pub timestamp: Option<T>,
60
61    pub client_heap: Option<T>,
64
65    pub client_used: Option<T>,
68}
69
70impl<U: Copy + 'static> SystemInfo<U> {
71    pub fn cast<T: Copy + 'static>(&self) -> SystemInfo<T>
74    where
75        U: num_traits::AsPrimitive<T>,
76    {
77        SystemInfo {
78            heap_size: self.heap_size.as_(),
79            used_size: self.used_size.as_(),
80            cpu_time: self.cpu_time,
81            cpu_time_epoch: self.cpu_time_epoch,
82            timestamp: self.timestamp.map(|x| x.as_()),
83            client_heap: self.client_heap.map(|x| x.as_()),
84            client_used: self.client_used.map(|x| x.as_()),
85        }
86    }
87}
88
89#[derive(Clone, Debug, Default)]
92pub struct Features(Arc<GetFeaturesResp>);
93
94impl Deref for Features {
95    type Target = GetFeaturesResp;
96
97    fn deref(&self) -> &Self::Target {
98        &self.0
99    }
100}
101
102impl GetFeaturesResp {
103    pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
104        self.filter_ops
105            .get(&(col_type as u32))?
106            .options
107            .first()
108            .map(|x| x.as_str())
109    }
110}
111
112type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
113type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
114
115type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
116type OnErrorCallback =
117    Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
118
119type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
120type SendCallback = Arc<
121    dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
122        + Send
123        + Sync
124        + 'static,
125>;
126
127pub trait ClientHandler: Clone + Send + Sync + 'static {
129    fn send_request(
130        &self,
131        msg: Vec<u8>,
132    ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
133}
134
135mod name_registry {
136    use std::collections::HashSet;
137    use std::sync::{Arc, LazyLock, Mutex};
138
139    use crate::ClientError;
140    use crate::view::ClientResult;
141
142    static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
143    static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
144
145    pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
146        if let Some(name) = name {
147            if let Some(name) = REGISTERED_CLIENTS
148                .lock()
149                .map_err(ClientError::from)?
150                .get(name)
151            {
152                Err(ClientError::DuplicateNameError(name.to_owned()))
153            } else {
154                Ok(name.to_owned())
155            }
156        } else {
157            let mut guard = CLIENT_ID_GEN.lock()?;
158            *guard += 1;
159            Ok(format!("client-{guard}"))
160        }
161    }
162}
163
164#[derive(Clone)]
171#[allow(clippy::type_complexity)]
172pub struct ReconnectCallback(
173    Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>,
174);
175
176impl Deref for ReconnectCallback {
177    type Target = dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync;
178
179    fn deref(&self) -> &Self::Target {
180        &*self.0
181    }
182}
183
184impl ReconnectCallback {
185    pub fn new(
186        f: impl Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync + 'static,
187    ) -> Self {
188        ReconnectCallback(Arc::new(f))
189    }
190}
191
192#[derive(Clone)]
209pub struct Client {
210    name: Arc<String>,
211    features: Arc<Mutex<Option<Features>>>,
212    send: SendCallback,
213    id_gen: IDGen,
214    subscriptions_errors: Subscriptions<OnErrorCallback>,
215    subscriptions_once: Subscriptions<OnceCallback>,
216    subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
217}
218
219impl PartialEq for Client {
220    fn eq(&self, other: &Self) -> bool {
221        self.name == other.name
222    }
223}
224
225impl std::fmt::Debug for Client {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("Client").finish()
228    }
229}
230
231impl Client {
232    pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
235    where
236        T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
237        U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
238    {
239        let name = name_registry::generate_name(name)?;
240        let send_request = Arc::new(send_request);
241        let send: SendCallback = Arc::new(move |req| {
242            let mut bytes: Vec<u8> = Vec::new();
243            req.encode(&mut bytes).unwrap();
244            let send_request = send_request.clone();
245            Box::pin(async move { send_request(bytes).await })
246        });
247
248        Ok(Client {
249            name: Arc::new(name),
250            features: Arc::default(),
251            id_gen: IDGen::default(),
252            send,
253            subscriptions: Subscriptions::default(),
254            subscriptions_errors: Arc::default(),
255            subscriptions_once: Arc::default(),
256        })
257    }
258
259    pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
261    where
262        T: ClientHandler + 'static + Sync + Send,
263    {
264        Self::new_with_callback(
265            name,
266            asyncfn!(client_handler, async move |req| {
267                client_handler.send_request(req).await
268            }),
269        )
270    }
271
272    pub fn get_name(&self) -> &'_ str {
273        self.name.as_str()
274    }
275
276    pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
283        let msg = Response::decode(msg)?;
284        tracing::debug!("RECV {}", msg);
285        let mut wr = self.subscriptions_once.write().await;
286        if let Some(handler) = (*wr).remove(&msg.msg_id) {
287            drop(wr);
288            handler(msg)?;
289            return Ok(true);
290        } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
291            drop(wr);
292            handler(msg).await?;
293            return Ok(true);
294        }
295
296        if let Response {
297            client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
298            ..
299        } = &msg
300        {
301            tracing::error!("{}", message);
302        } else {
303            tracing::debug!("Received unsolicited server response: {}", msg);
304        }
305
306        Ok(false)
307    }
308
309    pub async fn handle_error<T, U>(
311        &self,
312        message: ClientError,
313        reconnect: Option<T>,
314    ) -> ClientResult<()>
315    where
316        T: Fn() -> U + Clone + Send + Sync + 'static,
317        U: Future<Output = ClientResult<()>>,
318    {
319        let subs = self.subscriptions_errors.read().await;
320        let tasks = join_all(subs.values().map(|callback| {
321            callback(
322                message.clone(),
323                reconnect.clone().map(move |f| {
324                    ReconnectCallback(Arc::new(move || {
325                        clone!(f);
326                        Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
327                    }))
328                }),
329            )
330        }));
331
332        tasks.await.into_iter().collect::<Result<(), _>>()?;
333        self.close_and_error_subscriptions(&message).await
334    }
335
336    async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
341        let synthetic_error = |msg_id| Response {
342            msg_id,
343            entity_id: "".to_string(),
344            client_resp: Some(ClientResp::ServerError(ServerError {
345                message: format!("{message}"),
346                status_code: 2,
347            })),
348        };
349
350        self.subscriptions.write().await.clear();
351        let callbacks_once = self
352            .subscriptions_once
353            .write()
354            .await
355            .drain()
356            .collect::<Vec<_>>();
357
358        callbacks_once
359            .into_iter()
360            .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
361    }
362
363    pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
364    where
365        T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
366        U: Future<Output = V> + Send + 'static,
367        V: Into<Result<(), ClientError>> + Sync + 'static,
368    {
369        let id = self.gen_id();
370        let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
371        self.subscriptions_errors
372            .write()
373            .await
374            .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
375
376        Ok(id)
377    }
378
379    pub(crate) fn gen_id(&self) -> u32 {
381        self.id_gen.next()
382    }
383
384    pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
385        let callback = self
386            .subscriptions
387            .write()
388            .await
389            .remove(&update_id)
390            .ok_or(ClientError::Unknown("remove_update".to_string()))?;
391
392        drop(callback);
393        Ok(())
394    }
395
396    pub(crate) async fn subscribe_once(
398        &self,
399        msg: &Request,
400        on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
401    ) -> ClientResult<()> {
402        self.subscriptions_once
403            .write()
404            .await
405            .insert(msg.msg_id, on_update);
406
407        tracing::debug!("SEND {}", msg);
408        if let Err(e) = (self.send)(msg).await {
409            self.subscriptions_once.write().await.remove(&msg.msg_id);
410            Err(ClientError::Unknown(e.to_string()))
411        } else {
412            Ok(())
413        }
414    }
415
416    pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
417    where
418        T: Fn(Response) -> U + Send + Sync + 'static,
419        U: Future<Output = Result<(), ClientError>> + Send + 'static,
420    {
421        self.subscriptions
422            .write()
423            .await
424            .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
425
426        tracing::debug!("SEND {}", msg);
427        if let Err(e) = (self.send)(msg).await {
428            self.subscriptions.write().await.remove(&msg.msg_id);
429            Err(ClientError::Unknown(e.to_string()))
430        } else {
431            Ok(())
432        }
433    }
434
435    pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
438        let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
439        let on_update = Box::new(move |res: Response| {
440            sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
441        });
442
443        self.subscribe_once(req, on_update).await?;
444        receiver
445            .await
446            .map_err(|_| ClientError::Unknown(format!("Internal error for req {req}")))
447    }
448
449    pub(crate) async fn get_features(&self) -> ClientResult<Features> {
450        let mut guard = self.features.lock().await;
451        let features = if let Some(features) = &*guard {
452            features.clone()
453        } else {
454            let msg = Request {
455                msg_id: self.gen_id(),
456                entity_id: "".to_owned(),
457                client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
458            };
459
460            let features = Features(Arc::new(match self.oneshot(&msg).await? {
461                ClientResp::GetFeaturesResp(features) => Ok(features),
462                resp => Err(resp),
463            }?));
464
465            *guard = Some(features.clone());
466            features
467        };
468
469        Ok(features)
470    }
471
472    pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
525        let entity_id = match options.name.clone() {
526            Some(x) => x.to_owned(),
527            None => randid(),
528        };
529
530        if let TableData::View(view) = &input {
531            let window = ViewWindow::default();
532            let arrow = view.to_arrow(window).await?;
533            let mut table = self
534                .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
535                .await?;
536
537            let table_ = table.clone();
538            let callback = asyncfn!(table_, update, async move |update: OnUpdateData| {
539                let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
540                let options = crate::UpdateOptions::default();
541                table_.update(update, options).await.unwrap_or_log();
542            });
543
544            let options = OnUpdateOptions {
545                mode: Some(OnUpdateMode::Row),
546            };
547
548            let on_update_token = view.on_update(callback, options).await?;
549            table.view_update_token = Some(on_update_token);
550            Ok(table)
551        } else {
552            self.crate_table_inner(input, options.into(), entity_id)
553                .await
554        }
555    }
556
557    async fn crate_table_inner(
558        &self,
559        input: TableData,
560        options: TableOptions,
561        entity_id: String,
562    ) -> ClientResult<Table> {
563        let msg = Request {
564            msg_id: self.gen_id(),
565            entity_id: entity_id.clone(),
566            client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
567                data: Some(input.into()),
568                options: Some(options.clone().try_into()?),
569            })),
570        };
571
572        let client = self.clone();
573        match self.oneshot(&msg).await? {
574            ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
575            resp => Err(resp.into()),
576        }
577    }
578
579    async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
580        let msg = Request {
581            msg_id: self.gen_id(),
582            entity_id: "".to_owned(),
583            client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
584                subscribe: false,
585            })),
586        };
587
588        match self.oneshot(&msg).await? {
589            ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
590            resp => Err(resp.into()),
591        }
592    }
593
594    pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
607        let infos = self.get_table_infos().await?;
608
609        if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
611            let options = TableOptions {
612                index: info.index,
613                limit: info.limit,
614            };
615
616            let client = self.clone();
617            Ok(Table::new(entity_id, client, options))
618        } else {
619            Err(ClientError::Unknown("Unknown table".to_owned()))
620        }
621    }
622
623    pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
636        let msg = Request {
637            msg_id: self.gen_id(),
638            entity_id: "".to_owned(),
639            client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
640                subscribe: false,
641            })),
642        };
643
644        match self.oneshot(&msg).await? {
645            ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
646                Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
647            },
648            resp => Err(resp.into()),
649        }
650    }
651
652    pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
656    where
657        T: Fn() -> U + Send + Sync + 'static,
658        U: Future<Output = ()> + Send + 'static,
659    {
660        let on_update = Arc::new(on_update);
661        let callback = asyncfn!(on_update, async move |resp: Response| {
662            match resp.client_resp {
663                Some(ClientResp::GetHostedTablesResp(_)) | None => {
664                    on_update().await;
665                    Ok(())
666                },
667                resp => Err(resp.into()),
668            }
669        });
670
671        let msg = Request {
672            msg_id: self.gen_id(),
673            entity_id: "".to_owned(),
674            client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
675                subscribe: true,
676            })),
677        };
678
679        self.subscribe(&msg, callback).await?;
680        Ok(msg.msg_id)
681    }
682
683    pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
686        let msg = Request {
687            msg_id: self.gen_id(),
688            entity_id: "".to_owned(),
689            client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
690                RemoveHostedTablesUpdateReq { id: update_id },
691            )),
692        };
693
694        self.unsubscribe(update_id).await?;
695        match self.oneshot(&msg).await? {
696            ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
697            resp => Err(resp.into()),
698        }
699    }
700
701    pub async fn system_info(&self) -> ClientResult<SystemInfo> {
705        let msg = Request {
706            msg_id: self.gen_id(),
707            entity_id: "".to_string(),
708            client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
709        };
710
711        match self.oneshot(&msg).await? {
712            ClientResp::ServerSystemInfoResp(resp) => {
713                #[cfg(not(target_family = "wasm"))]
714                let timestamp = Some(
715                    std::time::SystemTime::now()
716                        .duration_since(std::time::UNIX_EPOCH)?
717                        .as_millis() as u64,
718                );
719
720                #[cfg(target_family = "wasm")]
721                let timestamp = None;
722
723                #[cfg(feature = "talc-allocator")]
724                let (client_used, client_heap) = {
725                    let (client_used, client_heap) = crate::utils::get_used();
726                    (Some(client_used as u64), Some(client_heap as u64))
727                };
728
729                #[cfg(not(feature = "talc-allocator"))]
730                let (client_used, client_heap) = (None, None);
731
732                let info = SystemInfo {
733                    heap_size: resp.heap_size,
734                    used_size: resp.used_size,
735                    cpu_time: resp.cpu_time,
736                    cpu_time_epoch: resp.cpu_time_epoch,
737                    timestamp,
738                    client_heap,
739                    client_used,
740                };
741
742                Ok(info)
743            },
744            resp => Err(resp.into()),
745        }
746    }
747}