distributed_scheduler/driver/
etcd.rs

1/// Etcd driver implementation.
2use std::collections::HashSet;
3use std::sync::{atomic::AtomicBool, Arc};
4
5use etcd_client::*;
6use tokio::sync::{Mutex, RwLock};
7
8use super::{utils, Driver};
9
10const DEFAULT_LEASE_TTL: i64 = 3;
11
12#[derive(Clone)]
13/// Etcd driver, used to manage nodes in an etcd cluster.
14pub struct EtcdDriver {
15    client: Arc<Mutex<Client>>,
16
17    service_name: String,
18    node_id: String,
19
20    stop: Arc<AtomicBool>,
21    node_list: Arc<RwLock<HashSet<String>>>,
22
23    lease_ttl: i64,
24}
25
26impl std::fmt::Debug for EtcdDriver {
27    fn fmt(
28        &self,
29        f: &mut std::fmt::Formatter,
30    ) -> std::fmt::Result {
31        f.debug_struct("EtcdDriver")
32            .field("service_name", &self.service_name)
33            .field("node_id", &self.node_id)
34            .field("stop", &self.stop)
35            .field("node_list", &self.node_list)
36            .finish()
37    }
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42    #[error("Etcd error: {0}")]
43    Etcd(#[from] etcd_client::Error),
44    #[error("Empty service name")]
45    EmptyServiceName,
46    #[error("Empty node id")]
47    EmptyNodeId,
48    #[error("Driver not started")]
49    DriverNotStarted,
50}
51
52impl EtcdDriver {
53    /// Create a new etcd driver with the given client, service name, and node id.
54    pub async fn new(
55        client: Client,
56        service_name: &str,
57        node_id: &str,
58    ) -> Result<Self, Error> {
59        if service_name.is_empty() {
60            return Err(Error::EmptyServiceName);
61        }
62
63        if node_id.is_empty() {
64            return Err(Error::EmptyNodeId);
65        }
66
67        Ok(Self {
68            client: Arc::new(Mutex::new(client)),
69            node_id: utils::get_key_prefix(service_name) + node_id,
70            service_name: service_name.into(),
71            stop: Arc::new(AtomicBool::new(true)),
72            node_list: Arc::new(RwLock::new(HashSet::new())),
73            lease_ttl: DEFAULT_LEASE_TTL,
74        })
75    }
76
77    /// Set the timeout for the driver.
78    pub fn with_timeout(
79        mut self,
80        timeout: i64,
81    ) -> Self {
82        self.lease_ttl = timeout;
83        self
84    }
85}
86
87#[async_trait::async_trait]
88impl Driver for EtcdDriver {
89    type Error = Error;
90
91    fn node_id(&self) -> String {
92        self.node_id.clone()
93    }
94
95    /// Read from local node list, because the node list is updated by a background task.
96    async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
97        if self.stop.load(std::sync::atomic::Ordering::SeqCst) {
98            return Err(Error::DriverNotStarted);
99        }
100
101        Ok(self.node_list.read().await.iter().cloned().collect())
102    }
103
104    /// Start a routine to watch for node changes and register the current node. Use lease to keep
105    /// the node key alive.
106    async fn start(&mut self) -> Result<(), Self::Error> {
107        let mut client = self.client.lock().await;
108        self.stop.store(false, std::sync::atomic::Ordering::SeqCst);
109
110        // init node list
111        let mut node_list = self.node_list.write().await;
112        for kv in client
113            .get(
114                utils::get_key_prefix(&self.service_name),
115                Some(GetOptions::new().with_prefix()),
116            )
117            .await?
118            .kvs()
119        {
120            node_list.insert(kv.key_str()?.into());
121        }
122
123        // watch for node changes
124        {
125            let (mut watcher, mut watch_stream) = client
126                .watch(
127                    utils::get_key_prefix(&self.service_name),
128                    Some(WatchOptions::new().with_prefix()),
129                )
130                .await?;
131            let node_list = self.node_list.clone();
132            let stop = self.stop.clone();
133            tokio::spawn(async move {
134                loop {
135                    if stop.load(std::sync::atomic::Ordering::SeqCst) {
136                        watcher.cancel().await.expect("Failed to cancel watcher");
137                        break;
138                    }
139
140                    match watch_stream.message().await {
141                        Ok(Some(resp)) => {
142                            if resp.canceled() {
143                                tracing::warn!("Watch stream canceled: {:?}", resp);
144                                break;
145                            }
146
147                            for event in resp.events() {
148                                let key = match event.kv() {
149                                    Some(kv) if kv.key_str().is_ok() => kv.key_str().unwrap().to_string(),
150                                    _ => continue,
151                                };
152
153                                match event.event_type() {
154                                    EventType::Put => node_list.write().await.insert(key),
155                                    EventType::Delete => node_list.write().await.remove(&key),
156                                };
157                            }
158                        }
159                        Ok(None) => panic!("Watch stream closed"),
160                        Err(e) => panic!("Watch error: {:?}", e),
161                    }
162                }
163            });
164        }
165
166        // register current node
167        {
168            tracing::info!("Registering node: {}", self.node_id);
169
170            // grant a lease for the node key
171            let lease = client.lease_grant(self.lease_ttl, None).await?;
172            let lease_id = lease.id();
173
174            // keep the lease alive
175            let (mut keeper, mut ka_stream) = client.lease_keep_alive(lease.id()).await?;
176            let stop = self.stop.clone();
177            let inner_client = self.client.clone();
178
179            // spawn a task to keep the lease alive
180            tokio::spawn(async move {
181                keeper.keep_alive().await.expect("Failed to keep alive");
182
183                loop {
184                    if stop.load(std::sync::atomic::Ordering::SeqCst) {
185                        inner_client
186                            .lock()
187                            .await
188                            .lease_revoke(lease_id)
189                            .await
190                            .expect("Failed to revoke lease");
191                        break;
192                    }
193
194                    match ka_stream.message().await {
195                        Ok(Some(_)) => keeper.keep_alive().await.expect("Failed to keep alive"),
196                        Ok(None) => panic!("Keep alive stream closed"),
197                        Err(e) => panic!("Keep alive error: {:?}", e),
198                    }
199                }
200            });
201
202            // put the node key
203            client
204                .put(
205                    self.node_id.as_str(),
206                    self.node_id.as_str(),
207                    Some(PutOptions::new().with_lease(lease_id)),
208                )
209                .await?;
210        }
211
212        Ok(())
213    }
214}
215
216impl Drop for EtcdDriver {
217    fn drop(&mut self) {
218        self.stop.store(true, std::sync::atomic::Ordering::SeqCst);
219    }
220}