1use crate::{
5 DistributedRuntime,
6 transports::etcd::{Client, WatchEvent},
7};
8use serde::{Serialize, de::DeserializeOwned};
9
10use std::collections::{HashMap, HashSet};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14fn barrier_key(id: &str, suffix: &str) -> String {
15 format!("barrier/{}/{}", id, suffix)
16}
17
18const BARRIER_DATA: &str = "data";
19const BARRIER_WORKER: &str = "worker";
20const BARRIER_COMPLETE: &str = "complete";
21const BARRIER_ABORT: &str = "abort";
22
23async fn wait_for_key_count<T: DeserializeOwned>(
25 client: &Client,
26 key: String,
27 expected_count: usize,
28 timeout: Option<Duration>,
29) -> Result<HashMap<String, T>, LeaderWorkerBarrierError> {
30 let (_key, _watcher, mut rx) = client
31 .kv_get_and_watch_prefix(&key)
32 .await
33 .map_err(LeaderWorkerBarrierError::EtcdError)?
34 .dissolve();
35
36 let mut data = HashMap::new();
37 let start = Instant::now();
38 let timeout = timeout.unwrap_or(Duration::MAX);
39
40 loop {
41 let elapsed = start.elapsed();
42 if elapsed > timeout {
43 return Err(LeaderWorkerBarrierError::Timeout);
44 }
45
46 let remaining_time = timeout.saturating_sub(elapsed);
47
48 tokio::select! {
49 Some(watch_event) = rx.recv() => {
50 handle_watch_event(watch_event, &mut data)?;
51 }
52 _ = tokio::time::sleep(remaining_time) => {
53 }
55 }
56
57 if data.len() == expected_count {
58 return Ok(data);
59 }
60 }
61}
62
63fn handle_watch_event<T: DeserializeOwned>(
65 event: WatchEvent,
66 data: &mut HashMap<String, T>,
67) -> Result<(), LeaderWorkerBarrierError> {
68 match event {
69 WatchEvent::Put(kv) => {
70 let key = kv.key_str().unwrap().to_string();
71 let value =
72 serde_json::from_slice(kv.value()).map_err(LeaderWorkerBarrierError::SerdeError)?;
73 data.insert(key, value);
74 }
75 WatchEvent::Delete(kv) => {
76 let key = kv.key_str().unwrap();
77 data.remove(key);
78 }
79 }
80 Ok(())
81}
82
83async fn create_barrier_key<T: Serialize>(
85 client: &Client,
86 key: &str,
87 data: T,
88 lease_id: Option<i64>,
89) -> Result<(), LeaderWorkerBarrierError> {
90 let serialized_data =
91 serde_json::to_vec(&data).map_err(LeaderWorkerBarrierError::SerdeError)?;
92
93 client
97 .kv_create(key, serialized_data, lease_id)
98 .await
99 .map_err(|_| LeaderWorkerBarrierError::IdNotUnique)?;
100
101 Ok(())
102}
103
104async fn wait_for_signal<T: DeserializeOwned>(
106 client: &Client,
107 key: String,
108) -> Result<T, LeaderWorkerBarrierError> {
109 let data = wait_for_key_count::<T>(client, key, 1, None).await?;
110 Ok(data.into_values().next().unwrap())
111}
112
113#[derive(Debug)]
114pub enum LeaderWorkerBarrierError {
115 EtcdClientNotFound,
116 IdNotUnique,
117 EtcdError(anyhow::Error),
118 SerdeError(serde_json::Error),
119 Timeout,
120 Aborted,
121 AlreadyCompleted,
122}
123
124pub struct LeaderBarrier<LeaderData, WorkerData> {
126 barrier_id: String,
127 num_workers: usize,
128 timeout: Option<Duration>,
129 marker: PhantomData<(LeaderData, WorkerData)>,
130}
131
132impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
133 LeaderBarrier<LeaderData, WorkerData>
134{
135 pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self {
136 Self {
137 barrier_id,
138 num_workers,
139 timeout,
140 marker: PhantomData,
141 }
142 }
143
144 pub async fn sync(
149 self,
150 rt: &DistributedRuntime,
151 data: &LeaderData,
152 ) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
153 let etcd_client = rt
154 .etcd_client()
155 .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
156
157 let lease_id = etcd_client.lease_id();
158
159 self.publish_barrier_data(&etcd_client, data, lease_id)
161 .await?;
162
163 let worker_result = self.wait_for_workers(&etcd_client).await;
165
166 self.signal_completion(&etcd_client, &worker_result, lease_id)
168 .await?;
169
170 worker_result.map(|r| {
171 r.into_iter()
172 .map(|(k, v)| (k.split("/").last().unwrap().to_string(), v))
173 .collect()
174 })
175 }
176
177 async fn publish_barrier_data(
178 &self,
179 client: &Client,
180 data: &LeaderData,
181 lease_id: i64,
182 ) -> Result<(), LeaderWorkerBarrierError> {
183 let key = barrier_key(&self.barrier_id, BARRIER_DATA);
184 create_barrier_key(client, &key, data, Some(lease_id)).await
185 }
186
187 async fn wait_for_workers(
188 &self,
189 client: &Client,
190 ) -> Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
191 let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
192 let workers = wait_for_key_count(client, key, self.num_workers, self.timeout).await?;
193 Ok(workers)
194 }
195
196 async fn signal_completion(
197 &self,
198 client: &Client,
199 worker_result: &Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError>,
200 lease_id: i64,
201 ) -> Result<(), LeaderWorkerBarrierError> {
202 if let Ok(worker_result) = worker_result {
203 let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
204
205 let workers = worker_result.keys().collect::<HashSet<_>>();
206
207 create_barrier_key(client, &key, workers, Some(lease_id)).await?;
208 } else {
209 let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
210 create_barrier_key(client, &key, (), Some(lease_id)).await?;
211 }
212
213 Ok(())
214 }
215}
216
217pub struct WorkerBarrier<LeaderData, WorkerData> {
219 barrier_id: String,
220 worker_id: String,
221 marker: PhantomData<(LeaderData, WorkerData)>,
222}
223
224impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
225 WorkerBarrier<LeaderData, WorkerData>
226{
227 pub fn new(barrier_id: String, worker_id: String) -> Self {
228 Self {
229 barrier_id,
230 worker_id,
231 marker: PhantomData,
232 }
233 }
234
235 pub async fn sync(
243 self,
244 rt: &DistributedRuntime,
245 data: &WorkerData,
246 ) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
247 let etcd_client = rt
248 .etcd_client()
249 .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
250
251 let lease_id = etcd_client.lease_id();
252
253 let barrier_data = self.get_barrier_data(&etcd_client).await?;
255
256 let worker_key = self.register_worker(&etcd_client, data, lease_id).await?;
258
259 self.wait_for_completion(&etcd_client, worker_key).await?;
261
262 Ok(barrier_data)
263 }
264
265 async fn get_barrier_data(
266 &self,
267 client: &Client,
268 ) -> Result<LeaderData, LeaderWorkerBarrierError> {
269 let data_key = barrier_key(&self.barrier_id, BARRIER_DATA);
270 let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
271
272 tokio::select! {
273 result = wait_for_key_count::<LeaderData>(client, data_key, 1, None) => {
274 result?.into_values().next()
275 .ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
276 }
277 _ = wait_for_signal::<()>(client, abort_key) => {
278 Err(LeaderWorkerBarrierError::Aborted)
279 }
280 }
281 }
282
283 async fn register_worker(
284 &self,
285 client: &Client,
286 data: &WorkerData,
287 lease_id: i64,
288 ) -> Result<String, LeaderWorkerBarrierError> {
289 let key = barrier_key(
290 &self.barrier_id,
291 &format!("{}/{}", BARRIER_WORKER, self.worker_id),
292 );
293 create_barrier_key(client, &key, data, Some(lease_id)).await?;
294 Ok(key)
295 }
296
297 async fn wait_for_completion(
298 &self,
299 client: &Client,
300 worker_key: String,
301 ) -> Result<(), LeaderWorkerBarrierError> {
302 let complete_key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
303 let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
304
305 tokio::select! {
306 Ok(workers) = wait_for_signal::<HashSet<String>>(client, complete_key) => {
307 if workers.contains(&worker_key) {
308 Ok(())
309 } else {
310 Err(LeaderWorkerBarrierError::AlreadyCompleted)
311 }
312 },
313 _ = wait_for_signal::<()>(client, abort_key) => Err(LeaderWorkerBarrierError::Aborted),
314 }
315 }
316}
317
318#[cfg(feature = "testing-etcd")]
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 use crate::Runtime;
324 use tokio::task::JoinHandle;
325
326 use std::sync::atomic::{AtomicU64, Ordering};
327
328 fn unique_id() -> String {
329 static COUNTER: AtomicU64 = AtomicU64::new(0);
330 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
331
332 format!("test_{}", id)
333 }
334
335 #[tokio::test]
336 async fn test_no_etcd() {
337 let rt = Runtime::from_current().unwrap();
338 let drt = DistributedRuntime::from_settings_without_discovery(rt.clone())
339 .await
340 .unwrap();
341
342 assert!(drt.etcd_client().is_none());
343
344 let barrier = LeaderBarrier::<String, String>::new("test".to_string(), 2, None);
345 let worker = WorkerBarrier::<String, String>::new("test".to_string(), "worker".to_string());
346
347 assert!(matches!(
348 barrier.sync(&drt, &"test".to_string()).await,
349 Err(LeaderWorkerBarrierError::EtcdClientNotFound)
350 ));
351 assert!(matches!(
352 worker.sync(&drt, &"test".to_string()).await,
353 Err(LeaderWorkerBarrierError::EtcdClientNotFound)
354 ));
355 }
356
357 #[tokio::test]
358 async fn test_simple() {
359 let rt = Runtime::from_current().unwrap();
360 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
361
362 let id = unique_id();
363
364 let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
365 let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
366
367 let drt_clone = drt.clone();
368 let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
369 tokio::spawn(async move {
370 let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
371 assert_eq!(worker_data.len(), 1);
372 assert_eq!(
373 worker_data.get("worker").unwrap(),
374 &"test_worker".to_string()
375 );
376 Ok(())
377 });
378
379 let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
380 tokio::spawn(async move {
381 let res = worker.sync(&drt, &"test_worker".to_string()).await?;
382 assert_eq!(res, "test_data".to_string());
383
384 Ok(())
385 });
386
387 let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
388
389 assert!(matches!(leader_res, Ok(Ok(_))));
390 assert!(matches!(worker_res, Ok(Ok(_))));
391 }
392
393 #[tokio::test]
394 async fn test_duplicate_leader() {
395 let rt = Runtime::from_current().unwrap();
396 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
397
398 let id = unique_id();
399
400 let leader1 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
401 let leader2 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
402
403 let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
404
405 let drt_clone = drt.clone();
406 let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
407 tokio::spawn(async move {
408 let worker_data = leader1.sync(&drt_clone, &"test_data".to_string()).await?;
409 assert_eq!(worker_data.len(), 1);
410 assert_eq!(
411 worker_data.get("worker").unwrap(),
412 &"test_worker".to_string()
413 );
414
415 let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await;
417
418 assert!(matches!(
420 leader2_res,
421 Err(LeaderWorkerBarrierError::IdNotUnique)
422 ));
423
424 Ok(())
425 });
426
427 let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
428 tokio::spawn(async move {
429 let res = worker.sync(&drt, &"test_worker".to_string()).await?;
430 assert_eq!(res, "test_data".to_string());
431
432 Ok(())
433 });
434
435 let (leader1_res, worker_res) = tokio::join!(leader1_join, worker_join);
436
437 assert!(matches!(leader1_res, Ok(Ok(_))));
438 assert!(matches!(worker_res, Ok(Ok(_))));
439 }
440
441 #[tokio::test]
442 async fn test_duplicate_worker() {
443 let rt = Runtime::from_current().unwrap();
444 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
445
446 let id = unique_id();
447
448 let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
449 let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
450 let worker2 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
451
452 let drt_clone = drt.clone();
453 let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
454 tokio::spawn(async move {
455 let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
456 assert_eq!(worker_data.len(), 1);
457 assert_eq!(
458 worker_data.get("worker").unwrap(),
459 &"test_worker_1".to_string()
460 );
461
462 Ok(())
463 });
464
465 let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
466 tokio::spawn(async move {
467 let leader_data = worker1.sync(&drt, &"test_worker_1".to_string()).await?;
468 assert_eq!(leader_data, "test_data".to_string());
469
470 let worker2_res = worker2.sync(&drt, &"test_worker_2".to_string()).await;
471
472 assert!(matches!(
473 worker2_res,
474 Err(LeaderWorkerBarrierError::IdNotUnique)
475 ));
476
477 Ok(())
478 });
479
480 let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
481
482 assert!(matches!(leader_res, Ok(Ok(_))));
483 assert!(matches!(worker_res, Ok(Ok(_))));
484 }
485
486 #[tokio::test]
487 async fn test_timeout() {
488 let rt = Runtime::from_current().unwrap();
489 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
490
491 let id = unique_id();
492
493 let leader = LeaderBarrier::<(), ()>::new(id.clone(), 2, Some(Duration::from_millis(100)));
494 let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
495 let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
496
497 let drt_clone = drt.clone();
498 let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
499 tokio::spawn(async move {
500 let res = leader.sync(&drt_clone, &()).await;
501 assert!(matches!(res, Err(LeaderWorkerBarrierError::Timeout)));
502
503 Ok(())
504 });
505
506 let drt_clone = drt.clone();
507 let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
508 tokio::spawn(async move {
509 let res = worker1.sync(&drt_clone, &()).await;
510 assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
511
512 Ok(())
513 });
514
515 let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
516 tokio::spawn(async move {
517 tokio::time::sleep(Duration::from_millis(200)).await;
518 let res = worker2.sync(&drt, &()).await;
519 assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
520
521 Ok(())
522 });
523
524 let (leader_res, worker1_res, worker2_res) =
525 tokio::join!(leader_join, worker1_join, worker2_join);
526
527 assert!(matches!(leader_res, Ok(Ok(_))));
528 assert!(matches!(worker1_res, Ok(Ok(_))));
529 assert!(matches!(worker2_res, Ok(Ok(_))));
530 }
531
532 #[tokio::test]
533 async fn test_serde_error() {
534 let rt = Runtime::from_current().unwrap();
535 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
536
537 let id = unique_id();
538
539 let leader =
541 LeaderBarrier::<(), String>::new(id.clone(), 1, Some(Duration::from_millis(100)));
542 let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker1".to_string());
543
544 let drt_clone = drt.clone();
545 let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
546 tokio::spawn(async move {
547 assert!(matches!(
548 leader.sync(&drt_clone, &()).await,
549 Err(LeaderWorkerBarrierError::Timeout)
550 ));
551 Ok(())
552 });
553
554 let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
555 tokio::spawn(async move {
556 assert!(matches!(
557 worker1.sync(&drt, &"test_worker".to_string()).await,
558 Err(LeaderWorkerBarrierError::SerdeError(_))
559 ));
560
561 Ok(())
562 });
563
564 let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
565
566 assert!(matches!(leader_res, Ok(Ok(_))));
567 assert!(matches!(worker_res, Ok(Ok(_))));
568 }
569
570 #[tokio::test]
571 async fn test_too_many_workers() {
572 let rt = Runtime::from_current().unwrap();
573 let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
574
575 let id = unique_id();
576
577 let leader = LeaderBarrier::<(), ()>::new(id.clone(), 1, None);
578 let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
579 let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
580
581 let drt_clone = drt.clone();
582 let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
583 tokio::spawn(async move {
584 leader.sync(&drt_clone, &()).await?;
585 Ok(())
586 });
587
588 let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
589 tokio::spawn(async move {
590 let drt_clone = drt.clone();
591 let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone, &()).await });
592
593 let worker2_join = tokio::spawn(async move { worker2.sync(&drt, &()).await });
594
595 let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);
596
597 let mut num_successes = 0;
598 for worker_res in [worker1_res, worker2_res] {
599 if let Ok(Ok(_)) = worker_res {
600 num_successes += 1;
601 } else if let Ok(Err(LeaderWorkerBarrierError::AlreadyCompleted)) = worker_res {
602 } else {
603 panic!();
604 }
605 }
606
607 assert_eq!(num_successes, 1);
608 Ok(())
609 });
610
611 let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
612
613 assert!(matches!(leader_res, Ok(Ok(_))));
614 assert!(matches!(worker_res, Ok(Ok(_))));
615 }
616}