dynamo_runtime/utils/
leader_worker_barrier.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    DistributedRuntime,
6    transports::etcd::{Client, WatchEvent},
7};
8use serde::{Serialize, de::DeserializeOwned};
9
10use std::collections::{HashMap, HashSet};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14fn barrier_key(id: &str, suffix: &str) -> String {
15    format!("barrier/{}/{}", id, suffix)
16}
17
18const BARRIER_DATA: &str = "data";
19const BARRIER_WORKER: &str = "worker";
20const BARRIER_COMPLETE: &str = "complete";
21const BARRIER_ABORT: &str = "abort";
22
23/// Watches for a specific number of items to appear under a key prefix
24async fn wait_for_key_count<T: DeserializeOwned>(
25    client: &Client,
26    key: String,
27    expected_count: usize,
28    timeout: Option<Duration>,
29) -> Result<HashMap<String, T>, LeaderWorkerBarrierError> {
30    let (_key, mut rx) = client
31        .kv_get_and_watch_prefix(&key)
32        .await
33        .map_err(LeaderWorkerBarrierError::EtcdError)?
34        .dissolve();
35
36    let mut data = HashMap::new();
37    let start = Instant::now();
38    let timeout = timeout.unwrap_or(Duration::MAX);
39
40    loop {
41        let elapsed = start.elapsed();
42        if elapsed > timeout {
43            return Err(LeaderWorkerBarrierError::Timeout);
44        }
45
46        let remaining_time = timeout.saturating_sub(elapsed);
47
48        tokio::select! {
49            Some(watch_event) = rx.recv() => {
50                handle_watch_event(watch_event, &mut data)?;
51            }
52            _ = tokio::time::sleep(remaining_time) => {
53                // Timeout occurred, continue to check count
54            }
55        }
56
57        if data.len() == expected_count {
58            return Ok(data);
59        }
60    }
61}
62
63/// Handles a single watch event by updating the data map
64fn handle_watch_event<T: DeserializeOwned>(
65    event: WatchEvent,
66    data: &mut HashMap<String, T>,
67) -> Result<(), LeaderWorkerBarrierError> {
68    match event {
69        WatchEvent::Put(kv) => {
70            let key = kv.key_str().unwrap().to_string();
71            let value =
72                serde_json::from_slice(kv.value()).map_err(LeaderWorkerBarrierError::SerdeError)?;
73            data.insert(key, value);
74        }
75        WatchEvent::Delete(kv) => {
76            let key = kv.key_str().unwrap();
77            data.remove(key);
78        }
79    }
80    Ok(())
81}
82
83/// Creates a key-value pair in etcd, returning a specific error if the key already exists
84async fn create_barrier_key<T: Serialize>(
85    client: &Client,
86    key: &str,
87    data: T,
88    lease_id: Option<u64>,
89) -> Result<(), LeaderWorkerBarrierError> {
90    let serialized_data =
91        serde_json::to_vec(&data).map_err(LeaderWorkerBarrierError::SerdeError)?;
92
93    // TODO: This can fail for many reasons, the most common of which is that the key already exists.
94    // Currently, the ETCD client returns a very generic error, so we can't distinguish between the them.
95    // For now, just assume it's because the key already exists.
96    client
97        .kv_create(key, serialized_data, lease_id)
98        .await
99        .map_err(|_| LeaderWorkerBarrierError::IdNotUnique)?;
100
101    Ok(())
102}
103
104/// Waits for a single key to appear (used for completion/abort signals)
105async fn wait_for_signal<T: DeserializeOwned>(
106    client: &Client,
107    key: String,
108) -> Result<T, LeaderWorkerBarrierError> {
109    let data = wait_for_key_count::<T>(client, key, 1, None).await?;
110    Ok(data.into_values().next().unwrap())
111}
112
113#[derive(Debug)]
114pub enum LeaderWorkerBarrierError {
115    EtcdClientNotFound,
116    IdNotUnique,
117    EtcdError(anyhow::Error),
118    SerdeError(serde_json::Error),
119    Timeout,
120    Aborted,
121    AlreadyCompleted,
122}
123
124/// A barrier for a leader to wait for a specific number of workers to join.
125pub struct LeaderBarrier<LeaderData, WorkerData> {
126    barrier_id: String,
127    num_workers: usize,
128    timeout: Option<Duration>,
129    marker: PhantomData<(LeaderData, WorkerData)>,
130}
131
132impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
133    LeaderBarrier<LeaderData, WorkerData>
134{
135    pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self {
136        Self {
137            barrier_id,
138            num_workers,
139            timeout,
140            marker: PhantomData,
141        }
142    }
143
144    /// Synchronize the leader with the workers.
145    ///
146    /// The leader will publish the barrier data, and the workers will wait for the barrier data to appear.
147    /// The leader will then signal completion or abort, and the workers will wait for the signal to appear.
148    pub async fn sync(
149        self,
150        rt: &DistributedRuntime,
151        data: &LeaderData,
152    ) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
153        let etcd_client = rt
154            .etcd_client()
155            .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
156
157        let lease_id = etcd_client.lease_id();
158
159        // Publish barrier data
160        self.publish_barrier_data(&etcd_client, data, lease_id)
161            .await?;
162
163        // Wait for workers to join
164        let worker_result = self.wait_for_workers(&etcd_client).await;
165
166        // Signal completion or abort
167        self.signal_completion(&etcd_client, &worker_result, lease_id)
168            .await?;
169
170        worker_result.map(|r| {
171            r.into_iter()
172                .map(|(k, v)| (k.split("/").last().unwrap().to_string(), v))
173                .collect()
174        })
175    }
176
177    async fn publish_barrier_data(
178        &self,
179        client: &Client,
180        data: &LeaderData,
181        lease_id: u64,
182    ) -> Result<(), LeaderWorkerBarrierError> {
183        let key = barrier_key(&self.barrier_id, BARRIER_DATA);
184        create_barrier_key(client, &key, data, Some(lease_id)).await
185    }
186
187    async fn wait_for_workers(
188        &self,
189        client: &Client,
190    ) -> Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
191        let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
192        let workers = wait_for_key_count(client, key, self.num_workers, self.timeout).await?;
193        Ok(workers)
194    }
195
196    async fn signal_completion(
197        &self,
198        client: &Client,
199        worker_result: &Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError>,
200        lease_id: u64,
201    ) -> Result<(), LeaderWorkerBarrierError> {
202        if let Ok(worker_result) = worker_result {
203            let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
204
205            let workers = worker_result.keys().collect::<HashSet<_>>();
206
207            create_barrier_key(client, &key, workers, Some(lease_id)).await?;
208        } else {
209            let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
210            create_barrier_key(client, &key, (), Some(lease_id)).await?;
211        }
212
213        Ok(())
214    }
215}
216
217// A barrier to synchronize a worker with a leader.
218pub struct WorkerBarrier<LeaderData, WorkerData> {
219    barrier_id: String,
220    worker_id: String,
221    marker: PhantomData<(LeaderData, WorkerData)>,
222}
223
224impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
225    WorkerBarrier<LeaderData, WorkerData>
226{
227    pub fn new(barrier_id: String, worker_id: String) -> Self {
228        Self {
229            barrier_id,
230            worker_id,
231            marker: PhantomData,
232        }
233    }
234
235    /// Synchronize the worker with the leader.
236    ///
237    /// The worker will wait for the barrier data to appear, and then register as a worker.
238    /// The worker will then wait for the completion or abort signal to appear.
239    ///
240    /// If the leader signals completion, the worker will return the barrier data.
241    /// If the leader signals abort, the worker will return an error.
242    pub async fn sync(
243        self,
244        rt: &DistributedRuntime,
245        data: &WorkerData,
246    ) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
247        let etcd_client = rt
248            .etcd_client()
249            .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
250
251        let lease_id = etcd_client.lease_id();
252
253        // Get barrier data while watching for abort signal
254        let barrier_data = self.get_barrier_data(&etcd_client).await?;
255
256        // Register as a worker
257        let worker_key = self.register_worker(&etcd_client, data, lease_id).await?;
258
259        // Wait for completion or abort signal
260        self.wait_for_completion(&etcd_client, worker_key).await?;
261
262        Ok(barrier_data)
263    }
264
265    async fn get_barrier_data(
266        &self,
267        client: &Client,
268    ) -> Result<LeaderData, LeaderWorkerBarrierError> {
269        let data_key = barrier_key(&self.barrier_id, BARRIER_DATA);
270        let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
271
272        tokio::select! {
273            result = wait_for_key_count::<LeaderData>(client, data_key, 1, None) => {
274                result?.into_values().next()
275                    .ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
276            }
277            _ = wait_for_signal::<()>(client, abort_key) => {
278                Err(LeaderWorkerBarrierError::Aborted)
279            }
280        }
281    }
282
283    async fn register_worker(
284        &self,
285        client: &Client,
286        data: &WorkerData,
287        lease_id: u64,
288    ) -> Result<String, LeaderWorkerBarrierError> {
289        let key = barrier_key(
290            &self.barrier_id,
291            &format!("{}/{}", BARRIER_WORKER, self.worker_id),
292        );
293        create_barrier_key(client, &key, data, Some(lease_id)).await?;
294        Ok(key)
295    }
296
297    async fn wait_for_completion(
298        &self,
299        client: &Client,
300        worker_key: String,
301    ) -> Result<(), LeaderWorkerBarrierError> {
302        let complete_key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
303        let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
304
305        tokio::select! {
306            Ok(workers) = wait_for_signal::<HashSet<String>>(client, complete_key) => {
307                if workers.contains(&worker_key) {
308                    Ok(())
309                } else {
310                    Err(LeaderWorkerBarrierError::AlreadyCompleted)
311                }
312            },
313            _ = wait_for_signal::<()>(client, abort_key) => Err(LeaderWorkerBarrierError::Aborted),
314        }
315    }
316}
317
318#[cfg(feature = "testing-etcd")]
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    use crate::Runtime;
324    use tokio::task::JoinHandle;
325
326    use std::sync::atomic::{AtomicU64, Ordering};
327
328    fn unique_id() -> String {
329        static COUNTER: AtomicU64 = AtomicU64::new(0);
330        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
331
332        format!("test_{}", id)
333    }
334
335    #[tokio::test]
336    async fn test_no_etcd() {
337        let rt = Runtime::from_current().unwrap();
338        let drt = DistributedRuntime::from_settings_without_discovery(rt.clone())
339            .await
340            .unwrap();
341
342        assert!(drt.etcd_client().is_none());
343
344        let barrier = LeaderBarrier::<String, String>::new("test".to_string(), 2, None);
345        let worker = WorkerBarrier::<String, String>::new("test".to_string(), "worker".to_string());
346
347        assert!(matches!(
348            barrier.sync(&drt, &"test".to_string()).await,
349            Err(LeaderWorkerBarrierError::EtcdClientNotFound)
350        ));
351        assert!(matches!(
352            worker.sync(&drt, &"test".to_string()).await,
353            Err(LeaderWorkerBarrierError::EtcdClientNotFound)
354        ));
355    }
356
357    #[tokio::test]
358    async fn test_simple() {
359        let rt = Runtime::from_current().unwrap();
360        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
361
362        let id = unique_id();
363
364        let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
365        let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
366
367        let drt_clone = drt.clone();
368        let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
369            tokio::spawn(async move {
370                let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
371                assert_eq!(worker_data.len(), 1);
372                assert_eq!(
373                    worker_data.get("worker").unwrap(),
374                    &"test_worker".to_string()
375                );
376                Ok(())
377            });
378
379        let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
380            tokio::spawn(async move {
381                let res = worker.sync(&drt, &"test_worker".to_string()).await?;
382                assert_eq!(res, "test_data".to_string());
383
384                Ok(())
385            });
386
387        let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
388
389        assert!(matches!(leader_res, Ok(Ok(_))));
390        assert!(matches!(worker_res, Ok(Ok(_))));
391    }
392
393    #[tokio::test]
394    async fn test_duplicate_leader() {
395        let rt = Runtime::from_current().unwrap();
396        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
397
398        let id = unique_id();
399
400        let leader1 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
401        let leader2 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
402
403        let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
404
405        let drt_clone = drt.clone();
406        let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
407            tokio::spawn(async move {
408                let worker_data = leader1.sync(&drt_clone, &"test_data".to_string()).await?;
409                assert_eq!(worker_data.len(), 1);
410                assert_eq!(
411                    worker_data.get("worker").unwrap(),
412                    &"test_worker".to_string()
413                );
414
415                // Now, try to sync leader 2.
416                let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await;
417
418                // Leader 2 should fail because the barrier ID is the same as leader 1.
419                assert!(matches!(
420                    leader2_res,
421                    Err(LeaderWorkerBarrierError::IdNotUnique)
422                ));
423
424                Ok(())
425            });
426
427        let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
428            tokio::spawn(async move {
429                let res = worker.sync(&drt, &"test_worker".to_string()).await?;
430                assert_eq!(res, "test_data".to_string());
431
432                Ok(())
433            });
434
435        let (leader1_res, worker_res) = tokio::join!(leader1_join, worker_join);
436
437        assert!(matches!(leader1_res, Ok(Ok(_))));
438        assert!(matches!(worker_res, Ok(Ok(_))));
439    }
440
441    #[tokio::test]
442    async fn test_duplicate_worker() {
443        let rt = Runtime::from_current().unwrap();
444        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
445
446        let id = unique_id();
447
448        let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
449        let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
450        let worker2 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
451
452        let drt_clone = drt.clone();
453        let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
454            tokio::spawn(async move {
455                let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
456                assert_eq!(worker_data.len(), 1);
457                assert_eq!(
458                    worker_data.get("worker").unwrap(),
459                    &"test_worker_1".to_string()
460                );
461
462                Ok(())
463            });
464
465        let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
466            tokio::spawn(async move {
467                let leader_data = worker1.sync(&drt, &"test_worker_1".to_string()).await?;
468                assert_eq!(leader_data, "test_data".to_string());
469
470                let worker2_res = worker2.sync(&drt, &"test_worker_2".to_string()).await;
471
472                assert!(matches!(
473                    worker2_res,
474                    Err(LeaderWorkerBarrierError::IdNotUnique)
475                ));
476
477                Ok(())
478            });
479
480        let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
481
482        assert!(matches!(leader_res, Ok(Ok(_))));
483        assert!(matches!(worker_res, Ok(Ok(_))));
484    }
485
486    #[tokio::test]
487    async fn test_timeout() {
488        let rt = Runtime::from_current().unwrap();
489        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
490
491        let id = unique_id();
492
493        let leader = LeaderBarrier::<(), ()>::new(id.clone(), 2, Some(Duration::from_millis(100)));
494        let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
495        let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
496
497        let drt_clone = drt.clone();
498        let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
499            tokio::spawn(async move {
500                let res = leader.sync(&drt_clone, &()).await;
501                assert!(matches!(res, Err(LeaderWorkerBarrierError::Timeout)));
502
503                Ok(())
504            });
505
506        let drt_clone = drt.clone();
507        let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
508            tokio::spawn(async move {
509                let res = worker1.sync(&drt_clone, &()).await;
510                assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
511
512                Ok(())
513            });
514
515        let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
516            tokio::spawn(async move {
517                tokio::time::sleep(Duration::from_millis(200)).await;
518                let res = worker2.sync(&drt, &()).await;
519                assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
520
521                Ok(())
522            });
523
524        let (leader_res, worker1_res, worker2_res) =
525            tokio::join!(leader_join, worker1_join, worker2_join);
526
527        assert!(matches!(leader_res, Ok(Ok(_))));
528        assert!(matches!(worker1_res, Ok(Ok(_))));
529        assert!(matches!(worker2_res, Ok(Ok(_))));
530    }
531
532    #[tokio::test]
533    async fn test_serde_error() {
534        let rt = Runtime::from_current().unwrap();
535        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
536
537        let id = unique_id();
538
539        // Get the leader to send a (), when the worker expects a String.
540        let leader =
541            LeaderBarrier::<(), String>::new(id.clone(), 1, Some(Duration::from_millis(100)));
542        let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker1".to_string());
543
544        let drt_clone = drt.clone();
545        let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
546            tokio::spawn(async move {
547                assert!(matches!(
548                    leader.sync(&drt_clone, &()).await,
549                    Err(LeaderWorkerBarrierError::Timeout)
550                ));
551                Ok(())
552            });
553
554        let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
555            tokio::spawn(async move {
556                assert!(matches!(
557                    worker1.sync(&drt, &"test_worker".to_string()).await,
558                    Err(LeaderWorkerBarrierError::SerdeError(_))
559                ));
560
561                Ok(())
562            });
563
564        let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
565
566        assert!(matches!(leader_res, Ok(Ok(_))));
567        assert!(matches!(worker_res, Ok(Ok(_))));
568    }
569
570    #[tokio::test]
571    async fn test_too_many_workers() {
572        let rt = Runtime::from_current().unwrap();
573        let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
574
575        let id = unique_id();
576
577        let leader = LeaderBarrier::<(), ()>::new(id.clone(), 1, None);
578        let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
579        let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
580
581        let drt_clone = drt.clone();
582        let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
583            tokio::spawn(async move {
584                leader.sync(&drt_clone, &()).await?;
585                Ok(())
586            });
587
588        let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
589            tokio::spawn(async move {
590                let drt_clone = drt.clone();
591                let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone, &()).await });
592
593                let worker2_join = tokio::spawn(async move { worker2.sync(&drt, &()).await });
594
595                let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);
596
597                let mut num_successes = 0;
598                for worker_res in [worker1_res, worker2_res] {
599                    if let Ok(Ok(_)) = worker_res {
600                        num_successes += 1;
601                    } else if let Ok(Err(LeaderWorkerBarrierError::AlreadyCompleted)) = worker_res {
602                    } else {
603                        panic!();
604                    }
605                }
606
607                assert_eq!(num_successes, 1);
608                Ok(())
609            });
610
611        let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);
612
613        assert!(matches!(leader_res, Ok(Ok(_))));
614        assert!(matches!(worker_res, Ok(Ok(_))));
615    }
616}