1use std::collections::HashMap;
14use std::error::Error;
15use std::sync::atomic::AtomicU32;
16use std::sync::Arc;
17
18use async_lock::{Mutex, RwLock};
19use futures::future::{join_all, BoxFuture, LocalBoxFuture};
20use futures::Future;
21use nanoid::*;
22use prost::Message;
23use serde::{Deserialize, Serialize};
24
25use crate::proto::request::ClientReq;
26use crate::proto::response::ClientResp;
27use crate::proto::{
28 self, ColumnType, GetFeaturesReq, GetFeaturesResp, GetHostedTablesReq, GetHostedTablesResp,
29 HostedTable, MakeTableReq, Request, Response, ServerSystemInfoReq,
30};
31use crate::table::{Table, TableInitOptions, TableOptions};
32use crate::table_data::{TableData, UpdateData};
33use crate::utils::*;
34use crate::view::ViewWindow;
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct SystemInfo {
39 pub heap_size: f64,
40}
41
42impl From<proto::ServerSystemInfoResp> for SystemInfo {
43 fn from(value: proto::ServerSystemInfoResp) -> Self {
44 SystemInfo {
45 heap_size: value.heap_size,
46 }
47 }
48}
49
50pub type Features = Arc<GetFeaturesResp>;
53
54impl GetFeaturesResp {
55 pub fn default_op(&self, col_type: ColumnType) -> Option<&str> {
56 self.filter_ops
57 .get(&(col_type as u32))?
58 .options
59 .first()
60 .map(|x| x.as_str())
61 }
62}
63
64type BoxFn<I, O> = Box<dyn Fn(I) -> O + Send + Sync + 'static>;
65type Box2Fn<I, J, O> = Box<dyn Fn(I, J) -> O + Send + Sync + 'static>;
66
67type Subscriptions<C> = Arc<RwLock<HashMap<u32, C>>>;
68type OnErrorCallback =
69 Box2Fn<Option<String>, Option<ReconnectCallback>, BoxFuture<'static, Result<(), ClientError>>>;
70type OnceCallback = Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>;
71type SendCallback = Arc<
72 dyn for<'a> Fn(&'a Request) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
73 + Send
74 + Sync
75 + 'static,
76>;
77
78pub trait ClientHandler: Clone + Send + Sync + 'static {
79 fn send_request<'a>(
80 &'a self,
81 msg: &'a [u8],
82 ) -> impl Future<Output = Result<(), Box<dyn Error + Send + Sync>>> + Send;
83}
84
85#[derive(Clone)]
86#[doc = include_str!("../../docs/client.md")]
87pub struct Client {
88 features: Arc<Mutex<Option<Features>>>,
89 send: SendCallback,
90 id_gen: Arc<AtomicU32>,
91 subscriptions_errors: Subscriptions<OnErrorCallback>,
92 subscriptions_once: Subscriptions<OnceCallback>,
93 subscriptions: Subscriptions<BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>>,
94}
95
96impl std::fmt::Debug for Client {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Client")
99 .field("id_gen", &self.id_gen)
100 .finish()
101 }
102}
103
104pub type ReconnectCallback =
111 Arc<dyn Fn() -> LocalBoxFuture<'static, Result<(), Box<dyn Error>>> + Send + Sync>;
112
113impl Client {
114 pub fn new_with_callback<T>(send_request: T) -> Self
117 where
118 T: for<'a> Fn(&'a [u8]) -> BoxFuture<'a, Result<(), Box<dyn Error + Send + Sync>>>
119 + 'static
120 + Sync
121 + Send,
122 {
123 let send_request = Arc::new(send_request);
124 let send: SendCallback = Arc::new(move |req| {
125 let mut bytes: Vec<u8> = Vec::new();
126 req.encode(&mut bytes).unwrap();
127 let send_request = send_request.clone();
128 Box::pin(async move { send_request(&bytes).await })
129 });
130
131 Client {
132 features: Arc::default(),
133 id_gen: Arc::new(AtomicU32::new(1)),
134 send,
135 subscriptions: Subscriptions::default(),
136 subscriptions_errors: Arc::default(),
137 subscriptions_once: Arc::default(),
138 }
139 }
140
141 pub fn new<T>(client_handler: T) -> Self
143 where
144 T: ClientHandler + 'static + Sync + Send,
145 {
146 Self::new_with_callback(move |req| {
147 let client_handler = client_handler.clone();
148 Box::pin(async move { client_handler.send_request(req).await })
149 })
150 }
151
152 pub async fn handle_response<'a>(&'a self, msg: &'a [u8]) -> ClientResult<bool> {
159 let msg = Response::decode(msg)?;
160 tracing::debug!("RECV {}", msg);
161 let mut wr = self.subscriptions_once.write().await;
162 if let Some(handler) = (*wr).remove(&msg.msg_id) {
163 drop(wr);
164 handler(msg)?;
165 return Ok(true);
166 } else if let Some(handler) = self.subscriptions.try_read().unwrap().get(&msg.msg_id) {
167 drop(wr);
168 handler(msg).await?;
169 return Ok(true);
170 }
171
172 tracing::warn!("Received unsolicited server message");
173 Ok(false)
174 }
175
176 pub async fn handle_error(
177 &self,
178 message: Option<String>,
179 reconnect: Option<ReconnectCallback>,
180 ) -> ClientResult<()> {
181 let subs = self.subscriptions_errors.read().await;
182 let tasks = join_all(
183 subs.values()
184 .map(|callback| callback(message.clone(), reconnect.clone())),
185 );
186 tasks.await.into_iter().collect::<Result<(), _>>()?;
187 Ok(())
188 }
189
190 pub async fn on_error(&self, on_error: OnErrorCallback) -> ClientResult<u32> {
191 let id = self.gen_id();
192 self.subscriptions_errors.write().await.insert(id, on_error);
193 Ok(id)
194 }
195
196 pub async fn init(&self) -> ClientResult<()> {
197 let msg = Request {
198 msg_id: self.gen_id(),
199 entity_id: "".to_owned(),
200 client_req: Some(ClientReq::GetFeaturesReq(GetFeaturesReq {})),
201 };
202
203 *self.features.lock().await = Some(Arc::new(match self.oneshot(&msg).await? {
204 ClientResp::GetFeaturesResp(features) => Ok(features),
205 resp => Err(resp),
206 }?));
207
208 Ok(())
209 }
210
211 pub(crate) fn gen_id(&self) -> u32 {
213 self.id_gen
214 .fetch_add(1, std::sync::atomic::Ordering::Acquire)
215 }
216
217 pub(crate) async fn unsubscribe(&self, update_id: u32) -> ClientResult<()> {
218 let callback = self
219 .subscriptions
220 .write()
221 .await
222 .remove(&update_id)
223 .ok_or(ClientError::Unknown("remove_update".to_string()))?;
224
225 drop(callback);
226 Ok(())
227 }
228
229 pub(crate) async fn subscribe_once(
231 &self,
232 msg: &Request,
233 on_update: Box<dyn FnOnce(Response) -> ClientResult<()> + Send + Sync + 'static>,
234 ) -> ClientResult<()> {
235 self.subscriptions_once
236 .write()
237 .await
238 .insert(msg.msg_id, on_update);
239
240 tracing::debug!("SEND {}", msg);
241 if let Err(e) = (self.send)(msg).await {
242 self.subscriptions_once.write().await.remove(&msg.msg_id);
243 Err(ClientError::Unknown(e.to_string()))
244 } else {
245 Ok(())
246 }
247 }
248
249 pub(crate) async fn subscribe(
250 &self,
251 msg: &Request,
252 on_update: BoxFn<Response, BoxFuture<'static, Result<(), ClientError>>>,
253 ) -> ClientResult<()> {
254 self.subscriptions
255 .write()
256 .await
257 .insert(msg.msg_id, on_update);
258 tracing::debug!("SEND {}", msg);
259 if let Err(e) = (self.send)(msg).await {
260 self.subscriptions.write().await.remove(&msg.msg_id);
261 Err(ClientError::Unknown(e.to_string()))
262 } else {
263 Ok(())
264 }
265 }
266
267 pub(crate) async fn oneshot(&self, msg: &Request) -> ClientResult<ClientResp> {
270 let (sender, receiver) = futures::channel::oneshot::channel::<ClientResp>();
271 let on_update = Box::new(move |msg: Response| {
272 sender.send(msg.client_resp.unwrap()).map_err(|x| x.into())
273 });
274
275 self.subscribe_once(msg, on_update).await?;
276 receiver
277 .await
278 .map_err(|_| ClientError::Unknown("Internal error".to_owned()))
279 }
280
281 pub(crate) fn get_features(&self) -> ClientResult<Features> {
282 Ok(self
283 .features
284 .try_lock()
285 .ok_or(ClientError::NotInitialized)?
286 .as_ref()
287 .ok_or(ClientError::NotInitialized)?
288 .clone())
289 }
290
291 #[doc = include_str!("../../docs/client/table.md")]
292 pub async fn table(&self, input: TableData, options: TableInitOptions) -> ClientResult<Table> {
293 let entity_id = match options.name.clone() {
294 Some(x) => x.to_owned(),
295 None => nanoid!(),
296 };
297
298 if let TableData::View(view) = &input {
299 let window = ViewWindow::default();
300 let arrow = view.to_arrow(window).await?;
301 let mut table = self
302 .crate_table_inner(UpdateData::Arrow(arrow).into(), options.into(), entity_id)
303 .await?;
304
305 let callback = {
306 let table = table.clone();
307 move |update: crate::proto::ViewOnUpdateResp| {
308 let table = table.clone();
309 let update = update.delta.expect("Missing update");
310 async move {
311 table
312 .update(
313 UpdateData::Arrow(update.into()),
314 crate::UpdateOptions::default(),
315 )
316 .await
317 .unwrap_or_log();
318 }
319 }
320 };
321
322 let on_update_token = view
323 .on_update(callback, crate::view::OnUpdateOptions {
324 mode: Some(crate::view::OnUpdateMode::Row),
325 })
326 .await?;
327
328 table.view_update_token = Some(on_update_token);
329 Ok(table)
330 } else {
331 self.crate_table_inner(input, options.into(), entity_id)
332 .await
333 }
334 }
335
336 async fn crate_table_inner(
337 &self,
338 input: TableData,
339 options: TableOptions,
340 entity_id: String,
341 ) -> ClientResult<Table> {
342 let msg = Request {
343 msg_id: self.gen_id(),
344 entity_id: entity_id.clone(),
345 client_req: Some(ClientReq::MakeTableReq(MakeTableReq {
346 data: Some(input.into()),
347 options: Some(options.clone().try_into()?),
348 })),
349 };
350
351 let client = self.clone();
352 match self.oneshot(&msg).await? {
353 ClientResp::MakeTableResp(_) => Ok(Table::new(entity_id, client, options)),
354 resp => Err(resp.into()),
355 }
356 }
357
358 async fn get_table_infos(&self) -> ClientResult<Vec<HostedTable>> {
359 let msg = Request {
360 msg_id: self.gen_id(),
361 entity_id: "".to_owned(),
362 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {})),
363 };
364
365 match self.oneshot(&msg).await? {
366 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => Ok(table_infos),
367 resp => Err(resp.into()),
368 }
369 }
370
371 #[doc = include_str!("../../docs/client/open_table.md")]
372 pub async fn open_table(&self, entity_id: String) -> ClientResult<Table> {
373 let infos = self.get_table_infos().await?;
374
375 if let Some(info) = infos.into_iter().find(|i| i.entity_id == entity_id) {
377 let options = TableOptions {
378 index: info.index,
379 limit: info.limit,
380 };
381
382 let client = self.clone();
383 Ok(Table::new(entity_id, client, options))
384 } else {
385 Err(ClientError::Unknown("Unknown table".to_owned()))
386 }
387 }
388
389 #[doc = include_str!("../../docs/client/get_hosted_table_names.md")]
390 pub async fn get_hosted_table_names(&self) -> ClientResult<Vec<String>> {
391 let msg = Request {
392 msg_id: self.gen_id(),
393 entity_id: "".to_owned(),
394 client_req: Some(ClientReq::GetHostedTablesReq(GetHostedTablesReq {})),
395 };
396
397 match self.oneshot(&msg).await? {
398 ClientResp::GetHostedTablesResp(GetHostedTablesResp { table_infos }) => {
399 Ok(table_infos.into_iter().map(|i| i.entity_id).collect())
400 },
401 resp => Err(resp.into()),
402 }
403 }
404
405 #[doc = include_str!("../../docs/client/system_info.md")]
406 pub async fn system_info(&self) -> ClientResult<SystemInfo> {
407 let msg = Request {
408 msg_id: self.gen_id(),
409 entity_id: "".to_string(),
410 client_req: Some(ClientReq::ServerSystemInfoReq(ServerSystemInfoReq {})),
411 };
412
413 match self.oneshot(&msg).await? {
414 ClientResp::ServerSystemInfoResp(resp) => Ok(resp.into()),
415 resp => Err(resp.into()),
416 }
417 }
418}