1use std::{
2 collections::HashMap,
3 pin::Pin,
4 sync::{Arc, Mutex, RwLock, Weak},
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use exocore_core::{
11 cell::{Cell, CellNodeRole, Node, NodeId},
12 framing::CapnpFrameBuilder,
13 futures::interval,
14 time::{Clock, ConsistentTimestamp, Instant},
15 utils::handle_set::{Handle, HandleSet},
16};
17use exocore_protos::generated::{
18 exocore_store::{EntityQuery, EntityResults, MutationRequest, MutationResult},
19 store_transport_capnp::{
20 mutation_response, query_response, unwatch_query_request, watched_query_response,
21 },
22 MessageType,
23};
24use exocore_transport::{
25 transport::ConnectionStatus, InEvent, InMessage, OutEvent, OutMessage, ServiceType,
26 TransportServiceHandle,
27};
28use futures::{
29 channel::{mpsc, oneshot},
30 prelude::*,
31};
32
33use super::seri::{
34 mutation_result_from_response_frame, mutation_to_request_frame,
35 query_results_from_response_frame, query_to_request_frame, watched_query_to_request_frame,
36};
37use crate::{error::Error, mutation::MutationRequestLike, query::WatchToken};
38
39pub struct Client<T>
42where
43 T: TransportServiceHandle,
44{
45 config: ClientConfiguration,
46 inner: Arc<RwLock<Inner>>,
47 transport_handle: T,
48 handles: HandleSet,
49}
50
51impl<T> Client<T>
52where
53 T: TransportServiceHandle,
54{
55 pub fn new(
56 config: ClientConfiguration,
57 cell: Cell,
58 clock: Clock,
59 transport_handle: T,
60 ) -> Result<Client<T>, Error> {
61 let inner = Arc::new(RwLock::new(Inner {
62 config,
63 cell,
64 clock,
65 transport_out: None,
66 store_node: None,
67 store_node_message_queue: Mutex::new(Vec::new()),
68 nodes_status: HashMap::new(),
69 pending_queries: HashMap::new(),
70 watched_queries: HashMap::new(),
71 pending_mutations: HashMap::new(),
72 }));
73
74 Ok(Client {
75 config,
76 inner,
77 transport_handle,
78 handles: HandleSet::new(),
79 })
80 }
81
82 pub fn get_handle(&self) -> ClientHandle {
83 ClientHandle {
84 inner: Arc::downgrade(&self.inner),
85 handle: self.handles.get_handle(),
86 }
87 }
88
89 pub async fn run(mut self) -> Result<(), Error> {
90 let out_receiver = {
93 let mut inner = self.inner.write()?;
94 let (out_sender, out_receiver) = mpsc::unbounded();
95 inner.transport_out = Some(out_sender);
96 out_receiver
97 };
98
99 let mut transport_sink = self.transport_handle.get_sink();
101 let transport_sender = async move {
102 let mut receiver = out_receiver;
103
104 while let Some(item) = receiver.next().await {
105 transport_sink.send(item).await?;
106 }
107
108 Ok::<(), Error>(())
109 };
110
111 let weak_inner = Arc::downgrade(&self.inner);
113 let mut transport_stream = self.transport_handle.get_stream();
114 let transport_receiver = async move {
115 while let Some(event) = transport_stream.next().await {
116 let res = match event {
117 InEvent::Message(msg) => Inner::handle_incoming_message(&weak_inner, msg),
118 InEvent::NodeStatus(node, status) => {
119 Inner::handle_node_status_change(&weak_inner, node, status)
120 }
121 };
122
123 if let Err(err) = res {
124 if err.is_fatal() {
125 return Err(err);
126 } else {
127 error!("Couldn't process incoming transport message: {}", err);
128 }
129 }
130 }
131
132 Ok::<(), Error>(())
133 };
134
135 let weak_inner = Arc::downgrade(&self.inner);
137 let management_interval = self.config.management_interval;
138 let management_timer = async move {
139 let mut timer = interval(management_interval);
140
141 loop {
142 timer.tick().await;
143 Inner::management_timer_process(&weak_inner)?;
144 }
145
146 #[allow(unreachable_code)]
148 Ok::<(), Error>(())
149 };
150
151 futures::select! {
152 _ = transport_sender.fuse() => {},
153 _ = transport_receiver.fuse() => {},
154 _ = management_timer.fuse() => {},
155 _ = self.transport_handle.fuse() => {},
156 _ = self.handles.on_handles_dropped().fuse() => {},
157 };
158
159 info!("Store client dropped");
160 Ok(())
161 }
162}
163
164#[derive(Debug, Clone, Copy)]
165pub struct ClientConfiguration {
166 pub query_timeout: Duration,
167 pub mutation_timeout: Duration,
168 pub management_interval: Duration,
169 pub watched_register_interval: Duration,
170 pub watched_channel_size: usize,
171 pub watched_re_register_remote_dropped: bool,
172}
173
174impl Default for ClientConfiguration {
176 fn default() -> Self {
177 ClientConfiguration {
178 query_timeout: Duration::from_secs(10),
179 mutation_timeout: Duration::from_secs(5),
180 watched_register_interval: Duration::from_secs(10),
181 management_interval: Duration::from_secs(1),
182 watched_channel_size: 1000,
183 watched_re_register_remote_dropped: true,
184 }
185 }
186}
187
188pub(super) struct Inner {
189 config: ClientConfiguration,
190 cell: Cell,
191 clock: Clock,
192 transport_out: Option<mpsc::UnboundedSender<OutEvent>>,
193 store_node: Option<Node>,
194 store_node_message_queue: Mutex<Vec<OutMessage>>,
195 nodes_status: HashMap<NodeId, ConnectionStatus>,
196 pending_queries: HashMap<ConsistentTimestamp, PendingRequest<EntityResults>>,
197 watched_queries: HashMap<ConsistentTimestamp, WatchedQueryRequest>,
198 pending_mutations: HashMap<ConsistentTimestamp, PendingRequest<MutationResult>>,
199}
200
201impl Inner {
202 fn handle_node_status_change(
203 weak_inner: &Weak<RwLock<Inner>>,
204 node_id: NodeId,
205 node_new_status: ConnectionStatus,
206 ) -> Result<(), Error> {
207 let inner = weak_inner.upgrade().ok_or(Error::Dropped)?;
208 let mut inner = inner.write()?;
209
210 inner.nodes_status.insert(node_id, node_new_status);
211
212 let was_already_connected = inner.store_node.is_some();
213
214 let node_is_connected = |node_id: &NodeId| -> bool {
215 let store_node_status = inner.nodes_status.get(node_id);
216 store_node_status == Some(&ConnectionStatus::Connected)
217 };
218
219 if let Some(store_node) = &inner.store_node {
222 if node_is_connected(store_node.id()) {
223 if node_new_status == ConnectionStatus::Connected {
226 inner.send_watched_queries_keepalive(true);
227 }
228
229 return Ok(());
230 }
231 }
232
233 let new_store_node = {
235 let cell_nodes = inner.cell.nodes();
236 let cell_nodes_iter = cell_nodes.iter();
237
238 let store_node = cell_nodes_iter
239 .with_role(CellNodeRole::Store)
240 .find(|n| node_is_connected(n.node().id()));
241
242 store_node.map(|n| n.node()).cloned()
243 };
244 if let Some(new_store_node) = new_store_node {
245 info!("Switching store server to {}", new_store_node);
246 inner.store_node = Some(new_store_node);
247 }
248
249 if !was_already_connected {
250 inner.drain_store_node_message_queue()?;
251 }
252
253 inner.send_watched_queries_keepalive(true);
254
255 Ok(())
256 }
257
258 fn handle_incoming_message(
259 weak_inner: &Weak<RwLock<Inner>>,
260 in_message: InMessage,
261 ) -> Result<(), Error> {
262 let inner = weak_inner.upgrade().ok_or(Error::Dropped)?;
263 let mut inner = inner.write()?;
264
265 if let Some(store_node) = &inner.store_node {
266 if in_message.source.id() != store_node.id() {
267 warn!("Got message from a node other than store node (from {} != current {}). Dropping it.", in_message.source, store_node);
268 return Ok(());
269 }
270 }
271
272 let Some(rendez_vous_id) = in_message.rendez_vous_id else {
273 return Err(anyhow!(
274 "Got an InMessage without a rendez_vous_id (type={:?} from={})",
275 in_message.typ,
276 in_message.source
277 )
278 .into());
279 };
280
281 match IncomingMessage::parse_incoming_message(&in_message) {
282 Ok(IncomingMessage::MutationResponse(mutation)) => {
283 if let Some(pending_request) = inner.pending_mutations.remove(&rendez_vous_id) {
284 let _ = pending_request.result_sender.send(Ok(mutation));
285 } else {
286 return Err(anyhow!(
287 "Couldn't find pending mutation for mutation response (request_id={:?} type={:?} from={})",
288 rendez_vous_id, in_message.typ, in_message.source
289 ).into());
290 }
291 }
292 Ok(IncomingMessage::QueryResponse(result)) => {
293 if let Some(pending_request) = inner.pending_queries.remove(&rendez_vous_id) {
294 let _ = pending_request.result_sender.send(Ok(result));
295 } else if let Some(watched_query) = inner.watched_queries.get_mut(&rendez_vous_id) {
296 let _ = watched_query.result_sender.try_send(Ok(result));
297 } else {
298 return Err(anyhow!(
299 "Couldn't find pending query for query response (request_id={:?} type={:?} from={})",
300 rendez_vous_id, in_message.typ, in_message.source
301 ).into());
302 }
303 }
304 Err(Error::Remote(err))
305 if inner.config.watched_re_register_remote_dropped
306 && err.contains("unregistered") =>
307 {
308 if let Some(watched_query) = inner.watched_queries.get_mut(&rendez_vous_id) {
309 debug!("Query got unregistered by remote. Re-registering...");
310 watched_query.force_register();
311 }
312 }
313 Err(err) => {
314 if let Some(pending_request) = inner.pending_mutations.remove(&rendez_vous_id) {
315 let _ = pending_request.result_sender.send(Err(err));
316 } else if let Some(mut watched_query) =
317 inner.watched_queries.remove(&rendez_vous_id)
318 {
319 let _ = watched_query.result_sender.try_send(Err(err));
320 } else if let Some(pending_request) = inner.pending_queries.remove(&rendez_vous_id)
321 {
322 let _ = pending_request.result_sender.send(Err(err));
323 }
324 }
325 }
326
327 Ok(())
328 }
329
330 fn management_timer_process(weak_inner: &Weak<RwLock<Inner>>) -> Result<(), Error> {
331 let inner = weak_inner.upgrade().ok_or(Error::Dropped)?;
332 let mut inner = inner.write()?;
333
334 let query_timeout = inner.config.query_timeout;
335 Inner::check_map_requests_timeouts(&mut inner.pending_queries, query_timeout);
336
337 let mutation_timeout = inner.config.mutation_timeout;
338 Inner::check_map_requests_timeouts(&mut inner.pending_mutations, mutation_timeout);
339
340 inner.send_watched_queries_keepalive(false);
341
342 Ok(())
343 }
344
345 fn send_mutation(
346 &mut self,
347 request: MutationRequest,
348 ) -> Result<oneshot::Receiver<Result<MutationResult, Error>>, Error> {
349 let (result_sender, receiver) = oneshot::channel();
350
351 let request_id = self.clock.consistent_time(self.cell.local_node());
352 let request_frame = mutation_to_request_frame(request)?;
353 let message =
354 OutMessage::from_framed_message(&self.cell, ServiceType::Store, request_frame)?
355 .with_expiration(Some(Instant::now() + self.config.mutation_timeout))
356 .with_rdv(request_id);
357 self.send_store_node_message(message)?;
358
359 self.pending_mutations.insert(
360 request_id,
361 PendingRequest {
362 request_id,
363 result_sender,
364 send_time: Instant::now(),
365 },
366 );
367
368 Ok(receiver)
369 }
370
371 fn send_query(
372 &mut self,
373 query: EntityQuery,
374 ) -> Result<oneshot::Receiver<Result<EntityResults, Error>>, Error> {
375 let (result_sender, receiver) = oneshot::channel();
376
377 let request_id = self.clock.consistent_time(self.cell.local_node());
378 let request_frame = query_to_request_frame(&query)?;
379 let message =
380 OutMessage::from_framed_message(&self.cell, ServiceType::Store, request_frame)?
381 .with_expiration(Some(Instant::now() + self.config.query_timeout))
382 .with_rdv(request_id);
383 self.send_store_node_message(message)?;
384
385 self.pending_queries.insert(
386 request_id,
387 PendingRequest {
388 request_id,
389 result_sender,
390 send_time: Instant::now(),
391 },
392 );
393
394 Ok(receiver)
395 }
396
397 fn watch_query(
398 &mut self,
399 query: EntityQuery,
400 ) -> Result<
401 (
402 ConsistentTimestamp,
403 mpsc::Receiver<Result<EntityResults, Error>>,
404 ),
405 Error,
406 > {
407 let (result_sender, result_receiver) = mpsc::channel(self.config.watched_channel_size);
408 let request_id = self.clock.consistent_time(self.cell.local_node());
409 let watched_query = WatchedQueryRequest {
410 request_id,
411 result_sender,
412 query,
413 last_register: Some(Instant::now()),
414 };
415
416 self.send_watch_query(&watched_query)?;
417 self.watched_queries.insert(request_id, watched_query);
418
419 Ok((request_id, result_receiver))
420 }
421
422 fn send_watch_query(&self, watched_query: &WatchedQueryRequest) -> Result<(), Error> {
423 let request_frame = watched_query_to_request_frame(&watched_query.query)?;
424 let message =
425 OutMessage::from_framed_message(&self.cell, ServiceType::Store, request_frame)?
426 .with_rdv(watched_query.request_id);
427
428 self.send_store_node_message(message)
429 }
430
431 fn send_unwatch_query(&self, token: WatchToken) -> Result<(), Error> {
432 let mut frame_builder = CapnpFrameBuilder::<unwatch_query_request::Owned>::new();
433 let mut message_builder = frame_builder.get_builder();
434 message_builder.set_token(token);
435
436 let message =
437 OutMessage::from_framed_message(&self.cell, ServiceType::Store, frame_builder)?;
438
439 self.send_store_node_message(message)
440 }
441
442 fn check_map_requests_timeouts<T>(
443 requests: &mut HashMap<ConsistentTimestamp, PendingRequest<T>>,
444 timeout: Duration,
445 ) {
446 let mut timed_out_requests = Vec::new();
447 for request in requests.values() {
448 if request.send_time.elapsed() > timeout {
449 timed_out_requests.push(request.request_id);
450 }
451 }
452
453 for request_id in timed_out_requests {
454 if let Some(request) = requests.remove(&request_id) {
455 let _ = request
456 .result_sender
457 .send(Err(Error::Timeout(request.send_time.elapsed(), timeout)));
458 }
459 }
460 }
461
462 fn send_watched_queries_keepalive(&mut self, force: bool) {
463 let register_interval = self.config.watched_register_interval;
464
465 let mut sent_queries = Vec::new();
466 for (token, query) in &self.watched_queries {
467 if force
468 || query
469 .last_register
470 .map_or(true, |i| i.elapsed() > register_interval)
471 {
472 if let Err(err) = self.send_watch_query(query) {
473 error!("Couldn't send watch query: {}", err);
474 }
475 sent_queries.push(*token);
476 }
477 }
478
479 for token in &sent_queries {
480 let query = self.watched_queries.get_mut(token).unwrap();
481 query.last_register = Some(Instant::now());
482 }
483 }
484
485 fn send_store_node_message(&self, message: OutMessage) -> Result<(), Error> {
486 let store_node = if let Some(store_node) = &self.store_node {
487 store_node.clone()
488 } else {
489 info!("No store node yet, queueing message");
490 let mut store_node_message_queue = self.store_node_message_queue.lock()?;
491 store_node_message_queue.push(message);
492 return Ok(());
493 };
494
495 let transport = self.transport_out.as_ref().ok_or_else(|| {
496 Error::Fatal(anyhow!("Tried to send message, but transport_out was none"))
497 })?;
498
499 transport
500 .unbounded_send(OutEvent::Message(message.with_destination(store_node)))
501 .map_err(|_err| {
502 Error::Fatal(anyhow!(
503 "Tried to send message, but transport_out channel is closed"
504 ))
505 })?;
506
507 Ok(())
508 }
509
510 fn drain_store_node_message_queue(&self) -> Result<(), Error> {
511 let store_node_message_queue: Vec<OutMessage> = {
512 let mut store_node_message_queue = self.store_node_message_queue.lock()?;
513 std::mem::take(store_node_message_queue.as_mut())
514 };
515
516 for message in store_node_message_queue {
517 self.send_store_node_message(message)?;
518 }
519
520 Ok(())
521 }
522}
523
524enum IncomingMessage {
526 MutationResponse(MutationResult),
527 QueryResponse(EntityResults),
528}
529
530impl IncomingMessage {
531 fn parse_incoming_message(in_message: &InMessage) -> Result<IncomingMessage, Error> {
532 match in_message.typ {
533 <mutation_response::Owned as MessageType>::MESSAGE_TYPE => {
534 let mutation_frame = in_message.get_data_as_framed_message()?;
535 let mutation_result = mutation_result_from_response_frame(mutation_frame)?;
536 Ok(IncomingMessage::MutationResponse(mutation_result))
537 }
538 <query_response::Owned as MessageType>::MESSAGE_TYPE
539 | <watched_query_response::Owned as MessageType>::MESSAGE_TYPE => {
540 let query_frame = in_message.get_data_as_framed_message()?;
541 let query_result = query_results_from_response_frame(query_frame)?;
542 Ok(IncomingMessage::QueryResponse(query_result))
543 }
544 other => Err(anyhow!("Received message of unknown type: {}", other).into()),
545 }
546 }
547}
548
549struct PendingRequest<T> {
551 request_id: ConsistentTimestamp,
552 result_sender: oneshot::Sender<Result<T, Error>>,
553 send_time: Instant,
554}
555
556struct WatchedQueryRequest {
557 request_id: ConsistentTimestamp,
558 query: EntityQuery,
559 result_sender: mpsc::Sender<Result<EntityResults, Error>>,
560 last_register: Option<Instant>,
561}
562
563impl WatchedQueryRequest {
564 fn force_register(&mut self) {
565 self.last_register = None;
566 }
567}
568
569#[derive(Clone)]
571pub struct ClientHandle {
572 inner: Weak<RwLock<Inner>>,
573 handle: Handle,
574}
575
576impl ClientHandle {
577 pub async fn on_start(&self) {
578 self.handle.on_set_started().await;
579 }
580
581 pub fn store_node(&self) -> Option<Node> {
582 let inner = self.inner.upgrade()?;
583 let inner = inner.read().ok()?;
584 inner.store_node.clone()
585 }
586}
587
588#[async_trait]
589impl crate::store::Store for ClientHandle {
590 type WatchedQueryStream = WatchedQueryStream;
591
592 async fn mutate<M: Into<MutationRequestLike> + Send>(
593 &self,
594 request: M,
595 ) -> Result<MutationResult, Error> {
596 let result = {
597 let inner = self.inner.upgrade().ok_or(Error::Dropped)?;
598 let mut inner = inner.write()?;
599
600 inner.send_mutation(request.into().0)?
601 };
602
603 result.await.map_err(|_err| Error::Cancelled)?
604 }
605
606 async fn query(&self, query: EntityQuery) -> Result<EntityResults, Error> {
607 let receiver = {
608 let inner = self.inner.upgrade().ok_or(Error::Dropped)?;
609 let mut inner = inner.write()?;
610
611 match inner.send_query(query) {
612 Ok(receiver) => receiver,
613 Err(err) => return Err(err),
614 }
615 };
616
617 receiver.await.map_err(|_err| Error::Cancelled)?
618 }
619
620 fn watched_query(&self, mut query: EntityQuery) -> Result<Self::WatchedQueryStream, Error> {
621 let inner = self.inner.upgrade().ok_or(Error::Dropped)?;
622 let mut inner = inner.write()?;
623
624 let mut watch_token = query.watch_token;
625 if watch_token == 0 {
626 watch_token = inner.clock.consistent_time(inner.cell.local_node()).into();
627 query.watch_token = watch_token;
628 }
629
630 let (request_id, receiver) = inner.watch_query(query)?;
631
632 Ok(WatchedQueryStream {
633 inner: self.inner.clone(),
634 watch_token: Some(watch_token),
635 request_id,
636 result: receiver,
637 })
638 }
639}
640
641pub struct WatchedQueryStream {
643 inner: Weak<RwLock<Inner>>,
644 watch_token: Option<WatchToken>,
645 request_id: ConsistentTimestamp,
646 result: mpsc::Receiver<Result<EntityResults, Error>>,
647}
648
649impl Stream for WatchedQueryStream {
650 type Item = Result<EntityResults, Error>;
651
652 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
653 self.result.poll_next_unpin(cx)
654 }
655}
656
657impl Drop for WatchedQueryStream {
658 fn drop(&mut self) {
659 if let Some(inner) = self.inner.upgrade() {
660 if let Ok(mut inner) = inner.write() {
661 inner.watched_queries.remove(&self.request_id);
662
663 if let Some(watch_token) = self.watch_token {
664 let _ = inner.send_unwatch_query(watch_token);
665 }
666 }
667 }
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use exocore_core::{
674 cell::{FullCell, LocalNode},
675 futures::spawn_future,
676 tests_utils::expect_eventually,
677 };
678 use exocore_transport::testing::MockTransport;
679
680 use super::*;
681 use crate::{query::QueryBuilder, store::Store};
682
683 #[tokio::test(flavor = "multi_thread")]
684 async fn connects_to_online_node() -> anyhow::Result<()> {
685 let client_node = LocalNode::generate();
686 let full_cell = FullCell::generate(client_node.clone())?;
687 let clock = Clock::new();
688 let transport = MockTransport::default();
689
690 let mut server_nodes = Vec::new();
691 let mut server_transports = Vec::new();
692 for _i in 0..2 {
693 let node = LocalNode::generate();
694 let mut cell_nodes = full_cell.cell().nodes_mut();
695
696 cell_nodes.add(node.node().clone());
697 let cell_node = cell_nodes.get_mut(node.id()).unwrap();
698 cell_node.add_role(CellNodeRole::Store);
699
700 server_nodes.push(node.clone());
701
702 let transport = transport
703 .get_transport(node, ServiceType::Store)
704 .into_testable_handle(full_cell.cell().clone());
705
706 server_transports.push(transport);
707 }
708
709 let transport_handle = transport.get_transport(client_node, ServiceType::Store);
710 let config = ClientConfiguration::default();
711 let client = Client::new(config, full_cell.cell().clone(), clock, transport_handle)?;
712 let client_inner = client.inner.clone();
713 let client_handle = client.get_handle();
714
715 spawn_future(async move {
716 let _ = client.run().await;
717 });
718
719 tokio::spawn(async move {
720 let _ = client_handle.query(QueryBuilder::test(true).build()).await;
721 });
722
723 expect_eventually(|| -> bool {
725 let inner = client_inner.read().unwrap();
726 assert!(inner.store_node.as_ref().is_none());
727
728 let msg_queue = inner.store_node_message_queue.lock().unwrap();
729 msg_queue.len() == 1
730 });
731
732 assert!(!server_transports[0].has_msg().await.unwrap());
734
735 transport.notify_node_connection_status(server_nodes[0].id(), ConnectionStatus::Connected);
737
738 expect_eventually(|| -> bool {
740 let inner = client_inner.read().unwrap();
741 if inner.store_node.is_none() {
742 return false;
743 }
744
745 inner.store_node.as_ref().unwrap().id() == server_nodes[0].id()
746 });
747
748 assert!(server_transports[0].has_msg().await.unwrap());
750
751 transport
753 .notify_node_connection_status(server_nodes[0].id(), ConnectionStatus::Disconnected);
754 transport.notify_node_connection_status(server_nodes[1].id(), ConnectionStatus::Connected);
755
756 expect_eventually(|| -> bool {
758 let inner = client_inner.read().unwrap();
759 inner.store_node.as_ref().unwrap().id() == server_nodes[1].id()
760 });
761
762 Ok(())
763 }
764}