1mod watcher;
2
3use std::borrow::Cow;
4use std::fmt::Write as _;
5use std::future::Future;
6use std::mem::ManuallyDrop;
7use std::sync::Arc;
8use std::time::Duration;
9
10use const_format::formatcp;
11use derive_where::derive_where;
12use either::{Either, Left, Right};
13use futures::channel::mpsc;
14use ignore_result::Ignore;
15use thiserror::Error;
16use tracing::instrument;
17
18pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
19use super::session::{Depot, MarshalledRequest, Request, Session, SessionOperation, WatchReceiver};
20use crate::acl::{Acl, Acls, AuthUser};
21use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
22use crate::endpoint::{self, IterableEndpoints};
23use crate::error::Error;
24use crate::proto::{
25 self,
26 AuthPacket,
27 CheckVersionRequest,
28 CreateRequest,
29 DeleteRequest,
30 ExistsRequest,
31 GetAclResponse,
32 GetChildren2Response,
33 GetChildrenRequest,
34 GetRequest,
35 MultiHeader,
36 MultiReadResponse,
37 MultiWriteResponse,
38 OpCode,
39 PersistentWatchRequest,
40 ReconfigRequest,
41 RequestBuffer,
42 RequestHeader,
43 SetAclRequest,
44 SetDataRequest,
45 SyncRequest,
46};
47pub use crate::proto::{EnsembleUpdate, Stat};
48use crate::record::{self, Record, StaticRecord};
49#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
50use crate::sasl::SaslOptions;
51use crate::session::StateReceiver;
52pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
53#[cfg(feature = "tls")]
54use crate::tls::TlsOptions;
55use crate::util;
56
57pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62pub enum CreateMode {
63 Persistent,
64 PersistentSequential,
65 Ephemeral,
66 EphemeralSequential,
67 Container,
68}
69
70impl CreateMode {
71 pub const fn with_acls(self, acls: Acls<'_>) -> CreateOptions<'_> {
73 CreateOptions { mode: self, acls, ttl: None }
74 }
75
76 fn is_sequential(self) -> bool {
77 self == CreateMode::PersistentSequential || self == CreateMode::EphemeralSequential
78 }
79
80 fn is_persistent(self) -> bool {
81 self == Self::Persistent || self == Self::PersistentSequential
82 }
83
84 fn is_ephemeral(self) -> bool {
85 self == Self::Ephemeral || self == Self::EphemeralSequential
86 }
87
88 fn is_container(self) -> bool {
89 self == CreateMode::Container
90 }
91
92 fn as_flags(self, ttl: bool) -> i32 {
93 use CreateMode::*;
94 match self {
95 Persistent => {
96 if ttl {
97 5
98 } else {
99 0
100 }
101 },
102 PersistentSequential => {
103 if ttl {
104 6
105 } else {
106 2
107 }
108 },
109 Ephemeral => 1,
110 EphemeralSequential => 3,
111 Container => 4,
112 }
113 }
114}
115
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
118pub enum AddWatchMode {
119 Persistent,
121
122 PersistentRecursive,
124}
125
126impl From<AddWatchMode> for proto::AddWatchMode {
127 fn from(mode: AddWatchMode) -> proto::AddWatchMode {
128 match mode {
129 AddWatchMode::Persistent => proto::AddWatchMode::Persistent,
130 AddWatchMode::PersistentRecursive => proto::AddWatchMode::PersistentRecursive,
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
137pub struct CreateOptions<'a> {
138 mode: CreateMode,
139 acls: Acls<'a>,
140 ttl: Option<Duration>,
141}
142
143const TTL_MAX_MILLIS: u128 = 0x00FFFFFFFFFF;
147
148impl<'a> CreateOptions<'a> {
149 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
151 self.ttl = Some(ttl);
152 self
153 }
154
155 fn validate(&'a self) -> Result<()> {
156 if let Some(ref ttl) = self.ttl {
157 if !self.mode.is_persistent() {
158 return Err(Error::BadArguments(&"ttl can only be specified with persistent node"));
159 } else if ttl.is_zero() {
160 return Err(Error::BadArguments(&"ttl is zero"));
161 } else if ttl.as_millis() > TTL_MAX_MILLIS {
162 return Err(Error::BadArguments(&formatcp!("ttl cannot larger than {}", TTL_MAX_MILLIS)));
163 }
164 }
165 if self.acls.is_empty() {
166 return Err(Error::InvalidAcl);
167 }
168 Ok(())
169 }
170
171 fn validate_as_directory(&self) -> Result<()> {
172 self.validate()?;
173 if self.mode.is_ephemeral() {
174 return Err(Error::BadArguments(&"directory node must not be ephemeral"));
175 } else if self.mode.is_sequential() {
176 return Err(Error::BadArguments(&"directory node must not be sequential"));
177 }
178 Ok(())
179 }
180}
181
182#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
186pub struct CreateSequence(i64);
187
188impl std::fmt::Display for CreateSequence {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 if self.0 <= i32::MAX.into() {
194 write!(f, "{:010}", self.0)
195 } else {
196 write!(f, "{:019}", self.0)
197 }
198 }
199}
200
201impl CreateSequence {
202 pub fn into_i64(self) -> i64 {
203 self.0
204 }
205}
206
207#[derive(Clone, Debug)]
221pub struct Client {
222 chroot: OwnedChroot,
223 #[allow(dead_code)]
224 version: Version,
225 session: SessionInfo,
226 session_timeout: Duration,
227 requester: Arc<mpsc::UnboundedSender<Request>>,
228 state_watcher: StateWatcher,
229}
230
231impl Client {
232 const CONFIG_NODE: &'static str = "/zookeeper/config";
233
234 pub async fn connect(cluster: &str) -> Result<Self> {
236 Self::connector().connect(cluster).await
237 }
238
239 pub fn connector() -> Connector {
241 Connector::new()
242 }
243
244 pub(crate) fn new(
245 chroot: OwnedChroot,
246 version: Version,
247 session: SessionInfo,
248 timeout: Duration,
249 requester: Arc<mpsc::UnboundedSender<Request>>,
250 state_watcher: StateWatcher,
251 ) -> Client {
252 Client { chroot, version, session, session_timeout: timeout, requester, state_watcher }
253 }
254
255 fn validate_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
256 ChrootPath::new(self.chroot.as_ref(), path, false)
257 }
258
259 fn validate_sequential_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
260 ChrootPath::new(self.chroot.as_ref(), path, true)
261 }
262
263 pub fn path(&self) -> &str {
265 self.chroot.path()
266 }
267
268 pub fn session(&self) -> &SessionInfo {
270 &self.session
271 }
272
273 pub fn session_id(&self) -> SessionId {
275 self.session().id()
276 }
277
278 pub fn into_session(self) -> SessionInfo {
280 self.session
281 }
282
283 pub fn session_timeout(&self) -> Duration {
285 self.session_timeout
286 }
287
288 pub fn state(&self) -> SessionState {
290 self.state_watcher.peek_state()
291 }
292
293 pub fn state_watcher(&self) -> StateWatcher {
295 let mut watcher = self.state_watcher.clone();
296 watcher.state();
297 watcher
298 }
299
300 pub fn chroot<'a>(mut self, path: impl Into<Cow<'a, str>>) -> std::result::Result<Client, Client> {
308 if self.chroot.chroot(path) {
309 Ok(self)
310 } else {
311 Err(self)
312 }
313 }
314
315 fn send_request(&self, code: OpCode, body: &impl Record) -> StateReceiver {
316 let request = MarshalledRequest::new(code, body);
317 self.send_marshalled_request(request)
318 }
319
320 fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
321 let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
322 if let Err(err) = self.requester.unbounded_send(operation.into()) {
323 let state = self.state();
324 err.into_inner().into_responser().send(Err(state.to_error()));
325 }
326 receiver
327 }
328
329 async fn wait<T, E, F>(result: std::result::Result<F, E>) -> std::result::Result<T, E>
330 where
331 F: Future<Output = std::result::Result<T, E>>, {
332 match result {
333 Err(err) => Err(err),
334 Ok(future) => future.await,
335 }
336 }
337
338 async fn resolve<T, E, F>(result: std::result::Result<Either<F, T>, E>) -> std::result::Result<T, E>
339 where
340 F: Future<Output = std::result::Result<T, E>>, {
341 match result {
342 Err(err) => Err(err),
343 Ok(Right(r)) => Ok(r),
344 Ok(Left(future)) => future.await,
345 }
346 }
347
348 async fn map_wait<T, U, Fu, Fn>(result: Result<Fu>, f: Fn) -> Result<U>
349 where
350 Fu: Future<Output = Result<T>>,
351 Fn: FnOnce(T) -> U, {
352 match result {
353 Err(err) => Err(err),
354 Ok(future) => match future.await {
355 Err(err) => Err(err),
356 Ok(t) => Ok(f(t)),
357 },
358 }
359 }
360
361 async fn retry_on_connection_loss<T, F>(operation: impl Fn() -> F) -> Result<T>
362 where
363 F: Future<Output = Result<T>>, {
364 loop {
365 let future = operation();
366 return match future.await {
367 Err(Error::ConnectionLoss) => continue,
368 result => result,
369 };
370 }
371 }
372
373 fn parse_sequence(client_path: &str, prefix: &str) -> Result<CreateSequence> {
374 if let Some(sequence_path) = client_path.strip_prefix(prefix) {
375 match sequence_path.parse::<i64>() {
376 Err(_) => Err(Error::UnexpectedError(format!("sequential node get no i32 path {client_path}"))),
377 Ok(i) => Ok(CreateSequence(i)),
378 }
379 } else {
380 Err(Error::UnexpectedError(format!("sequential path {client_path} does not contain prefix path {prefix}",)))
381 }
382 }
383
384 pub async fn mkdir(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
396 options.validate_as_directory()?;
397 self.mkdir_internally(path, options).await
398 }
399
400 async fn mkdir_internally(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
401 let mut j = path.len();
402 loop {
403 match self.create(&path[..j], Default::default(), options).await {
404 Ok(_) | Err(Error::NodeExists) => {
405 if j >= path.len() {
406 return Ok(());
407 } else if let Some(i) = path[j + 1..].find('/') {
408 j = j + 1 + i;
409 } else {
410 j = path.len();
411 }
412 },
413 Err(Error::NoNode) => {
414 let i = path[..j].rfind('/').unwrap();
415 if i == 0 {
416 return Err(Error::NoNode);
418 }
419 j = i;
420 },
421 Err(err) => return Err(err),
422 }
423 }
424 }
425
426 pub fn create<'a: 'f, 'b: 'f, 'f>(
440 &'a self,
441 path: &'b str,
442 data: &[u8],
443 options: &CreateOptions<'_>,
444 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
445 Self::wait(self.create_internally(path, data, options))
446 }
447
448 fn create_internally<'a: 'f, 'b: 'f, 'f>(
449 &'a self,
450 path: &'b str,
451 data: &[u8],
452 options: &CreateOptions<'_>,
453 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f> {
454 options.validate()?;
455 let create_mode = options.mode;
456 let sequential = create_mode.is_sequential();
457 let chroot_path = if sequential { self.validate_sequential_path(path)? } else { self.validate_path(path)? };
458 if chroot_path.is_root() {
459 return Err(Error::BadArguments(&"can not create root node"));
460 }
461 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
462 let op_code = if ttl != 0 {
463 OpCode::CreateTtl
464 } else if create_mode.is_container() {
465 OpCode::CreateContainer
466 } else {
467 OpCode::Create
468 };
469 let flags = create_mode.as_flags(ttl != 0);
470 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
471 let receiver = self.send_request(op_code, &request);
472 Ok(async move {
473 let (body, _) = receiver.await?;
474 let mut buf = body.as_slice();
475 let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
476 let client_path = util::strip_root_path(server_path, self.chroot.root())?;
477 let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) };
478 let stat =
479 if op_code == OpCode::Create { Stat::new_invalid() } else { record::unmarshal::<Stat>(&mut buf)? };
480 Ok((stat, sequence))
481 })
482 }
483
484 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send {
491 Self::wait(self.delete_internally(path, expected_version))
492 }
493
494 fn delete_internally(&self, path: &str, expected_version: Option<i32>) -> Result<impl Future<Output = Result<()>>> {
495 let chroot_path = self.validate_path(path)?;
496 if chroot_path.is_root() {
497 return Err(Error::BadArguments(&"can not delete root node"));
498 }
499 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
500 let receiver = self.send_request(OpCode::Delete, &request);
501 Ok(async move {
502 receiver.await?;
503 Ok(())
504 })
505 }
506
507 fn delete_background(self, path: String) {
509 asyncs::spawn(async move {
510 self.delete_foreground(&path).await;
511 });
512 }
513
514 async fn delete_foreground(&self, path: &str) {
515 Client::retry_on_connection_loss(|| self.delete(path, None)).await.ignore();
516 }
517
518 fn delete_ephemeral_background(self, prefix: String, unique: bool) {
519 asyncs::spawn(async move {
520 let (parent, tree, name) = util::split_path(&prefix);
521 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
522 if unique {
523 if let Some(i) = children.iter().position(|s| s.starts_with(name)) {
524 self.delete_foreground(&children[i]).await;
525 };
526 return Ok::<(), Error>(());
527 }
528 children.retain(|s| s.starts_with(name));
529 for child in children.iter_mut() {
530 child.insert_str(0, tree);
531 }
532 let results = Self::retry_on_connection_loss(|| {
533 let mut reader = self.new_multi_reader();
534 for child in children.iter() {
535 reader.add_get_data(child).unwrap();
536 }
537 reader.commit()
538 })
539 .await?;
540 for (i, result) in results.into_iter().enumerate() {
541 let MultiReadResult::Data { stat, .. } = result else {
542 continue;
544 };
545 if stat.ephemeral_owner == self.session_id().0 {
546 self.delete_foreground(&children[i]).await;
547 break;
548 }
549 }
550 Ok(())
551 });
552 }
553
554 fn get_data_internally(
555 &self,
556 chroot: Chroot,
557 path: &str,
558 watch: bool,
559 ) -> Result<impl Future<Output = Result<(Vec<u8>, Stat, WatchReceiver)>> + Send> {
560 let chroot_path = ChrootPath::new(chroot, path, false)?;
561 let request = GetRequest { path: chroot_path, watch };
562 let receiver = self.send_request(OpCode::GetData, &request);
563 Ok(async move {
564 let (mut body, watcher) = receiver.await?;
565 let data_len = body.len() - Stat::record_len();
566 let mut stat_buf = &body[data_len..];
567 let stat = record::unmarshal(&mut stat_buf)?;
568 body.truncate(data_len);
569 drop(body.drain(..4));
570 Ok((body, stat, watcher))
571 })
572 }
573
574 pub fn get_data(&self, path: &str) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
579 let result = self.get_data_internally(self.chroot.as_ref(), path, false);
580 Self::map_wait(result, |(data, stat, _)| (data, stat))
581 }
582
583 pub fn get_and_watch_data(
593 &self,
594 path: &str,
595 ) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send + '_ {
596 let result = self.get_data_internally(self.chroot.as_ref(), path, true);
597 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&self.chroot)))
598 }
599
600 fn check_stat_internally(
601 &self,
602 path: &str,
603 watch: bool,
604 ) -> Result<impl Future<Output = Result<(Option<Stat>, WatchReceiver)>>> {
605 let chroot_path = self.validate_path(path)?;
606 let request = ExistsRequest { path: chroot_path, watch };
607 let receiver = self.send_request(OpCode::Exists, &request);
608 Ok(async move {
609 let (body, watcher) = receiver.await?;
610 let mut buf = body.as_slice();
611 let stat = record::try_deserialize(&mut buf)?;
612 Ok((stat, watcher))
613 })
614 }
615
616 pub fn check_stat(&self, path: &str) -> impl Future<Output = Result<Option<Stat>>> + Send {
618 Self::map_wait(self.check_stat_internally(path, false), |(stat, _)| stat)
619 }
620
621 pub fn check_and_watch_stat(
628 &self,
629 path: &str,
630 ) -> impl Future<Output = Result<(Option<Stat>, OneshotWatcher)>> + Send + '_ {
631 let result = self.check_stat_internally(path, true);
632 Self::map_wait(result, |(stat, watcher)| (stat, watcher.into_oneshot(&self.chroot)))
633 }
634
635 pub fn set_data(
642 &self,
643 path: &str,
644 data: &[u8],
645 expected_version: Option<i32>,
646 ) -> impl Future<Output = Result<Stat>> + Send {
647 Self::wait(self.set_data_internally(path, data, expected_version))
648 }
649
650 pub fn set_data_internally(
651 &self,
652 path: &str,
653 data: &[u8],
654 expected_version: Option<i32>,
655 ) -> Result<impl Future<Output = Result<Stat>>> {
656 let chroot_path = self.validate_path(path)?;
657 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
658 let receiver = self.send_request(OpCode::SetData, &request);
659 Ok(async move {
660 let (body, _) = receiver.await?;
661 let mut buf = body.as_slice();
662 let stat: Stat = record::unmarshal(&mut buf)?;
663 Ok(stat)
664 })
665 }
666
667 fn list_children_internally(
668 &self,
669 path: &str,
670 watch: bool,
671 ) -> Result<impl Future<Output = Result<(Vec<String>, WatchReceiver)>>> {
672 let chroot_path = self.validate_path(path)?;
673 let request = GetChildrenRequest { path: chroot_path, watch };
674 let receiver = self.send_request(OpCode::GetChildren, &request);
675 Ok(async move {
676 let (body, watcher) = receiver.await?;
677 let mut buf = body.as_slice();
678 let children = record::unmarshal_entity::<Vec<String>>(&"children paths", &mut buf)?;
679 Ok((children, watcher))
680 })
681 }
682
683 pub fn list_children(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
688 Self::map_wait(self.list_children_internally(path, false), |(children, _)| children)
689 }
690
691 pub fn list_and_watch_children(
702 &self,
703 path: &str,
704 ) -> impl Future<Output = Result<(Vec<String>, OneshotWatcher)>> + Send + '_ {
705 let result = self.list_children_internally(path, true);
706 Self::map_wait(result, |(children, watcher)| (children, watcher.into_oneshot(&self.chroot)))
707 }
708
709 fn get_children_internally(
710 &self,
711 path: &str,
712 watch: bool,
713 ) -> Result<impl Future<Output = Result<(Vec<String>, Stat, WatchReceiver)>>> {
714 let chroot_path = self.validate_path(path)?;
715 let request = GetChildrenRequest { path: chroot_path, watch };
716 let receiver = self.send_request(OpCode::GetChildren2, &request);
717 Ok(async move {
718 let (body, watcher) = receiver.await?;
719 let mut buf = body.as_slice();
720 let response = record::unmarshal::<GetChildren2Response>(&mut buf)?;
721 Ok((response.children, response.stat, watcher))
722 })
723 }
724
725 pub fn get_children(&self, path: &str) -> impl Future<Output = Result<(Vec<String>, Stat)>> + Send {
730 let result = self.get_children_internally(path, false);
731 Self::map_wait(result, |(children, stat, _)| (children, stat))
732 }
733
734 pub fn get_and_watch_children(
745 &self,
746 path: &str,
747 ) -> impl Future<Output = Result<(Vec<String>, Stat, OneshotWatcher)>> + Send + '_ {
748 let result = self.get_children_internally(path, true);
749 Self::map_wait(result, |(children, stat, watcher)| (children, stat, watcher.into_oneshot(&self.chroot)))
750 }
751
752 pub fn count_descendants_number(&self, path: &str) -> impl Future<Output = Result<usize>> + Send {
757 Self::wait(self.count_descendants_number_internally(path))
758 }
759
760 fn count_descendants_number_internally(&self, path: &str) -> Result<impl Future<Output = Result<usize>>> {
761 let chroot_path = self.validate_path(path)?;
762 let receiver = self.send_request(OpCode::GetAllChildrenNumber, &chroot_path);
763 Ok(async move {
764 let (body, _) = receiver.await?;
765 let mut buf = body.as_slice();
766 let n = record::unmarshal_entity::<i32>(&"all children number", &mut buf)?;
767 Ok(n as usize)
768 })
769 }
770
771 pub fn list_ephemerals(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
778 Self::wait(self.list_ephemerals_internally(path))
779 }
780
781 fn list_ephemerals_internally(&self, path: &str) -> Result<impl Future<Output = Result<Vec<String>>> + Send + '_> {
782 let path = self.validate_path(path)?;
783 let receiver = self.send_request(OpCode::GetEphemerals, &path);
784 Ok(async move {
785 let (body, _) = receiver.await?;
786 let mut buf = body.as_slice();
787 let mut ephemerals = record::unmarshal_entity::<Vec<String>>(&"ephemerals", &mut buf)?;
788 for ephemeral_path in ephemerals.iter_mut() {
789 util::drain_root_path(ephemeral_path, self.chroot.root())?;
790 }
791 Ok(ephemerals)
792 })
793 }
794
795 pub fn get_acl(&self, path: &str) -> impl Future<Output = Result<(Vec<Acl>, Stat)>> + Send + '_ {
800 Self::wait(self.get_acl_internally(path))
801 }
802
803 fn get_acl_internally(&self, path: &str) -> Result<impl Future<Output = Result<(Vec<Acl>, Stat)>>> {
804 let chroot_path = self.validate_path(path)?;
805 let receiver = self.send_request(OpCode::GetACL, &chroot_path);
806 Ok(async move {
807 let (body, _) = receiver.await?;
808 let mut buf = body.as_slice();
809 let response: GetAclResponse = record::unmarshal(&mut buf)?;
810 Ok((response.acl, response.stat))
811 })
812 }
813
814 pub fn set_acl(
820 &self,
821 path: &str,
822 acl: &[Acl],
823 expected_acl_version: Option<i32>,
824 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
825 Self::wait(self.set_acl_internally(path, acl, expected_acl_version))
826 }
827
828 fn set_acl_internally(
829 &self,
830 path: &str,
831 acl: &[Acl],
832 expected_acl_version: Option<i32>,
833 ) -> Result<impl Future<Output = Result<Stat>>> {
834 let chroot_path = self.validate_path(path)?;
835 let request = SetAclRequest { path: chroot_path, acl, version: expected_acl_version.unwrap_or(-1) };
836 let receiver = self.send_request(OpCode::SetACL, &request);
837 Ok(async move {
838 let (body, _) = receiver.await?;
839 let mut buf = body.as_slice();
840 let stat: Stat = record::unmarshal(&mut buf)?;
841 Ok(stat)
842 })
843 }
844
845 pub fn watch(&self, path: &str, mode: AddWatchMode) -> impl Future<Output = Result<PersistentWatcher>> + Send + '_ {
860 Self::wait(self.watch_internally(path, mode))
861 }
862
863 fn watch_internally(
864 &self,
865 path: &str,
866 mode: AddWatchMode,
867 ) -> Result<impl Future<Output = Result<PersistentWatcher>> + Send + '_> {
868 let chroot_path = self.validate_path(path)?;
869 let proto_mode = proto::AddWatchMode::from(mode);
870 let request = PersistentWatchRequest { path: chroot_path, mode: proto_mode.into() };
871 let receiver = self.send_request(OpCode::AddWatch, &request);
872 Ok(async move {
873 let (_, watcher) = receiver.await?;
874 Ok(watcher.into_persistent(&self.chroot))
875 })
876 }
877
878 pub fn sync(&self, path: &str) -> impl Future<Output = Result<()>> + Send + '_ {
889 Self::wait(self.sync_internally(path))
890 }
891
892 fn sync_internally(&self, path: &str) -> Result<impl Future<Output = Result<()>>> {
893 let chroot_path = self.validate_path(path)?;
894 let request = SyncRequest { path: chroot_path };
895 let receiver = self.send_request(OpCode::Sync, &request);
896 Ok(async move {
897 let (body, _) = receiver.await?;
898 let mut buf = body.as_slice();
899 record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
900 Ok(())
901 })
902 }
903
904 pub fn auth(&self, scheme: &str, auth: &[u8]) -> impl Future<Output = Result<()>> + Send + '_ {
918 let request = AuthPacket { scheme, auth };
919 let receiver = self.send_request(OpCode::Auth, &request);
920 async move {
921 receiver.await?;
922 Ok(())
923 }
924 }
925
926 pub fn list_auth_users(&self) -> impl Future<Output = Result<Vec<AuthUser>>> + Send {
936 let receiver = self.send_request(OpCode::WhoAmI, &());
937 async move {
938 let (body, _) = receiver.await?;
939 let mut buf = body.as_slice();
940 let authed_users = record::unmarshal_entity::<Vec<AuthUser>>(&"authed users", &mut buf)?;
941 Ok(authed_users)
942 }
943 }
944
945 pub fn get_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
947 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, false);
948 Self::map_wait(result, |(data, stat, _)| (data, stat))
949 }
950
951 pub fn get_and_watch_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send {
953 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, true);
954 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&OwnedChroot::default())))
955 }
956
957 pub fn update_ensemble<'a, I: Iterator<Item = &'a str> + Clone>(
965 &self,
966 update: EnsembleUpdate<'a, I>,
967 expected_zxid: Option<i64>,
968 ) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
969 let request = ReconfigRequest { update, version: expected_zxid.unwrap_or(-1) };
970 let receiver = self.send_request(OpCode::Reconfig, &request);
971 async move {
972 let (mut body, _) = receiver.await?;
973 let mut buf = body.as_slice();
974 let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?;
975 let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?;
976 let data_len = data.len();
977 body.truncate(data_len + 4);
978 drop(body.drain(..4));
979 Ok((body, stat))
980 }
981 }
982
983 pub fn new_multi_reader(&self) -> MultiReader<'_> {
985 MultiReader::new(self)
986 }
987
988 pub fn new_multi_writer(&self) -> MultiWriter<'_> {
990 MultiWriter::new(self)
991 }
992
993 pub fn new_check_writer(&self, path: &str, version: Option<i32>) -> Result<CheckWriter<'_>> {
996 let mut writer = self.new_multi_writer();
997 writer.add_check_version(path, version.unwrap_or(-1))?;
998 Ok(CheckWriter { writer })
999 }
1000
1001 async fn create_lock(
1002 &self,
1003 prefix: LockPrefix<'_>,
1004 data: &[u8],
1005 options: LockOptions<'_>,
1006 ) -> Result<(String, usize)> {
1007 let kind = prefix.kind();
1008 let prefix = prefix.into();
1009 self.validate_sequential_path(&prefix)?;
1010 let (parent, _, _) = util::split_path(&prefix);
1011 let guard = LockingGuard { zk: self, prefix: &prefix, unique: kind.is_unique() };
1012 loop {
1013 let mut result = self.create(&prefix, data, &CreateMode::EphemeralSequential.with_acls(options.acls)).await;
1014 if result == Err(Error::NoNode) {
1015 if let Some(options) = &options.parent {
1016 match Self::retry_on_connection_loss(|| self.mkdir_internally(parent, options)).await {
1017 Ok(_) => continue,
1018 Err(Error::NoNode) => result = Err(Error::NoNode),
1019 Err(err) => return Err(err),
1020 }
1021 }
1022 }
1023 let sequence = match result {
1024 Err(Error::ConnectionLoss) => {
1025 if let Some(sequence) = self.find_lock(&prefix, kind).await? {
1026 sequence
1027 } else {
1028 continue;
1029 }
1030 },
1031 Err(err) => {
1032 if err.has_no_data_change() {
1033 std::mem::forget(guard);
1034 return Err(err);
1035 } else {
1036 return Err(err);
1037 }
1038 },
1039 Ok((_stat, sequence)) => sequence,
1040 };
1041 std::mem::forget(guard);
1042 let prefix_len = prefix.len();
1043 let mut path = prefix;
1044 write!(&mut path, "{sequence}").unwrap();
1045 let sequence_len = path.len() - prefix_len;
1046 return Ok((path, sequence_len));
1047 }
1048 }
1049
1050 async fn find_lock(&self, prefix: &str, kind: LockPrefixKind<'_>) -> Result<Option<CreateSequence>> {
1051 let (parent, tree, name) = util::split_path(prefix);
1052 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1053 if kind.is_unique() {
1054 let Some(i) = children.iter().position(|s| s.starts_with(name)) else {
1055 return Ok(None);
1056 };
1057 let sequence = Self::parse_sequence(&children[i], name)?;
1058 return Ok(Some(sequence));
1059 }
1060 children.retain(|s| s.starts_with(name));
1061 if children.is_empty() {
1062 return Ok(None);
1063 }
1064 for child in children.iter_mut() {
1065 child.insert_str(0, tree);
1066 }
1067 let results = Self::retry_on_connection_loss(|| {
1068 let mut reader = self.new_multi_reader();
1069 for child in children.iter() {
1070 reader.add_get_data(child).unwrap();
1071 }
1072 reader.commit()
1073 })
1074 .await?;
1075 for (i, result) in results.into_iter().enumerate() {
1076 let MultiReadResult::Data { stat, .. } = result else {
1077 continue;
1079 };
1080 if stat.ephemeral_owner == self.session_id().0 {
1081 let sequence = Self::parse_sequence(&children[i], name)?;
1082 return Ok(Some(sequence));
1083 }
1084 }
1085 Ok(None)
1086 }
1087
1088 async fn wait_lock(&self, lock: &str, kind: LockPrefixKind<'_>, sequence_len: usize) -> Result<()> {
1089 let (parent, tree, this) = util::split_path(lock);
1090 loop {
1091 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1092 children.retain(|s| {
1093 s.len() >= sequence_len && kind.filter(s) && s[s.len() - sequence_len..].parse::<i32>().is_ok()
1094 });
1095 children.sort_unstable_by(|a, b| a[a.len() - sequence_len..].cmp(&b[b.len() - sequence_len..]));
1096 match children.binary_search_by(|a| a[a.len() - sequence_len..].cmp(&this[this.len() - sequence_len..])) {
1097 Ok(0) => return Ok(()),
1098 Ok(i) => {
1099 let mut child = children.swap_remove(i - 1);
1100 child.insert_str(0, tree);
1101 let watcher = match Self::retry_on_connection_loss(|| self.get_and_watch_data(&child)).await {
1102 Err(Error::NoNode) => continue,
1103 Err(err) => return Err(err),
1104 Ok((_data, _stat, watcher)) => watcher,
1105 };
1106 watcher.changed().await;
1107 },
1108 Err(_) => return Err(Error::RuntimeInconsistent),
1109 }
1110 }
1111 }
1112
1113 pub async fn lock(
1144 &self,
1145 prefix: LockPrefix<'_>,
1146 data: &[u8],
1147 options: impl Into<LockOptions<'_>>,
1148 ) -> Result<LockClient<'_>> {
1149 let options = options.into();
1150 if options.acls.is_empty() {
1151 return Err(Error::InvalidAcl);
1152 }
1153 let prefix_kind = prefix.kind();
1154 let (lock, sequence_len) = self.create_lock(prefix, data, options).await?;
1155 let client = LockClient { client: self, lock: Cow::from(lock) };
1156 match self.wait_lock(&client.lock, prefix_kind, sequence_len).await {
1157 Err(err @ (Error::RuntimeInconsistent | Error::SessionExpired)) => {
1158 std::mem::forget(client);
1159 Err(err)
1160 },
1161 Err(err) => Err(err),
1162 Ok(_) => Ok(client),
1163 }
1164 }
1165}
1166
1167#[derive(Clone, Debug)]
1170pub struct LockOptions<'a> {
1171 acls: Acls<'a>,
1172 parent: Option<CreateOptions<'a>>,
1173}
1174
1175impl<'a> LockOptions<'a> {
1176 pub fn new(acls: Acls<'a>) -> Self {
1177 Self { acls, parent: None }
1178 }
1179
1180 pub fn with_ancestor_options(mut self, options: CreateOptions<'a>) -> Result<Self> {
1186 options.validate_as_directory()?;
1187 self.parent = Some(options);
1188 Ok(self)
1189 }
1190}
1191
1192impl<'a> From<Acls<'a>> for LockOptions<'a> {
1193 fn from(acls: Acls<'a>) -> Self {
1194 LockOptions::new(acls)
1195 }
1196}
1197
1198#[derive(Clone, Copy)]
1199enum LockPrefixKind<'a> {
1200 Curator { lock_name: &'a str },
1201 Custom { lock_name: &'a str },
1202 Shared { prefix: &'a str },
1203}
1204
1205impl LockPrefixKind<'_> {
1206 fn filter(&self, name: &str) -> bool {
1207 match self {
1208 Self::Curator { lock_name } => name.contains(lock_name),
1209 Self::Custom { lock_name } => name.contains(lock_name),
1210 Self::Shared { prefix } => name.starts_with(prefix),
1211 }
1212 }
1213
1214 fn is_unique(&self) -> bool {
1215 matches!(self, Self::Curator { .. })
1216 }
1217}
1218
1219#[derive(Debug)]
1220enum LockPrefixInner<'a> {
1221 Curator { dir: &'a str, name: &'a str },
1222 Custom { prefix: String, name: &'a str },
1223 Shared { prefix: &'a str },
1224}
1225
1226#[derive(Debug)]
1235pub struct LockPrefix<'a> {
1236 inner: LockPrefixInner<'a>,
1237}
1238
1239impl<'a> LockPrefix<'a> {
1240 pub fn new_curator(dir: &'a str, name: &'a str) -> Result<Self> {
1247 crate::util::validate_path(Chroot::default(), dir, false)?;
1248 if name.find('/').is_some() {
1249 return Err(Error::BadArguments(&"lock name must not contain /"));
1250 }
1251 Ok(Self { inner: LockPrefixInner::Curator { dir, name } })
1252 }
1253
1254 pub fn new_shared(prefix: &'a str) -> Result<Self> {
1266 crate::util::validate_path(Chroot::default(), prefix, true)?;
1267 Ok(Self { inner: LockPrefixInner::Shared { prefix } })
1268 }
1269
1270 pub fn new_custom(prefix: String, name: &'a str) -> Result<Self> {
1286 crate::util::validate_path(Chroot::default(), &prefix, true)?;
1287 if !name.is_empty() {
1288 let (_dir, _tree, this) = util::split_path(&prefix);
1289 if !this.contains(name) {
1290 return Err(Error::BadArguments(&"lock path prefix must contain lock name"));
1291 }
1292 }
1293 Ok(Self { inner: LockPrefixInner::Custom { prefix, name } })
1294 }
1295
1296 fn kind(&self) -> LockPrefixKind<'a> {
1297 match &self.inner {
1298 LockPrefixInner::Curator { name, .. } => LockPrefixKind::Curator { lock_name: name },
1299 LockPrefixInner::Shared { prefix } => {
1300 let (_parent, _tree, name) = util::split_path(prefix);
1301 LockPrefixKind::Shared { prefix: name }
1302 },
1303 LockPrefixInner::Custom { name, .. } => LockPrefixKind::Custom { lock_name: name },
1304 }
1305 }
1306
1307 fn into(self) -> String {
1308 match self.inner {
1309 LockPrefixInner::Curator { dir, name } => format!("{}/_c_{}-{}", dir, uuid::Uuid::new_v4(), name),
1310 LockPrefixInner::Shared { prefix } => prefix.to_string(),
1311 LockPrefixInner::Custom { prefix, .. } => prefix,
1312 }
1313 }
1314}
1315
1316struct LockingGuard<'a> {
1317 zk: &'a Client,
1318 prefix: &'a str,
1319 unique: bool,
1320}
1321
1322impl Drop for LockingGuard<'_> {
1323 fn drop(&mut self) {
1324 self.zk.clone().delete_ephemeral_background(self.prefix.to_string(), self.unique);
1325 }
1326}
1327
1328#[derive(Debug)]
1330pub struct LockClient<'a> {
1331 client: &'a Client,
1332 lock: Cow<'a, str>,
1333}
1334
1335impl<'a> LockClient<'a> {
1336 async fn resolve_one_write(
1337 future: impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>>,
1338 ) -> Result<MultiWriteResult> {
1339 let mut results = future.await?;
1340 Ok(results.remove(0))
1341 }
1342
1343 pub fn client(&self) -> &'a Client {
1345 self.client
1346 }
1347
1348 pub fn lock_path(&self) -> &str {
1353 &self.lock
1354 }
1355
1356 pub fn create(
1364 &self,
1365 path: &str,
1366 data: &[u8],
1367 options: &CreateOptions<'_>,
1368 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a {
1369 Client::wait(self.create_internally(path, data, options))
1370 }
1371
1372 fn create_internally(
1373 &self,
1374 path: &str,
1375 data: &[u8],
1376 options: &CreateOptions<'_>,
1377 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a> {
1378 let mut writer = self.client.new_check_writer(&self.lock, None)?;
1379 writer.add_create(path, data, options)?;
1380 let write = writer.commit();
1381 let path_len = path.len();
1386 Ok(async move {
1387 let result = Self::resolve_one_write(write).await?;
1388 let (created_path, stat) = result.into_create()?;
1389 let sequence = if created_path.len() <= path_len {
1390 CreateSequence(-1)
1391 } else {
1392 Client::parse_sequence(&created_path, &created_path[..path_len])?
1393 };
1394 Ok((stat, sequence))
1395 })
1396 }
1397
1398 pub fn set_data(
1400 &self,
1401 path: &str,
1402 data: &[u8],
1403 expected_version: Option<i32>,
1404 ) -> impl Future<Output = Result<Stat>> + Send + 'a {
1405 Client::wait(self.set_data_internally(path, data, expected_version))
1406 }
1407
1408 fn set_data_internally(
1409 &self,
1410 path: &str,
1411 data: &[u8],
1412 expected_version: Option<i32>,
1413 ) -> Result<impl Future<Output = Result<Stat>> + Send + 'a> {
1414 let mut writer = self.new_check_writer();
1415 writer.add_set_data(path, data, expected_version)?;
1416 let write = writer.commit();
1417 Ok(async move {
1418 let result = Self::resolve_one_write(write).await?;
1419 let stat = result.into_set_data()?;
1420 Ok(stat)
1421 })
1422 }
1423
1424 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + 'a {
1426 Client::wait(self.delete_internally(path, expected_version))
1427 }
1428
1429 fn delete_internally(
1430 &self,
1431 path: &str,
1432 expected_version: Option<i32>,
1433 ) -> Result<impl Future<Output = Result<()>> + Send + 'a> {
1434 let mut writer = self.new_check_writer();
1435 writer.add_delete(path, expected_version)?;
1436 let write = writer.commit();
1437 Ok(async move {
1438 let result = Self::resolve_one_write(write).await?;
1439 result.into_delete()
1440 })
1441 }
1442
1443 pub fn new_check_writer(&self) -> CheckWriter<'a> {
1445 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1446 }
1447
1448 pub fn into_owned(self) -> OwnedLockClient {
1450 let client = self.client.clone();
1451 let mut drop = ManuallyDrop::new(self);
1452 let lock = std::mem::take(drop.lock.to_mut());
1453 OwnedLockClient { client: ManuallyDrop::new(client), lock }
1454 }
1455}
1456
1457impl Drop for LockClient<'_> {
1459 fn drop(&mut self) {
1460 let path = std::mem::take(self.lock.to_mut());
1461 let client = self.client.clone();
1462 client.delete_background(path);
1463 }
1464}
1465
1466#[derive(Clone, Debug)]
1468pub struct OwnedLockClient {
1469 client: ManuallyDrop<Client>,
1470 lock: String,
1471}
1472
1473impl OwnedLockClient {
1474 fn lock_client(&self) -> std::mem::ManuallyDrop<LockClient<'_>> {
1475 std::mem::ManuallyDrop::new(LockClient { client: &self.client, lock: Cow::from(&self.lock) })
1476 }
1477
1478 pub fn client(&self) -> &Client {
1480 &self.client
1481 }
1482
1483 pub fn lock_path(&self) -> &str {
1485 &self.lock
1486 }
1487
1488 pub fn create<'a: 'f, 'b: 'f, 'f>(
1490 &'a self,
1491 path: &'b str,
1492 data: &[u8],
1493 options: &CreateOptions<'_>,
1494 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
1495 self.lock_client().create(path, data, options)
1496 }
1497
1498 pub fn set_data(
1500 &self,
1501 path: &str,
1502 data: &[u8],
1503 expected_version: Option<i32>,
1504 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
1505 self.lock_client().set_data(path, data, expected_version)
1506 }
1507
1508 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + '_ {
1510 self.lock_client().delete(path, expected_version)
1511 }
1512
1513 pub fn new_check_writer(&self) -> CheckWriter<'_> {
1515 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1516 }
1517}
1518
1519impl Drop for OwnedLockClient {
1521 fn drop(&mut self) {
1522 let client = unsafe { ManuallyDrop::take(&mut self.client) };
1523 let path = std::mem::take(&mut self.lock);
1524 client.delete_background(path);
1525 }
1526}
1527
1528#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
1529pub(crate) struct Version(u32, u32, u32);
1530
1531#[derive(Clone)]
1535#[derive_where(Debug)]
1536pub struct Connector {
1537 #[cfg(feature = "tls")]
1538 tls: Option<TlsOptions>,
1539 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1540 sasl: Option<SaslOptions>,
1541 #[derive_where(skip(Debug))]
1542 authes: Vec<MarshalledRequest>,
1543 session: Option<SessionInfo>,
1544 readonly: bool,
1545 detached: bool,
1546 fail_eagerly: bool,
1547 server_version: Version,
1548 session_timeout: Duration,
1549 connection_timeout: Duration,
1550}
1551
1552impl Connector {
1553 fn new() -> Self {
1554 Self {
1555 #[cfg(feature = "tls")]
1556 tls: None,
1557 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1558 sasl: None,
1559 authes: Default::default(),
1560 session: None,
1561 readonly: false,
1562 detached: false,
1563 fail_eagerly: false,
1564 server_version: Version(u32::MAX, u32::MAX, u32::MAX),
1565 session_timeout: Duration::ZERO,
1566 connection_timeout: Duration::ZERO,
1567 }
1568 }
1569
1570 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
1574 self.session_timeout = timeout;
1575 self
1576 }
1577
1578 #[deprecated(since = "0.11.0", note = "use Connector::with_session_timeout instead")]
1582 pub fn session_timeout(&mut self, timeout: Duration) -> &mut Self {
1583 self.session_timeout = timeout;
1584 self
1585 }
1586
1587 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
1591 self.connection_timeout = timeout;
1592 self
1593 }
1594
1595 #[deprecated(since = "0.11.0", note = "use Connector::with_connection_timeout instead")]
1599 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1600 self.connection_timeout = timeout;
1601 self
1602 }
1603
1604 pub fn with_readonly(mut self, readonly: bool) -> Self {
1606 self.readonly = readonly;
1607 self
1608 }
1609
1610 #[deprecated(since = "0.11.0", note = "use Connector::with_readonly instead")]
1612 pub fn readonly(&mut self, readonly: bool) -> &mut Self {
1613 self.readonly = readonly;
1614 self
1615 }
1616
1617 pub fn with_auth(mut self, scheme: &str, auth: &[u8]) -> Self {
1619 let packet = AuthPacket { scheme, auth };
1620 let request = MarshalledRequest::new(OpCode::Auth, &packet);
1621 self.authes.push(request);
1622 self
1623 }
1624
1625 #[deprecated(since = "0.11.0", note = "use Connector::with_auth instead")]
1627 pub fn auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut Self {
1628 let packet = AuthPacket { scheme: &scheme, auth: &auth };
1629 let request = MarshalledRequest::new(OpCode::Auth, &packet);
1630 self.authes.push(request);
1631 self
1632 }
1633
1634 pub fn with_session(mut self, session: SessionInfo) -> Self {
1636 self.session = Some(session);
1637 self
1638 }
1639
1640 #[deprecated(since = "0.11.0", note = "use Connector::with_session instead")]
1642 pub fn session(&mut self, session: SessionInfo) -> &mut Self {
1643 self.session = Some(session);
1644 self
1645 }
1646
1647 pub fn with_server_version(mut self, major: u32, minor: u32, patch: u32) -> Self {
1657 self.server_version = Version(major, minor, patch);
1658 self
1659 }
1660
1661 #[deprecated(since = "0.11.0", note = "use Connector::with_server_version instead")]
1671 pub fn server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1672 self.server_version = Version(major, minor, patch);
1673 self
1674 }
1675
1676 pub fn with_detached(mut self) -> Self {
1678 self.detached = true;
1679 self
1680 }
1681
1682 #[deprecated(since = "0.11.0", note = "use Connector::with_detached instead")]
1684 pub fn detached(&mut self) -> &mut Self {
1685 self.detached = true;
1686 self
1687 }
1688
1689 #[cfg(feature = "tls")]
1691 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
1692 pub fn with_tls(mut self, options: TlsOptions) -> Self {
1693 self.tls = Some(options);
1694 self
1695 }
1696
1697 #[cfg(feature = "tls")]
1699 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
1700 #[deprecated(since = "0.11.0", note = "use Connector::with_tls instead")]
1701 pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
1702 self.tls = Some(options);
1703 self
1704 }
1705
1706 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1708 #[cfg_attr(docsrs, doc(cfg(any(feature = "sasl", feature = "sasl-gssapi", feature = "sasl-digest-md5"))))]
1709 pub fn with_sasl(mut self, options: impl Into<SaslOptions>) -> Self {
1710 self.sasl = Some(options.into());
1711 self
1712 }
1713
1714 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1716 #[cfg_attr(docsrs, doc(cfg(any(feature = "sasl", feature = "sasl-gssapi", feature = "sasl-digest-md5"))))]
1717 #[deprecated(since = "0.11.0", note = "use Connector::with_sasl instead")]
1718 pub fn sasl(&mut self, options: impl Into<SaslOptions>) -> &mut Self {
1719 self.sasl = Some(options.into());
1720 self
1721 }
1722
1723 pub fn with_fail_eagerly(mut self) -> Self {
1728 self.fail_eagerly = true;
1729 self
1730 }
1731
1732 #[deprecated(since = "0.11.0", note = "use Connector::with_fail_eagerly instead")]
1737 pub fn fail_eagerly(&mut self) -> &mut Self {
1738 self.fail_eagerly = true;
1739 self
1740 }
1741
1742 #[instrument(name = "connect", skip_all, fields(session))]
1743 async fn connect_internally(self, secure: bool, cluster: &str) -> Result<Client> {
1744 let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1745 let builder = Session::builder()
1746 .with_session(self.session)
1747 .with_authes(self.authes)
1748 .with_readonly(self.readonly)
1749 .with_detached(self.detached)
1750 .with_session_timeout(self.session_timeout)
1751 .with_connection_timeout(self.connection_timeout);
1752 #[cfg(feature = "tls")]
1753 let builder = builder.with_tls(self.tls);
1754 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1755 let builder = builder.with_sasl(self.sasl);
1756 let (sender, receiver) = mpsc::unbounded();
1757 let sender = Arc::new(sender);
1758 let mut session = builder.build(Arc::downgrade(&sender))?;
1759 let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
1760 endpoints.reset();
1761 if !self.fail_eagerly {
1762 endpoints.cycle();
1763 }
1764 let mut buf = Vec::with_capacity(4096);
1765 let mut depot = Depot::new();
1766 let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1767 let session_info = session.session.clone();
1768 let session_timeout = session.session_timeout;
1769 let mut state_watcher = StateWatcher::new(session.subscribe_state());
1770 state_watcher.state();
1772 asyncs::spawn(async move {
1773 session.serve(endpoints, conn, buf, depot, receiver).await;
1774 });
1775 let client =
1776 Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_watcher);
1777 Ok(client)
1778 }
1779
1780 #[cfg(feature = "tls")]
1785 pub async fn secure_connect(self, cluster: &str) -> Result<Client> {
1786 self.connect_internally(true, cluster).await
1787 }
1788
1789 pub async fn connect(self, cluster: &str) -> Result<Client> {
1801 self.connect_internally(false, cluster).await
1802 }
1803}
1804
1805trait MultiBuffer {
1806 fn buffer(&mut self) -> &mut Vec<u8>;
1807
1808 fn op_code() -> OpCode;
1809
1810 fn build_request(&mut self) -> MarshalledRequest {
1811 let buffer = self.buffer();
1812 if buffer.is_empty() {
1813 return Default::default();
1814 }
1815 let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
1816 buffer.append_record(&header);
1817 buffer.finish();
1818 MarshalledRequest(std::mem::take(buffer))
1819 }
1820
1821 fn add_operation(&mut self, op: OpCode, request: &impl Record) {
1822 let buffer = self.buffer();
1823 if buffer.is_empty() {
1824 let n = RequestHeader::record_len() + MultiHeader::record_len() + request.serialized_len();
1825 buffer.prepare_and_reserve(n);
1826 buffer.append_record(&RequestHeader::with_code(Self::op_code()));
1827 }
1828 let header = MultiHeader { op, done: false, err: -1 };
1829 self.buffer().append_record2(&header, request);
1830 }
1831}
1832
1833#[non_exhaustive]
1835#[derive(Debug)]
1836pub enum MultiReadResult {
1837 Data { data: Vec<u8>, stat: Stat },
1839
1840 Children { children: Vec<String> },
1842
1843 Error { err: Error },
1845}
1846
1847pub struct MultiReader<'a> {
1849 client: &'a Client,
1850 buf: Vec<u8>,
1851}
1852
1853impl MultiBuffer for MultiReader<'_> {
1854 fn buffer(&mut self) -> &mut Vec<u8> {
1855 &mut self.buf
1856 }
1857
1858 fn op_code() -> OpCode {
1859 OpCode::MultiRead
1860 }
1861}
1862
1863impl<'a> MultiReader<'a> {
1864 fn new(client: &'a Client) -> MultiReader<'a> {
1865 MultiReader { client, buf: Default::default() }
1866 }
1867
1868 pub fn add_get_data(&mut self, path: &str) -> Result<()> {
1872 let chroot_path = self.client.validate_path(path)?;
1873 let request = GetRequest { path: chroot_path, watch: false };
1874 self.add_operation(OpCode::GetData, &request);
1875 Ok(())
1876 }
1877
1878 pub fn add_get_children(&mut self, path: &str) -> Result<()> {
1882 let chroot_path = self.client.validate_path(path)?;
1883 let request = GetChildrenRequest { path: chroot_path, watch: false };
1884 self.add_operation(OpCode::GetChildren, &request);
1885 Ok(())
1886 }
1887
1888 pub fn commit(&mut self) -> impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a {
1893 let request = self.build_request();
1894 Client::resolve(self.commit_internally(request))
1895 }
1896
1897 fn commit_internally(
1898 &self,
1899 request: MarshalledRequest,
1900 ) -> Result<Either<impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a, Vec<MultiReadResult>>> {
1901 if request.is_empty() {
1902 return Ok(Right(Vec::default()));
1903 }
1904 let receiver = self.client.send_marshalled_request(request);
1905 Ok(Left(async move {
1906 let (body, _) = receiver.await?;
1907 let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1908 let mut results = Vec::with_capacity(response.len());
1909 for result in response {
1910 match result {
1911 MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1912 MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1913 MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1914 }
1915 }
1916 Ok(results)
1917 }))
1918 }
1919
1920 pub fn abort(&mut self) {
1922 self.buf.clear();
1923 }
1924}
1925
1926#[non_exhaustive]
1928#[derive(Debug, PartialEq, Eq)]
1929pub enum MultiWriteResult {
1930 Check,
1932
1933 Delete,
1935
1936 Create {
1938 path: String,
1940
1941 stat: Stat,
1948 },
1949
1950 SetData {
1952 stat: Stat,
1954 },
1955}
1956
1957impl MultiWriteResult {
1958 fn kind(&self) -> &'static str {
1959 match self {
1960 MultiWriteResult::Check => "MultiWriteResult::Check",
1961 MultiWriteResult::Create { .. } => "MultiWriteResult::Create",
1962 MultiWriteResult::Delete => "MultiWriteResult::Delete",
1963 MultiWriteResult::SetData { .. } => "MultiWriteResult::SetData",
1964 }
1965 }
1966
1967 fn into_check(self) -> Result<()> {
1968 match self {
1969 MultiWriteResult::Check => Ok(()),
1970 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Check, got {}", self.kind()))),
1971 }
1972 }
1973
1974 fn into_create(self) -> Result<(String, Stat)> {
1975 match self {
1976 MultiWriteResult::Create { path, stat } => Ok((path, stat)),
1977 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Create, got {}", self.kind()))),
1978 }
1979 }
1980
1981 fn into_set_data(self) -> Result<Stat> {
1982 match self {
1983 MultiWriteResult::SetData { stat } => Ok(stat),
1984 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::SetData, got {}", self.kind()))),
1985 }
1986 }
1987
1988 fn into_delete(self) -> Result<()> {
1989 match self {
1990 MultiWriteResult::Delete => Ok(()),
1991 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Delete, got {}", self.kind()))),
1992 }
1993 }
1994}
1995
1996#[derive(Error, Clone, Debug, PartialEq, Eq)]
1998pub enum MultiWriteError {
1999 #[error("{source}")]
2000 RequestFailed {
2001 #[from]
2002 source: Error,
2003 },
2004
2005 #[error("operation at index {index} failed: {source}")]
2006 OperationFailed { index: usize, source: Error },
2007}
2008
2009impl From<MultiWriteError> for Error {
2010 fn from(err: MultiWriteError) -> Self {
2011 match err {
2012 MultiWriteError::RequestFailed { source } => source,
2013 MultiWriteError::OperationFailed { source, .. } => source,
2014 }
2015 }
2016}
2017
2018#[derive(Error, Clone, Debug, PartialEq, Eq)]
2020pub enum CheckWriteError {
2021 #[error("request failed: {source}")]
2022 RequestFailed {
2023 #[from]
2024 source: Error,
2025 },
2026
2027 #[error("path check failed: {source}")]
2028 CheckFailed { source: Error },
2029
2030 #[error("operation at index {index} failed: {source}")]
2031 OperationFailed { index: usize, source: Error },
2032}
2033
2034impl From<MultiWriteError> for CheckWriteError {
2035 fn from(err: MultiWriteError) -> Self {
2036 match err {
2037 MultiWriteError::RequestFailed { source } => CheckWriteError::RequestFailed { source },
2038 MultiWriteError::OperationFailed { index: 0, source } => CheckWriteError::CheckFailed { source },
2039 MultiWriteError::OperationFailed { index, source } => {
2040 CheckWriteError::OperationFailed { index: index - 1, source }
2041 },
2042 }
2043 }
2044}
2045
2046impl From<CheckWriteError> for Error {
2047 fn from(err: CheckWriteError) -> Self {
2048 match err {
2049 CheckWriteError::RequestFailed { source } => source,
2050 CheckWriteError::CheckFailed { source: Error::NoNode | Error::BadVersion } => Error::RuntimeInconsistent,
2051 CheckWriteError::CheckFailed { source } => source,
2052 CheckWriteError::OperationFailed { source, .. } => source,
2053 }
2054 }
2055}
2056
2057pub struct CheckWriter<'a> {
2059 writer: MultiWriter<'a>,
2060}
2061
2062impl<'a> CheckWriter<'a> {
2063 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2065 self.writer.add_check_version(path, version)
2066 }
2067
2068 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2070 self.writer.add_create(path, data, options)
2071 }
2072
2073 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2075 self.writer.add_set_data(path, data, expected_version)
2076 }
2077
2078 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2080 self.writer.add_delete(path, expected_version)
2081 }
2082
2083 pub fn commit(
2085 mut self,
2086 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>> + Send + 'a {
2087 let commit = self.writer.commit();
2088 async move {
2089 let mut results = commit.await?;
2090 if results.is_empty() {
2091 Err(CheckWriteError::RequestFailed {
2092 source: Error::UnexpectedError("expect path check, got none".to_string()),
2093 })
2094 } else {
2095 results.remove(0).into_check()?;
2096 Ok(results)
2097 }
2098 }
2099 }
2100}
2101
2102pub struct MultiWriter<'a> {
2104 client: &'a Client,
2105 buf: Vec<u8>,
2106}
2107
2108impl MultiBuffer for MultiWriter<'_> {
2109 fn buffer(&mut self) -> &mut Vec<u8> {
2110 &mut self.buf
2111 }
2112
2113 fn op_code() -> OpCode {
2114 OpCode::Multi
2115 }
2116}
2117
2118impl<'a> MultiWriter<'a> {
2119 fn new(client: &'a Client) -> MultiWriter<'a> {
2120 MultiWriter { client, buf: Default::default() }
2121 }
2122
2123 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2128 let chroot_path = self.client.validate_path(path)?;
2129 let request = CheckVersionRequest { path: chroot_path, version };
2130 self.add_operation(OpCode::Check, &request);
2131 Ok(())
2132 }
2133
2134 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2145 options.validate()?;
2146 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
2147 let create_mode = options.mode;
2148 let sequential = create_mode.is_sequential();
2149 let chroot_path =
2150 if sequential { self.client.validate_sequential_path(path)? } else { self.client.validate_path(path)? };
2151 let op_code = if ttl != 0 {
2152 OpCode::CreateTtl
2153 } else if create_mode.is_container() {
2154 OpCode::CreateContainer
2155 } else {
2156 OpCode::Create
2157 };
2158 let flags = create_mode.as_flags(ttl != 0);
2159 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
2160 self.add_operation(op_code, &request);
2161 Ok(())
2162 }
2163
2164 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2168 let chroot_path = self.client.validate_path(path)?;
2169 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
2170 self.add_operation(OpCode::SetData, &request);
2171 Ok(())
2172 }
2173
2174 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2178 let chroot_path = self.client.validate_path(path)?;
2179 if chroot_path.is_root() {
2180 return Err(Error::BadArguments(&"can not delete root node"));
2181 }
2182 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
2183 self.add_operation(OpCode::Delete, &request);
2184 Ok(())
2185 }
2186
2187 pub fn commit(
2195 &mut self,
2196 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a {
2197 let request = self.build_request();
2198 Client::resolve(self.commit_internally(request))
2199 }
2200
2201 #[allow(clippy::type_complexity)]
2202 fn commit_internally(
2203 &self,
2204 request: MarshalledRequest,
2205 ) -> Result<
2206 Either<impl Future<Output = Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a, Vec<MultiWriteResult>>,
2207 MultiWriteError,
2208 > {
2209 if request.is_empty() {
2210 return Ok(Right(Vec::default()));
2211 }
2212 let receiver = self.client.send_marshalled_request(request);
2213 let client = self.client;
2214 Ok(Left(async move {
2215 let (body, _) = receiver.await?;
2216 let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
2217 let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
2218 let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
2219 for (index, result) in response.into_iter().enumerate() {
2220 match result {
2221 MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
2222 MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
2223 MultiWriteResponse::Create { mut path, stat } => {
2224 path = util::strip_root_path(path, client.chroot.root())?;
2225 results.push(MultiWriteResult::Create { path: path.to_string(), stat });
2226 },
2227 MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
2228 MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
2229 MultiWriteResponse::Error(err) => {
2230 return Err(MultiWriteError::OperationFailed { index, source: err })
2231 },
2232 }
2233 }
2234 Ok(results)
2235 }))
2236 }
2237
2238 pub fn abort(&mut self) {
2240 self.buf.clear();
2241 }
2242}
2243
2244#[cfg(test)]
2245mod tests {
2246 use assertor::*;
2247
2248 use super::*;
2249
2250 #[test]
2251 fn test_create_options_validate() {
2252 assert_that!(CreateMode::Persistent.with_acls(Acls::new(Default::default())).validate().unwrap_err())
2253 .is_equal_to(Error::InvalidAcl);
2254
2255 let acls = Acls::anyone_all();
2256
2257 assert_that!(CreateMode::Ephemeral.with_acls(acls).with_ttl(Duration::from_secs(1)).validate().unwrap_err())
2258 .is_equal_to(Error::BadArguments(&"ttl can only be specified with persistent node"));
2259
2260 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::ZERO).validate().unwrap_err())
2261 .is_equal_to(Error::BadArguments(&"ttl is zero"));
2262
2263 assert_that!(CreateMode::Persistent
2264 .with_acls(acls)
2265 .with_ttl(Duration::from_millis(0x01FFFFFFFFFF))
2266 .validate()
2267 .unwrap_err())
2268 .is_equal_to(Error::BadArguments(&"ttl cannot larger than 1099511627775"));
2269
2270 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::from_secs(5)).validate())
2271 .is_equal_to(Ok(()));
2272 }
2273
2274 #[test]
2275 fn test_lock_options_with_ancestor_options() {
2276 let options = LockOptions::new(Acls::anyone_all());
2277 assert_that!(options
2278 .clone()
2279 .with_ancestor_options(CreateMode::Ephemeral.with_acls(Acls::anyone_all()))
2280 .unwrap_err())
2281 .is_equal_to(Error::BadArguments(&"directory node must not be ephemeral"));
2282 assert_that!(options
2283 .with_ancestor_options(CreateMode::PersistentSequential.with_acls(Acls::anyone_all()))
2284 .unwrap_err())
2285 .is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
2286 }
2287
2288 #[test_log::test(asyncs::test)]
2289 async fn session_last_zxid_seen() {
2290 use testcontainers::clients::Cli as DockerCli;
2291 use testcontainers::core::{Healthcheck, WaitFor};
2292 use testcontainers::images::generic::GenericImage;
2293
2294 let healthcheck = Healthcheck::default()
2295 .with_cmd(["./bin/zkServer.sh", "status"].iter())
2296 .with_interval(Duration::from_secs(2))
2297 .with_retries(60);
2298 let image =
2299 GenericImage::new("zookeeper", "3.9.0").with_healthcheck(healthcheck).with_wait_for(WaitFor::Healthcheck);
2300 let docker = DockerCli::default();
2301 let container = docker.run(image);
2302 let endpoint = format!("127.0.0.1:{}", container.get_host_port(2181));
2303
2304 let client1 = Client::connector().with_detached().connect(&endpoint).await.unwrap();
2305 client1.create("/n1", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2306
2307 let mut session = client1.into_session();
2308
2309 session.last_zxid = i64::MAX;
2311 assert_that!(Client::connector()
2312 .with_fail_eagerly()
2313 .with_session(session.clone())
2314 .connect(&endpoint)
2315 .await
2316 .unwrap_err())
2317 .is_equal_to(Error::NoHosts);
2318
2319 session.last_zxid = 0;
2321 let client2 =
2322 Client::connector().with_fail_eagerly().with_session(session.clone()).connect(&endpoint).await.unwrap();
2323 client2.create("/n2", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2324 }
2325}