1use crate::tree::Tree;
2use crate::{ChildData, Event};
3use crate::{EventStream, Result, SharedChildData};
4use async_recursion::async_recursion;
5use futures::StreamExt;
6use futures::{stream, Stream};
7use std::collections::{HashMap, HashSet};
8use std::mem;
9use std::ops::DerefMut;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{RwLock, RwLockWriteGuard};
13use tokio_util::sync::CancellationToken;
14use zookeeper_client::{EventType, SessionState, WatchedEvent};
15
16type Path = String;
17struct Storage {
18 data: HashMap<Path, SharedChildData>,
19 tree: Tree<Path>,
20}
21
22impl Storage {
23 pub fn new(root: String) -> Storage {
24 Storage {
25 data: HashMap::new(),
26 tree: Tree::new(root),
27 }
28 }
29
30 pub fn replace(&mut self, data: HashMap<Path, SharedChildData>, tree: Tree<Path>) -> Storage {
31 Storage {
32 data: mem::replace(&mut self.data, data),
33 tree: mem::replace(&mut self.tree, tree),
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
39pub(crate) struct Version(u32, u32, u32);
40
41#[derive(Clone, Debug)]
42pub struct AuthPacket {
43 pub scheme: String,
44 pub auth: Vec<u8>,
45}
46
47#[derive(Clone, Debug)]
49pub struct CacheBuilder {
50 path: String,
52 authes: Vec<AuthPacket>,
54 server_version: Version,
56 session_timeout: Duration,
58 connection_timeout: Duration,
60 reconnect_timeout: Duration,
62}
63
64impl Default for CacheBuilder {
65 fn default() -> Self {
66 Self {
67 path: "/".to_string(),
68 authes: vec![],
69 server_version: Version(u32::MAX, u32::MAX, u32::MAX),
70 session_timeout: Duration::ZERO,
71 connection_timeout: Duration::ZERO,
72 reconnect_timeout: Duration::from_secs(1),
73 }
74 }
75}
76
77impl From<&CacheBuilder> for zookeeper_client::Connector {
78 fn from(val: &CacheBuilder) -> Self {
79 let mut connector = zookeeper_client::Client::connector();
80 connector.server_version(
81 val.server_version.0,
82 val.server_version.1,
83 val.server_version.2,
84 );
85 for auth in val.authes.clone() {
86 connector.auth(auth.scheme, auth.auth);
87 }
88 connector.session_timeout(val.session_timeout);
89 connector.connection_timeout(val.connection_timeout);
90 connector.readonly(true);
91 connector
92 }
93}
94
95impl CacheBuilder {
109 pub fn new(path: impl Into<String>) -> Self {
110 Self {
111 path: path.into(),
112 ..Default::default()
113 }
114 }
115
116 pub fn with_auth(mut self, scheme: String, auth: Vec<u8>) -> Self {
117 self.authes.push(AuthPacket { scheme, auth });
118 self
119 }
120
121 pub fn with_version(mut self, major: u32, minor: u32, patch: u32) -> Self {
122 self.server_version = Version(major, minor, patch);
123 self
124 }
125
126 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
127 self.session_timeout = timeout;
128 self
129 }
130
131 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
132 self.connection_timeout = timeout;
133 self
134 }
135
136 pub fn with_reconnect_timeout(mut self, timeout: Duration) -> Self {
137 self.reconnect_timeout = timeout;
138 self
139 }
140
141 pub async fn build(
142 self,
143 addr: impl Into<String>,
144 ) -> Result<(Cache, impl Stream<Item = Event>)> {
145 Cache::new(addr, self).await
146 }
147}
148
149pub struct Cache {
164 addr: String,
165 builder: CacheBuilder,
166 storage: Arc<RwLock<Storage>>,
167 event_sender: tokio::sync::mpsc::UnboundedSender<Event>,
168 token: CancellationToken,
169}
170
171impl Drop for Cache {
172 fn drop(&mut self) {
173 self.token.cancel();
174 }
175}
176
177impl Cache {
178 pub async fn new(
179 addr: impl Into<String>,
180 builder: CacheBuilder,
181 ) -> Result<(Self, impl Stream<Item = Event>)> {
182 let mut connector: zookeeper_client::Connector = (&builder).into();
183 let addr = addr.into();
184 let client = connector.connect(&addr).await?;
185 let storage = Arc::new(RwLock::new(Storage::new(builder.path.clone())));
186 let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
187 let events = EventStream { watcher };
188 let cache = Self {
189 addr,
190 builder: builder.clone(),
191 storage,
192 event_sender: sender,
193 token: CancellationToken::new(),
194 };
195 let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
196 Self::init_nodes(
197 &client,
198 &builder.path,
199 cache.storage.write().await.deref_mut(),
200 &sender,
201 &cache.event_sender,
202 )
203 .await?;
204 cache.watch(client, sender, watcher).await;
205 Ok((cache, events))
206 }
207
208 pub async fn get(&self, path: &str) -> Option<SharedChildData> {
218 self.storage.read().await.data.get(path).cloned()
219 }
220
221 async fn init_nodes(
222 client: &zookeeper_client::Client,
223 path: &str,
224 storage: &mut Storage,
225 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
226 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
227 ) -> Result<()> {
228 let new = Arc::new(RwLock::new(Storage::new(path.to_string())));
229 Self::fetch_all(client, path, &mut new.write().await, sender, true).await?;
230 let new = new.write().await;
232 Self::compare_storage(path, storage, &new, event_sender).await;
233 storage.replace(new.data.clone(), new.tree.clone());
234 Ok(())
235 }
236
237 async fn watch(
238 &self,
239 mut client: zookeeper_client::Client,
240 sender: tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
241 mut watcher: tokio::sync::mpsc::UnboundedReceiver<WatchedEvent>,
242 ) {
243 let addr = self.addr.clone();
244 let storage = self.storage.clone();
245 let sender = sender.clone();
246 let builder = self.builder.clone();
247 let event_sender = self.event_sender.clone();
248 let token = self.token.clone();
249 tokio::spawn(async move {
250 let mut control = HandleControl::Handle;
251 loop {
252 tokio::select! {
253 _ = token.cancelled() => {
254 return
255 }
256 event = watcher.recv() => {
257 match event{
258 Some(event) => {
259 match control {
260 HandleControl::Handle => {},
261 HandleControl::Continue => {
262 if event.event_type == EventType::Session && event.session_state.is_terminated(){
263 continue;
264 } else {
265 control = HandleControl::Handle;
266 }
267 }
268 };
269 if let Some(reconnect) = Self::handle_event(&addr, &client, &builder, &storage, event, &sender, &event_sender, &token).await{
270 client = reconnect;
271 control = HandleControl::Continue;
273 }
274 }
275 None => break
276 }
277 }
278 }
279 }
280 });
281 }
282
283 #[allow(clippy::too_many_arguments)]
284 async fn handle_event(
285 addr: &str,
286 client: &zookeeper_client::Client,
287 builder: &CacheBuilder,
288 storage: &Arc<RwLock<Storage>>,
289 event: WatchedEvent,
290 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
291 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
292 token: &CancellationToken,
293 ) -> Option<zookeeper_client::Client> {
294 match event.event_type {
295 EventType::Session => {
296 if let Some(client) =
297 Self::handle_session(addr, builder, storage, event, sender, event_sender, token)
298 .await
299 {
300 return Some(client);
301 }
302 }
303 EventType::NodeDeleted => {
304 Self::handle_node_deleted(storage, event, event_sender).await;
305 }
306 EventType::NodeDataChanged => {
307 Self::handle_node_data_changed(client, storage, event, sender, event_sender).await;
308 }
309 EventType::NodeChildrenChanged => {
310 Self::handle_children_change(client, storage, event, sender, event_sender).await;
311 }
312 EventType::NodeCreated => {
313 Self::handle_node_created(client, storage, event, sender, event_sender).await;
314 }
315 }
316 None
317 }
318
319 async fn handle_session(
320 addr: &str,
321 builder: &CacheBuilder,
322 storage: &Arc<RwLock<Storage>>,
323 event: WatchedEvent,
324 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
325 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
326 token: &CancellationToken,
327 ) -> Option<zookeeper_client::Client> {
328 match event.session_state {
330 SessionState::Expired | SessionState::Closed => {
331 let mut interval = tokio::time::interval(builder.reconnect_timeout);
332 let mut connector: zookeeper_client::Connector = builder.into();
333 let client = loop {
334 tokio::select! {
335 _ = token.cancelled() => {
336 return None
337 }
338 _ = interval.tick() => {
339 match connector.connect(addr).await {
340 Ok(zk) => break zk,
341 Err(_err) => {
342 }
343 };
344 }
345 }
346 };
347 {
348 loop {
349 match Self::init_nodes(
350 &client,
351 &builder.path,
352 storage.write().await.deref_mut(),
353 sender,
354 event_sender,
355 )
356 .await
357 {
358 Ok(_) => break,
359 Err(_err) => {
360 interval.tick().await;
361 }
362 }
363 }
364 }
365 return Some(client);
366 }
367 _ => {}
368 };
369 None
370 }
371
372 async fn handle_node_created(
374 client: &zookeeper_client::Client,
375 storage: &Arc<RwLock<Storage>>,
376 event: WatchedEvent,
377 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
378 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
379 ) {
380 let mut storage = storage.write().await;
381 if let Ok(status) = Self::get_root_node(client, &event.path, &mut storage, sender).await {
382 match status {
383 RootStatus::Ephemeral(data) => {
384 let _ = event_sender.send(Event::Add(data));
385 }
386 RootStatus::Persistent(data) => {
387 if let Err(err) = Self::list_children(client, &event.path, sender).await {
388 debug_assert_eq!(err, zookeeper_client::Error::NoNode);
389 }
390 let _ = event_sender.send(Event::Add(data));
391 }
392 _ => {}
393 }
394 }
395 }
396
397 async fn handle_node_deleted(
398 storage: &Arc<RwLock<Storage>>,
399 event: WatchedEvent,
400 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
401 ) {
402 let mut storage = storage.write().await;
403 storage.tree.remove_child(&event.path);
404 match storage.data.get(&event.path) {
405 None => {}
406 Some(_data) => {}
407 }
408 match storage.data.remove(&event.path) {
409 None => {}
410 Some(child_data) => {
411 let _ = event_sender.send(Event::Delete(child_data));
412 }
413 }
414 }
415
416 async fn handle_node_data_changed(
417 client: &zookeeper_client::Client,
418 storage: &Arc<RwLock<Storage>>,
419 event: WatchedEvent,
420 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
421 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
422 ) {
423 let mut storage = storage.write().await;
424 let old = storage.data.get(&event.path).unwrap().clone();
425 if let Err(err) = Self::get_data(client, &event.path, &mut storage, sender).await {
426 debug_assert_eq!(err, zookeeper_client::Error::NoNode);
427 storage.tree.remove_child(&event.path);
429 let child_data = storage.data.remove(&event.path).unwrap();
430 let _ = event_sender.send(Event::Delete(child_data));
431 return;
432 };
433 let new = storage.data.get(&event.path).unwrap().clone();
434 let _ = event_sender.send(Event::Update { old, new });
435 }
436
437 async fn handle_children_change(
438 client: &zookeeper_client::Client,
439 storage: &Arc<RwLock<Storage>>,
440 event: WatchedEvent,
441 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
442 event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
443 ) {
444 let old_children = storage
445 .read()
446 .await
447 .tree
448 .children(&event.path)
449 .into_iter()
450 .map(|child| child.to_string())
451 .collect::<Vec<_>>();
452 let new_children = match Self::list_children(client, &event.path, sender).await {
453 Ok(children) => children
454 .iter()
455 .map(|child| make_path(&event.path, child))
456 .collect::<Vec<_>>(),
457 Err(err) => {
458 debug_assert_eq!(err, zookeeper_client::Error::NoNode);
459 return;
460 }
461 };
462 let (added, _) = compare(&old_children, &new_children);
463 let added = added
465 .into_iter()
466 .map(|added| {
467 let zk = client.clone();
468 let path = event.path.clone();
469 let sender = sender.clone();
470 let event_sender = event_sender.clone();
471 (zk, storage, path, added, sender, event_sender)
472 })
473 .collect::<Vec<_>>();
474 stream::iter(added)
475 .for_each_concurrent(
476 20,
478 |(zk, storage, parent, child_path, sender, event_sender)| async move {
479 let mut storage = storage.write().await;
480 if let Err(err) =
481 Self::get_data(&zk, &child_path, &mut storage, &sender.clone()).await
482 {
483 debug_assert_eq!(err, zookeeper_client::Error::NoNode);
484 return;
485 }
486 storage.tree.add_child(&parent, child_path.clone());
487 let child_data = storage.data.get(&child_path).unwrap();
488 let _ = event_sender.send(Event::Add(child_data.clone()));
489 },
490 )
491 .await;
492 }
493
494 async fn get_data(
495 client: &zookeeper_client::Client,
496 path: &str,
497 storage: &mut RwLockWriteGuard<'_, Storage>,
498 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
499 ) -> std::result::Result<SharedChildData, zookeeper_client::Error> {
500 let (data, stat, watcher) = client.get_and_watch_data(path).await?;
501 let data = Arc::new(ChildData {
502 path: path.to_string(),
503 data,
504 stat,
505 });
506 storage.data.insert(path.to_string(), data.clone());
507 {
508 let sender = sender.clone();
509 tokio::spawn(async move {
510 let _ = sender.send(watcher.changed().await);
511 });
512 }
513 Ok(data)
514 }
515
516 async fn list_children(
517 client: &zookeeper_client::Client,
518 path: &str,
519 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
520 ) -> std::result::Result<Vec<String>, zookeeper_client::Error> {
521 let (children, watcher) = client.list_and_watch_children(path).await?;
522 {
523 let sender = sender.clone();
524 tokio::spawn(async move {
525 let _ = sender.send(watcher.changed().await);
526 });
527 }
528 Ok(children)
529 }
530
531 #[async_recursion]
533 async fn get_root_node(
534 client: &zookeeper_client::Client,
535 path: &str,
536 storage: &mut RwLockWriteGuard<'_, Storage>,
537 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
538 ) -> std::result::Result<RootStatus, zookeeper_client::Error> {
539 match client.check_and_watch_stat(path).await? {
540 (None, watcher) => {
541 let sender = sender.clone();
542 tokio::spawn(async move {
543 let _ = sender.send(watcher.changed().await);
544 });
545 Ok(RootStatus::NotExist)
546 }
547 (Some(_), _) => {
548 match Self::get_data(client, path, storage, sender).await {
549 Ok(data) if data.stat.ephemeral_owner != 0 => {
550 Ok(RootStatus::Ephemeral(data.clone()))
551 }
552 Ok(data) => Ok(RootStatus::Persistent(data.clone())),
553 Err(err) => {
554 debug_assert_eq!(err, zookeeper_client::Error::NoNode);
555 Self::get_root_node(client, path, storage, sender).await
557 }
558 }
559 }
560 }
561 }
562
563 #[async_recursion]
564 async fn fetch_all(
565 client: &zookeeper_client::Client,
566 path: &str,
567 storage: &mut RwLockWriteGuard<Storage>,
568 sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
569 root: bool,
570 ) -> std::result::Result<(), zookeeper_client::Error> {
571 let persistent = if root {
572 matches!(
573 Self::get_root_node(client, path, storage, sender).await?,
574 RootStatus::Persistent(_)
575 )
576 } else {
577 Self::get_data(client, path, storage, sender)
578 .await?
579 .stat
580 .ephemeral_owner
581 == 0
582 };
583 if persistent {
584 let children = match Self::list_children(client, path, sender).await {
585 Ok(children) => children,
586 Err(_) => return Ok(()),
587 };
588 storage.tree.add_children(
589 path,
590 children
591 .iter()
592 .map(|child| make_path(path, child))
593 .collect(),
594 );
595 for child in children.iter() {
596 if let Err(zookeeper_client::Error::NoNode) = Self::fetch_all(
597 client,
598 make_path(path, child).as_str(),
599 storage,
600 sender,
601 false,
602 )
603 .await
604 {
605 continue;
606 }
607 }
608 }
609 Ok(())
610 }
611
612 #[async_recursion]
613 async fn compare_storage(
614 path: &str,
615 old: &Storage,
616 new: &Storage,
617 sender: &tokio::sync::mpsc::UnboundedSender<Event>,
618 ) {
619 let old_data = old.data.get(path);
620 let new_data = new.data.get(path);
621 match (old_data, new_data) {
622 (Some(data), None) => {
623 let _ = sender.send(Event::Delete(data.clone()));
624 }
625 (None, Some(data)) => {
626 let _ = sender.send(Event::Add(data.clone()));
627 }
628 (Some(old), Some(new)) => {
629 if !old.eq(new) {
630 let _ = sender.send(Event::Update {
631 old: old.clone(),
632 new: new.clone(),
633 });
634 }
635 }
636 _ => {}
637 }
638 let mut old_children = old.tree.children(path);
639 let mut new_children = new.tree.children(path);
640 old_children.append(&mut new_children);
641 let children = old_children.into_iter().collect::<HashSet<_>>();
642 for child in children.iter() {
643 Self::compare_storage(child, old, new, sender).await;
644 }
645 }
646}
647
648fn make_path(parent: &str, child: &str) -> String {
649 if let Some('/') = parent.chars().last() {
650 format!("{}{}", parent, child)
651 } else {
652 format!("{}/{}", parent, child)
653 }
654}
655
656fn compare(old: &[String], new: &[String]) -> (Vec<String>, Vec<String>) {
657 let old_map = old.iter().collect::<HashSet<_>>();
658 let new_map = new.iter().collect::<HashSet<_>>();
659 let and = &new_map & &old_map;
660 (
661 (&new_map ^ &and)
662 .into_iter()
663 .map(|s| s.to_string())
664 .collect(),
665 (&old_map ^ &and)
666 .into_iter()
667 .map(|s| s.to_string())
668 .collect(),
669 )
670}
671
672#[derive(Clone, Debug)]
673enum RootStatus {
674 NotExist,
675 Ephemeral(SharedChildData),
676 Persistent(SharedChildData),
677}
678
679#[derive(Clone, Debug)]
680enum HandleControl {
681 Handle,
682 Continue,
683}
684
685#[cfg(test)]
686mod tests {
687 #[test]
688 fn compare() {
689 let old = ["1".to_string(), "2".to_string(), "3".to_string()];
690 let new = ["2".to_string(), "3".to_string(), "4".to_string()];
691 let (added, deleted) = super::compare(&old, &new);
692 assert_eq!(added, vec!["4".to_string()]);
693 assert_eq!(deleted, vec!["1".to_string()]);
694 }
695}