distributed_scheduler/driver/
etcd.rs1use std::collections::HashSet;
3use std::sync::{atomic::AtomicBool, Arc};
4
5use etcd_client::*;
6use tokio::sync::{Mutex, RwLock};
7
8use super::{utils, Driver};
9
10const DEFAULT_LEASE_TTL: i64 = 3;
11
12#[derive(Clone)]
13pub struct EtcdDriver {
15 client: Arc<Mutex<Client>>,
16
17 service_name: String,
18 node_id: String,
19
20 stop: Arc<AtomicBool>,
21 node_list: Arc<RwLock<HashSet<String>>>,
22
23 lease_ttl: i64,
24}
25
26impl std::fmt::Debug for EtcdDriver {
27 fn fmt(
28 &self,
29 f: &mut std::fmt::Formatter,
30 ) -> std::fmt::Result {
31 f.debug_struct("EtcdDriver")
32 .field("service_name", &self.service_name)
33 .field("node_id", &self.node_id)
34 .field("stop", &self.stop)
35 .field("node_list", &self.node_list)
36 .finish()
37 }
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42 #[error("Etcd error: {0}")]
43 Etcd(#[from] etcd_client::Error),
44 #[error("Empty service name")]
45 EmptyServiceName,
46 #[error("Empty node id")]
47 EmptyNodeId,
48 #[error("Driver not started")]
49 DriverNotStarted,
50}
51
52impl EtcdDriver {
53 pub async fn new(
55 client: Client,
56 service_name: &str,
57 node_id: &str,
58 ) -> Result<Self, Error> {
59 if service_name.is_empty() {
60 return Err(Error::EmptyServiceName);
61 }
62
63 if node_id.is_empty() {
64 return Err(Error::EmptyNodeId);
65 }
66
67 Ok(Self {
68 client: Arc::new(Mutex::new(client)),
69 node_id: utils::get_key_prefix(service_name) + node_id,
70 service_name: service_name.into(),
71 stop: Arc::new(AtomicBool::new(true)),
72 node_list: Arc::new(RwLock::new(HashSet::new())),
73 lease_ttl: DEFAULT_LEASE_TTL,
74 })
75 }
76
77 pub fn with_timeout(
79 mut self,
80 timeout: i64,
81 ) -> Self {
82 self.lease_ttl = timeout;
83 self
84 }
85}
86
87#[async_trait::async_trait]
88impl Driver for EtcdDriver {
89 type Error = Error;
90
91 fn node_id(&self) -> String {
92 self.node_id.clone()
93 }
94
95 async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
97 if self.stop.load(std::sync::atomic::Ordering::SeqCst) {
98 return Err(Error::DriverNotStarted);
99 }
100
101 Ok(self.node_list.read().await.iter().cloned().collect())
102 }
103
104 async fn start(&mut self) -> Result<(), Self::Error> {
107 let mut client = self.client.lock().await;
108 self.stop.store(false, std::sync::atomic::Ordering::SeqCst);
109
110 let mut node_list = self.node_list.write().await;
112 for kv in client
113 .get(
114 utils::get_key_prefix(&self.service_name),
115 Some(GetOptions::new().with_prefix()),
116 )
117 .await?
118 .kvs()
119 {
120 node_list.insert(kv.key_str()?.into());
121 }
122
123 {
125 let (mut watcher, mut watch_stream) = client
126 .watch(
127 utils::get_key_prefix(&self.service_name),
128 Some(WatchOptions::new().with_prefix()),
129 )
130 .await?;
131 let node_list = self.node_list.clone();
132 let stop = self.stop.clone();
133 tokio::spawn(async move {
134 loop {
135 if stop.load(std::sync::atomic::Ordering::SeqCst) {
136 watcher.cancel().await.expect("Failed to cancel watcher");
137 break;
138 }
139
140 match watch_stream.message().await {
141 Ok(Some(resp)) => {
142 if resp.canceled() {
143 tracing::warn!("Watch stream canceled: {:?}", resp);
144 break;
145 }
146
147 for event in resp.events() {
148 let key = match event.kv() {
149 Some(kv) if kv.key_str().is_ok() => kv.key_str().unwrap().to_string(),
150 _ => continue,
151 };
152
153 match event.event_type() {
154 EventType::Put => node_list.write().await.insert(key),
155 EventType::Delete => node_list.write().await.remove(&key),
156 };
157 }
158 }
159 Ok(None) => panic!("Watch stream closed"),
160 Err(e) => panic!("Watch error: {:?}", e),
161 }
162 }
163 });
164 }
165
166 {
168 tracing::info!("Registering node: {}", self.node_id);
169
170 let lease = client.lease_grant(self.lease_ttl, None).await?;
172 let lease_id = lease.id();
173
174 let (mut keeper, mut ka_stream) = client.lease_keep_alive(lease.id()).await?;
176 let stop = self.stop.clone();
177 let inner_client = self.client.clone();
178
179 tokio::spawn(async move {
181 keeper.keep_alive().await.expect("Failed to keep alive");
182
183 loop {
184 if stop.load(std::sync::atomic::Ordering::SeqCst) {
185 inner_client
186 .lock()
187 .await
188 .lease_revoke(lease_id)
189 .await
190 .expect("Failed to revoke lease");
191 break;
192 }
193
194 match ka_stream.message().await {
195 Ok(Some(_)) => keeper.keep_alive().await.expect("Failed to keep alive"),
196 Ok(None) => panic!("Keep alive stream closed"),
197 Err(e) => panic!("Keep alive error: {:?}", e),
198 }
199 }
200 });
201
202 client
204 .put(
205 self.node_id.as_str(),
206 self.node_id.as_str(),
207 Some(PutOptions::new().with_lease(lease_id)),
208 )
209 .await?;
210 }
211
212 Ok(())
213 }
214}
215
216impl Drop for EtcdDriver {
217 fn drop(&mut self) {
218 self.stop.store(true, std::sync::atomic::Ordering::SeqCst);
219 }
220}