1use std::collections::HashMap;
14use std::error::Error;
15use std::sync::Arc;
16
17use async_lock::{Mutex, RwLock};
18use futures::Future;
19use futures::future::{BoxFuture, LocalBoxFuture, join_all};
20use nanoid::*;
21use prost::Message;
22use serde::{Deserialize, Serialize};
23
24use crate::proto::request::ClientReq;
25use crate::proto::response::ClientResp;
26use crate::proto::{
27 self, ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
28 HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerError,
29 ServerSystemInfoReq,
30};
31use crate::table::{Table, TableInitOptions, TableOptions};
32use crate::table_data::{TableData, UpdateData};
33use crate::utils::*;
34use crate::view::ViewWindow;
35use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct SystemInfo {
40 pub heap_size: f64,
41}
42
43impl From<proto::ServerSystemInfoResp> for SystemInfo {
44 fn from(value: proto::ServerSystemInfoResp) -> Self {
45 SystemInfo {
46 heap_size: value.heap_size,
47 }
48 }
49}
50
51pub type Features = Arc<GetFeaturesResp>;
54
55impl GetFeaturesResp {
56 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
57 self.filter_ops
58 .get(&(col_type as u32))?
59 .options
60 .first()
61 .map(|x| x.as_str())
62 }
63}
64
65type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
66type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
67
68type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
69type OnErrorCallback =
70 Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
71type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
72type SendCallback = Arc<
73 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
74 + Send
75 + Sync
76 + 'static,
77>;
78
79pub trait ClientHandler: Clone + Send + Sync + 'static {
80 fn send_request(
81 &self,
82 msg: Vec<u8>,
83 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
84}
85
86mod name_registry {
87 use std::collections::HashSet;
88 use std::sync::{Arc, LazyLock, Mutex};
89
90 use crate::ClientError;
91 use crate::view::ClientResult;
92
93 static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
94 static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
95
96 pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
97 if let Some(name) = name {
98 if let Some(name) = REGISTERED_CLIENTS
99 .lock()
100 .map_err(ClientError::from)?
101 .get(name)
102 {
103 Err(ClientError::DuplicateNameError(name.to_owned()))
104 } else {
105 Ok(name.to_owned())
106 }
107 } else {
108 let mut guard = CLIENT_ID_GEN.lock()?;
109 *guard += 1;
110 Ok(format!("client-{}", guard))
111 }
112 }
113}
114
115pub type ReconnectCallback =
122 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>;
123
124#[derive(Clone)]
141pub struct Client {
142 name: Arc<String>,
143 features: Arc<Mutex<Option<Features>>>,
144 send: SendCallback,
145 id_gen: IDGen,
146 subscriptions_errors: Subscriptions<OnErrorCallback>,
147 subscriptions_once: Subscriptions<OnceCallback>,
148 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
149}
150
151impl PartialEq for Client {
152 fn eq(&self, other: &Self) -> bool {
153 self.name == other.name
154 }
155}
156
157impl std::fmt::Debug for Client {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("Client").finish()
160 }
161}
162
163impl Client {
164 pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
167 where
168 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
169 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
170 {
171 let name = name_registry::generate_name(name)?;
172 let send_request = Arc::new(send_request);
173 let send: SendCallback = Arc::new(move |req| {
174 let mut bytes: Vec<u8> = Vec::new();
175 req.encode(&mut bytes).unwrap();
176 let send_request = send_request.clone();
177 Box::pin(async move { send_request(bytes).await })
178 });
179
180 Ok(Client {
181 name: Arc::new(name),
182 features: Arc::default(),
183 id_gen: IDGen::default(),
184 send,
185 subscriptions: Subscriptions::default(),
186 subscriptions_errors: Arc::default(),
187 subscriptions_once: Arc::default(),
188 })
189 }
190
191 pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
193 where
194 T: ClientHandler + 'static + Sync + Send,
195 {
196 Self::new_with_callback(
197 name,
198 asyncfn!(client_handler, async move |req| {
199 client_handler.send_request(req).await
200 }),
201 )
202 }
203
204 pub fn get_name(&self) -> &'_ str {
205 self.name.as_str()
206 }
207
208 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
215 let msg = Response::decode(msg)?;
216 tracing::debug!("RECV {}", msg);
217 let mut wr = self.subscriptions_once.write().await;
218 if let Some(handler) = (*wr).remove(&msg.msg_id) {
219 drop(wr);
220 handler(msg)?;
221 return Ok(true);
222 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
223 drop(wr);
224 handler(msg).await?;
225 return Ok(true);
226 }
227
228 if let Response {
229 client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
230 ..
231 } = &msg
232 {
233 tracing::error!("{}", message);
234 } else {
235 tracing::debug!("Received unsolicited server response: {}", msg);
236 }
237
238 Ok(false)
239 }
240
241 pub async fn handle_error<T, U>(
243 &self,
244 message: ClientError,
245 reconnect: Option<T>,
246 ) -> ClientResult<()>
247 where
248 T: Fn() -> U + Clone + Send + Sync + 'static,
249 U: Future<Output = ClientResult<()>>,
250 {
251 let subs = self.subscriptions_errors.read().await;
252 let tasks = join_all(subs.values().map(|callback| {
253 callback(
254 message.clone(),
255 reconnect.clone().map(move |f| {
256 Arc::new(move || {
257 clone!(f);
258 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
259 }) as ReconnectCallback
260 }),
261 )
262 }));
263
264 tasks.await.into_iter().collect::<Result<(), _>>()?;
265 self.close_and_error_subscriptions(&message).await
266 }
267
268 async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
273 let synthetic_error = |msg_id| Response {
274 msg_id,
275 entity_id: "".to_string(),
276 client_resp: Some(ClientResp::ServerError(ServerError {
277 message: format!("{}", message),
278 status_code: 2,
279 })),
280 };
281
282 self.subscriptions.write().await.clear();
283 let callbacks_once = self
284 .subscriptions_once
285 .write()
286 .await
287 .drain()
288 .collect::<Vec<_>>();
289
290 callbacks_once
291 .into_iter()
292 .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
293 }
294
295 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
296 where
297 T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
298 U: Future<Output = V> + Send + 'static,
299 V: Into<Result<(), ClientError>> + Sync + 'static,
300 {
301 let id = self.gen_id();
302 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
303 self.subscriptions_errors
304 .write()
305 .await
306 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
307
308 Ok(id)
309 }
310
311 pub async fn init(&self) -> ClientResult<()> {
312 let msg = Request {
313 msg_id: self.gen_id(),
314 entity_id: "".to_owned(),
315 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
316 };
317
318 *self.features.lock().await = Some(Arc::new(match self.oneshot(&msg).await? {
319 ClientResp::GetFeaturesResp(features) => Ok(features),
320 resp => Err(resp),
321 }?));
322
323 Ok(())
324 }
325
326 pub(crate) fn gen_id(&self) -> u32 {
328 self.id_gen.next()
329 }
330
331 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
332 let callback = self
333 .subscriptions
334 .write()
335 .await
336 .remove(&update_id)
337 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
338
339 drop(callback);
340 Ok(())
341 }
342
343 pub(crate) async fn subscribe_once(
345 &self,
346 msg: &Request,
347 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
348 ) -> ClientResult<()> {
349 self.subscriptions_once
350 .write()
351 .await
352 .insert(msg.msg_id, on_update);
353
354 tracing::debug!("SEND {}", msg);
355 if let Err(e) = (self.send)(msg).await {
356 self.subscriptions_once.write().await.remove(&msg.msg_id);
357 Err(ClientError::Unknown(e.to_string()))
358 } else {
359 Ok(())
360 }
361 }
362
363 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
364 where
365 T: Fn(Response) -> U + Send + Sync + 'static,
366 U: Future<Output = Result<(), ClientError>> + Send + 'static,
367 {
368 self.subscriptions
369 .write()
370 .await
371 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
372
373 tracing::debug!("SEND {}", msg);
374 if let Err(e) = (self.send)(msg).await {
375 self.subscriptions.write().await.remove(&msg.msg_id);
376 Err(ClientError::Unknown(e.to_string()))
377 } else {
378 Ok(())
379 }
380 }
381
382 pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
385 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
386 let on_update = Box::new(move |res: Response| {
387 sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
388 });
389
390 self.subscribe_once(req, on_update).await?;
391 receiver
392 .await
393 .map_err(|_| ClientError::Unknown(format!("Internal error for req {}", req)))
394 }
395
396 pub(crate) fn get_features(&self) -> ClientResult<Features> {
397 Ok(self
398 .features
399 .try_lock()
400 .ok_or(ClientError::NotInitialized)?
401 .as_ref()
402 .ok_or(ClientError::NotInitialized)?
403 .clone())
404 }
405
406 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
458 let entity_id = match options.name.clone() {
459 Some(x) => x.to_owned(),
460 None => nanoid!(),
461 };
462
463 if let TableData::View(view) = &input {
464 let window = ViewWindow::default();
465 let arrow = view.to_arrow(window).await?;
466 let mut table = self
467 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
468 .await?;
469
470 let table_ = table.clone();
471 let callback = asyncfn!(
472 table_,
473 update,
474 async move |update: crate::proto::ViewOnUpdateResp| {
475 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
476 let options = crate::UpdateOptions::default();
477 table_.update(update, options).await.unwrap_or_log();
478 }
479 );
480
481 let options = OnUpdateOptions {
482 mode: Some(OnUpdateMode::Row),
483 };
484
485 let on_update_token = view.on_update(callback, options).await?;
486 table.view_update_token = Some(on_update_token);
487 Ok(table)
488 } else {
489 self.crate_table_inner(input, options.into(), entity_id)
490 .await
491 }
492 }
493
494 async fn crate_table_inner(
495 &self,
496 input: TableData,
497 options: TableOptions,
498 entity_id: String,
499 ) -> ClientResult<Table> {
500 let msg = Request {
501 msg_id: self.gen_id(),
502 entity_id: entity_id.clone(),
503 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
504 data: Some(input.into()),
505 options: Some(options.clone().try_into()?),
506 })),
507 };
508
509 let client = self.clone();
510 match self.oneshot(&msg).await? {
511 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
512 resp => Err(resp.into()),
513 }
514 }
515
516 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
517 let msg = Request {
518 msg_id: self.gen_id(),
519 entity_id: "".to_owned(),
520 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
521 subscribe: false,
522 })),
523 };
524
525 match self.oneshot(&msg).await? {
526 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
527 resp => Err(resp.into()),
528 }
529 }
530
531 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
544 let infos = self.get_table_infos().await?;
545
546 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
548 let options = TableOptions {
549 index: info.index,
550 limit: info.limit,
551 };
552
553 let client = self.clone();
554 Ok(Table::new(entity_id, client, options))
555 } else {
556 Err(ClientError::Unknown("Unknown table".to_owned()))
557 }
558 }
559
560 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
573 let msg = Request {
574 msg_id: self.gen_id(),
575 entity_id: "".to_owned(),
576 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
577 subscribe: false,
578 })),
579 };
580
581 match self.oneshot(&msg).await? {
582 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
583 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
584 },
585 resp => Err(resp.into()),
586 }
587 }
588
589 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
593 where
594 T: Fn() -> U + Send + Sync + 'static,
595 U: Future<Output = ()> + Send + 'static,
596 {
597 let on_update = Arc::new(on_update);
598 let callback = asyncfn!(on_update, async move |resp: Response| {
599 match resp.client_resp {
600 Some(ClientResp::GetHostedTablesResp(_)) | None => {
601 on_update().await;
602 Ok(())
603 },
604 resp => Err(resp.into()),
605 }
606 });
607
608 let msg = Request {
609 msg_id: self.gen_id(),
610 entity_id: "".to_owned(),
611 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
612 subscribe: true,
613 })),
614 };
615
616 self.subscribe(&msg, callback).await?;
617 Ok(msg.msg_id)
618 }
619
620 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
623 let msg = Request {
624 msg_id: self.gen_id(),
625 entity_id: "".to_owned(),
626 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
627 RemoveHostedTablesUpdateReq { id: update_id },
628 )),
629 };
630
631 self.unsubscribe(update_id).await?;
632 match self.oneshot(&msg).await? {
633 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
634 resp => Err(resp.into()),
635 }
636 }
637
638 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
642 let msg = Request {
643 msg_id: self.gen_id(),
644 entity_id: "".to_string(),
645 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
646 };
647
648 match self.oneshot(&msg).await? {
649 ClientResp::ServerSystemInfoResp(resp) => Ok(resp.into()),
650 resp => Err(resp.into()),
651 }
652 }
653}