1use std::collections::HashMap;
14use std::error::Error;
15use std::sync::Arc;
16
17use async_lock::{Mutex, RwLock};
18use futures::Future;
19use futures::future::{BoxFuture, LocalBoxFuture, join_all};
20use prost::Message;
21use serde::{Deserialize, Serialize};
22
23use crate::proto::request::ClientReq;
24use crate::proto::response::ClientResp;
25use crate::proto::{
26 self, ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
27 HostedTable, MakeTableReq, RemoveHostedTablesUpdateReq, Request, Response, ServerError,
28 ServerSystemInfoReq,
29};
30use crate::table::{Table, TableInitOptions, TableOptions};
31use crate::table_data::{TableData, UpdateData};
32use crate::utils::*;
33use crate::view::ViewWindow;
34use crate::{OnUpdateMode, OnUpdateOptions, asyncfn, clone};
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct SystemInfo {
39 pub heap_size: f64,
40}
41
42impl From<proto::ServerSystemInfoResp> for SystemInfo {
43 fn from(value: proto::ServerSystemInfoResp) -> Self {
44 SystemInfo {
45 heap_size: value.heap_size,
46 }
47 }
48}
49
50pub type Features = Arc<GetFeaturesResp>;
53
54impl GetFeaturesResp {
55 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
56 self.filter_ops
57 .get(&(col_type as u32))?
58 .options
59 .first()
60 .map(|x| x.as_str())
61 }
62}
63
64type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
65type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
66
67type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
68type OnErrorCallback =
69 Box2Fn<ClientError, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
70type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
71type SendCallback = Arc<
72 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
73 + Send
74 + Sync
75 + 'static,
76>;
77
78pub trait ClientHandler: Clone + Send + Sync + 'static {
79 fn send_request(
80 &self,
81 msg: Vec<u8>,
82 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
83}
84
85mod name_registry {
86 use std::collections::HashSet;
87 use std::sync::{Arc, LazyLock, Mutex};
88
89 use crate::ClientError;
90 use crate::view::ClientResult;
91
92 static CLIENT_ID_GEN: LazyLock<Arc<Mutex<u32>>> = LazyLock::new(Arc::default);
93 static REGISTERED_CLIENTS: LazyLock<Arc<Mutex<HashSet<String>>>> = LazyLock::new(Arc::default);
94
95 pub(crate) fn generate_name(name: Option<&str>) -> ClientResult<String> {
96 if let Some(name) = name {
97 if let Some(name) = REGISTERED_CLIENTS
98 .lock()
99 .map_err(ClientError::from)?
100 .get(name)
101 {
102 Err(ClientError::DuplicateNameError(name.to_owned()))
103 } else {
104 Ok(name.to_owned())
105 }
106 } else {
107 let mut guard = CLIENT_ID_GEN.lock()?;
108 *guard += 1;
109 Ok(format!("client-{}", guard))
110 }
111 }
112}
113
114pub type ReconnectCallback =
121 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>;
122
123#[derive(Clone)]
140pub struct Client {
141 name: Arc<String>,
142 features: Arc<Mutex<Option<Features>>>,
143 send: SendCallback,
144 id_gen: IDGen,
145 subscriptions_errors: Subscriptions<OnErrorCallback>,
146 subscriptions_once: Subscriptions<OnceCallback>,
147 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
148}
149
150impl PartialEq for Client {
151 fn eq(&self, other: &Self) -> bool {
152 self.name == other.name
153 }
154}
155
156impl std::fmt::Debug for Client {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("Client").finish()
159 }
160}
161
162impl Client {
163 pub fn new_with_callback<T, U>(name: Option<&str>, send_request: T) -> ClientResult<Self>
166 where
167 T: Fn(Vec<u8>) -> U + 'static + Sync + Send,
168 U: Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send + 'static,
169 {
170 let name = name_registry::generate_name(name)?;
171 let send_request = Arc::new(send_request);
172 let send: SendCallback = Arc::new(move |req| {
173 let mut bytes: Vec<u8> = Vec::new();
174 req.encode(&mut bytes).unwrap();
175 let send_request = send_request.clone();
176 Box::pin(async move { send_request(bytes).await })
177 });
178
179 Ok(Client {
180 name: Arc::new(name),
181 features: Arc::default(),
182 id_gen: IDGen::default(),
183 send,
184 subscriptions: Subscriptions::default(),
185 subscriptions_errors: Arc::default(),
186 subscriptions_once: Arc::default(),
187 })
188 }
189
190 pub fn new<T>(name: Option<&str>, client_handler: T) -> ClientResult<Self>
192 where
193 T: ClientHandler + 'static + Sync + Send,
194 {
195 Self::new_with_callback(
196 name,
197 asyncfn!(client_handler, async move |req| {
198 client_handler.send_request(req).await
199 }),
200 )
201 }
202
203 pub fn get_name(&self) -> &'_ str {
204 self.name.as_str()
205 }
206
207 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
214 let msg = Response::decode(msg)?;
215 tracing::debug!("RECV {}", msg);
216 let mut wr = self.subscriptions_once.write().await;
217 if let Some(handler) = (*wr).remove(&msg.msg_id) {
218 drop(wr);
219 handler(msg)?;
220 return Ok(true);
221 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
222 drop(wr);
223 handler(msg).await?;
224 return Ok(true);
225 }
226
227 if let Response {
228 client_resp: Some(ClientResp::ServerError(ServerError { message, .. })),
229 ..
230 } = &msg
231 {
232 tracing::error!("{}", message);
233 } else {
234 tracing::debug!("Received unsolicited server response: {}", msg);
235 }
236
237 Ok(false)
238 }
239
240 pub async fn handle_error<T, U>(
242 &self,
243 message: ClientError,
244 reconnect: Option<T>,
245 ) -> ClientResult<()>
246 where
247 T: Fn() -> U + Clone + Send + Sync + 'static,
248 U: Future<Output = ClientResult<()>>,
249 {
250 let subs = self.subscriptions_errors.read().await;
251 let tasks = join_all(subs.values().map(|callback| {
252 callback(
253 message.clone(),
254 reconnect.clone().map(move |f| {
255 Arc::new(move || {
256 clone!(f);
257 Box::pin(async move { Ok(f().await?) }) as LocalBoxFuture<'static, _>
258 }) as ReconnectCallback
259 }),
260 )
261 }));
262
263 tasks.await.into_iter().collect::<Result<(), _>>()?;
264 self.close_and_error_subscriptions(&message).await
265 }
266
267 async fn close_and_error_subscriptions(&self, message: &ClientError) -> ClientResult<()> {
272 let synthetic_error = |msg_id| Response {
273 msg_id,
274 entity_id: "".to_string(),
275 client_resp: Some(ClientResp::ServerError(ServerError {
276 message: format!("{}", message),
277 status_code: 2,
278 })),
279 };
280
281 self.subscriptions.write().await.clear();
282 let callbacks_once = self
283 .subscriptions_once
284 .write()
285 .await
286 .drain()
287 .collect::<Vec<_>>();
288
289 callbacks_once
290 .into_iter()
291 .try_for_each(|(msg_id, f)| f(synthetic_error(msg_id)))
292 }
293
294 pub async fn on_error<T, U, V>(&self, on_error: T) -> ClientResult<u32>
295 where
296 T: Fn(ClientError, Option<ReconnectCallback>) -> U + Clone + Send + Sync + 'static,
297 U: Future<Output = V> + Send + 'static,
298 V: Into<Result<(), ClientError>> + Sync + 'static,
299 {
300 let id = self.gen_id();
301 let callback = asyncfn!(on_error, async move |x, y| on_error(x, y).await.into());
302 self.subscriptions_errors
303 .write()
304 .await
305 .insert(id, Box::new(move |x, y| Box::pin(callback(x, y))));
306
307 Ok(id)
308 }
309
310 pub async fn init(&self) -> ClientResult<()> {
311 let msg = Request {
312 msg_id: self.gen_id(),
313 entity_id: "".to_owned(),
314 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
315 };
316
317 *self.features.lock().await = Some(Arc::new(match self.oneshot(&msg).await? {
318 ClientResp::GetFeaturesResp(features) => Ok(features),
319 resp => Err(resp),
320 }?));
321
322 Ok(())
323 }
324
325 pub(crate) fn gen_id(&self) -> u32 {
327 self.id_gen.next()
328 }
329
330 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
331 let callback = self
332 .subscriptions
333 .write()
334 .await
335 .remove(&update_id)
336 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
337
338 drop(callback);
339 Ok(())
340 }
341
342 pub(crate) async fn subscribe_once(
344 &self,
345 msg: &Request,
346 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
347 ) -> ClientResult<()> {
348 self.subscriptions_once
349 .write()
350 .await
351 .insert(msg.msg_id, on_update);
352
353 tracing::debug!("SEND {}", msg);
354 if let Err(e) = (self.send)(msg).await {
355 self.subscriptions_once.write().await.remove(&msg.msg_id);
356 Err(ClientError::Unknown(e.to_string()))
357 } else {
358 Ok(())
359 }
360 }
361
362 pub(crate) async fn subscribe<T, U>(&self, msg: &Request, on_update: T) -> ClientResult<()>
363 where
364 T: Fn(Response) -> U + Send + Sync + 'static,
365 U: Future<Output = Result<(), ClientError>> + Send + 'static,
366 {
367 self.subscriptions
368 .write()
369 .await
370 .insert(msg.msg_id, Box::new(move |x| Box::pin(on_update(x))));
371
372 tracing::debug!("SEND {}", msg);
373 if let Err(e) = (self.send)(msg).await {
374 self.subscriptions.write().await.remove(&msg.msg_id);
375 Err(ClientError::Unknown(e.to_string()))
376 } else {
377 Ok(())
378 }
379 }
380
381 pub(crate) async fn oneshot(&self, req: &Request) -> ClientResult<ClientResp> {
384 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
385 let on_update = Box::new(move |res: Response| {
386 sender.send(res.client_resp.unwrap()).map_err(|x| x.into())
387 });
388
389 self.subscribe_once(req, on_update).await?;
390 receiver
391 .await
392 .map_err(|_| ClientError::Unknown(format!("Internal error for req {}", req)))
393 }
394
395 pub(crate) fn get_features(&self) -> ClientResult<Features> {
396 Ok(self
397 .features
398 .try_lock()
399 .ok_or(ClientError::NotInitialized)?
400 .as_ref()
401 .ok_or(ClientError::NotInitialized)?
402 .clone())
403 }
404
405 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
457 let entity_id = match options.name.clone() {
458 Some(x) => x.to_owned(),
459 None => randid(),
460 };
461
462 if let TableData::View(view) = &input {
463 let window = ViewWindow::default();
464 let arrow = view.to_arrow(window).await?;
465 let mut table = self
466 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
467 .await?;
468
469 let table_ = table.clone();
470 let callback = asyncfn!(
471 table_,
472 update,
473 async move |update: crate::proto::ViewOnUpdateResp| {
474 let update = UpdateData::Arrow(update.delta.expect("Malformed message").into());
475 let options = crate::UpdateOptions::default();
476 table_.update(update, options).await.unwrap_or_log();
477 }
478 );
479
480 let options = OnUpdateOptions {
481 mode: Some(OnUpdateMode::Row),
482 };
483
484 let on_update_token = view.on_update(callback, options).await?;
485 table.view_update_token = Some(on_update_token);
486 Ok(table)
487 } else {
488 self.crate_table_inner(input, options.into(), entity_id)
489 .await
490 }
491 }
492
493 async fn crate_table_inner(
494 &self,
495 input: TableData,
496 options: TableOptions,
497 entity_id: String,
498 ) -> ClientResult<Table> {
499 let msg = Request {
500 msg_id: self.gen_id(),
501 entity_id: entity_id.clone(),
502 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
503 data: Some(input.into()),
504 options: Some(options.clone().try_into()?),
505 })),
506 };
507
508 let client = self.clone();
509 match self.oneshot(&msg).await? {
510 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
511 resp => Err(resp.into()),
512 }
513 }
514
515 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
516 let msg = Request {
517 msg_id: self.gen_id(),
518 entity_id: "".to_owned(),
519 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
520 subscribe: false,
521 })),
522 };
523
524 match self.oneshot(&msg).await? {
525 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
526 resp => Err(resp.into()),
527 }
528 }
529
530 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
543 let infos = self.get_table_infos().await?;
544
545 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
547 let options = TableOptions {
548 index: info.index,
549 limit: info.limit,
550 };
551
552 let client = self.clone();
553 Ok(Table::new(entity_id, client, options))
554 } else {
555 Err(ClientError::Unknown("Unknown table".to_owned()))
556 }
557 }
558
559 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
572 let msg = Request {
573 msg_id: self.gen_id(),
574 entity_id: "".to_owned(),
575 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
576 subscribe: false,
577 })),
578 };
579
580 match self.oneshot(&msg).await? {
581 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
582 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
583 },
584 resp => Err(resp.into()),
585 }
586 }
587
588 pub async fn on_hosted_tables_update<T, U>(&self, on_update: T) -> ClientResult<u32>
592 where
593 T: Fn() -> U + Send + Sync + 'static,
594 U: Future<Output = ()> + Send + 'static,
595 {
596 let on_update = Arc::new(on_update);
597 let callback = asyncfn!(on_update, async move |resp: Response| {
598 match resp.client_resp {
599 Some(ClientResp::GetHostedTablesResp(_)) | None => {
600 on_update().await;
601 Ok(())
602 },
603 resp => Err(resp.into()),
604 }
605 });
606
607 let msg = Request {
608 msg_id: self.gen_id(),
609 entity_id: "".to_owned(),
610 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {
611 subscribe: true,
612 })),
613 };
614
615 self.subscribe(&msg, callback).await?;
616 Ok(msg.msg_id)
617 }
618
619 pub async fn remove_hosted_tables_update(&self, update_id: u32) -> ClientResult<()> {
622 let msg = Request {
623 msg_id: self.gen_id(),
624 entity_id: "".to_owned(),
625 client_req: Some(ClientReq::RemoveHostedTablesUpdateReq(
626 RemoveHostedTablesUpdateReq { id: update_id },
627 )),
628 };
629
630 self.unsubscribe(update_id).await?;
631 match self.oneshot(&msg).await? {
632 ClientResp::RemoveHostedTablesUpdateResp(_) => Ok(()),
633 resp => Err(resp.into()),
634 }
635 }
636
637 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
641 let msg = Request {
642 msg_id: self.gen_id(),
643 entity_id: "".to_string(),
644 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
645 };
646
647 match self.oneshot(&msg).await? {
648 ClientResp::ServerSystemInfoResp(resp) => Ok(resp.into()),
649 resp => Err(resp.into()),
650 }
651 }
652}