Skip to main content

dubbo_rs_registry_etcd/
lib.rs

1pub use dubbo_rs_common;
2pub use dubbo_rs_registry;
3
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use dubbo_rs_common::error::RPCError;
10use dubbo_rs_common::node::Node;
11use dubbo_rs_common::url::URL;
12use dubbo_rs_registry::{NotifyListener, Registry, ServiceEvent};
13use etcd_client::{
14    Client as EtcdClient, ConnectOptions, DeleteOptions, EventType, GetOptions, PutOptions,
15    WatchOptions,
16};
17
18const DEFAULT_TTL: i64 = 30;
19const DEFAULT_DUBBO_ROOT: &str = "/dubbo";
20
21/// Etcd-based service registry using gRPC (etcd-client).
22///
23/// Connects to etcd via gRPC, matching the approach used by dubbo-java (jetcd).
24/// Supports lease-based ephemeral registration, keep-alive, and watch-based
25/// service discovery.
26pub struct EtcdRegistry {
27    url: URL,
28    endpoints: Vec<String>,
29    root_path: String,
30    client: tokio::sync::OnceCell<EtcdClient>,
31    lease_id: tokio::sync::Mutex<Option<i64>>,
32    subscribed: DashMap<String, Vec<Arc<dyn NotifyListener>>>,
33    shutdown: Arc<AtomicBool>,
34}
35
36impl EtcdRegistry {
37    #[must_use]
38    pub fn new(url: URL) -> Self {
39        let endpoints = url.get_param("endpoints").map_or_else(
40            || vec![format!("{}:{}", url.ip, url.port)],
41            |e| e.split(',').map(|s| s.trim().to_string()).collect(),
42        );
43
44        let root_path = url
45            .get_param("root")
46            .map_or_else(|| DEFAULT_DUBBO_ROOT.to_string(), Clone::clone);
47
48        Self {
49            url,
50            endpoints,
51            root_path,
52            client: tokio::sync::OnceCell::new(),
53            lease_id: tokio::sync::Mutex::new(None),
54            subscribed: DashMap::new(),
55            shutdown: Arc::new(AtomicBool::new(false)),
56        }
57    }
58
59    #[must_use]
60    pub fn with_endpoints(mut self, endpoints: impl Into<String>) -> Self {
61        let e: String = endpoints.into();
62        self.endpoints = e.split(',').map(|s| s.trim().to_string()).collect();
63        self
64    }
65
66    #[must_use]
67    pub fn with_root_path(mut self, path: impl Into<String>) -> Self {
68        self.root_path = path.into();
69        self
70    }
71
72    fn provider_path(&self, service_key: &str) -> String {
73        format!("{}/{service_key}/providers", self.root_path)
74    }
75
76    fn provider_key(&self, service_key: &str, url_str: &str) -> String {
77        let dir = self.provider_path(service_key);
78        format!("{dir}/{url_str}")
79    }
80
81    async fn connect(&self) -> Result<&EtcdClient, RPCError> {
82        self.client
83            .get_or_try_init(|| async {
84                let endpoints: Vec<String> = if self.endpoints.is_empty() {
85                    vec![format!("{}:{}", self.url.ip, self.url.port)]
86                } else {
87                    self.endpoints.clone()
88                };
89                EtcdClient::connect(endpoints, Some(ConnectOptions::default()))
90                    .await
91                    .map_err(|e| RPCError::ServerError(format!("etcd connect failed: {e}")))
92            })
93            .await
94    }
95
96    async fn ensure_lease(&self) -> Result<i64, RPCError> {
97        let mut guard = self.lease_id.lock().await;
98        if let Some(id) = *guard {
99            return Ok(id);
100        }
101
102        let client = self.connect().await?;
103        let mut lease_client = client.lease_client();
104        let resp = lease_client
105            .grant(DEFAULT_TTL, None)
106            .await
107            .map_err(|e| RPCError::ServerError(format!("etcd lease grant failed: {e}")))?;
108        let lease_id = resp.id();
109
110        // Start background keep-alive for the lease
111        let (mut _keeper, mut stream) = lease_client
112            .keep_alive(lease_id)
113            .await
114            .map_err(|e| RPCError::ServerError(format!("etcd keepalive failed: {e}")))?;
115
116        let shutdown = self.shutdown.clone();
117        tokio::spawn(async move {
118            loop {
119                tokio::select! {
120                    result = stream.message() => {
121                        if result.is_err() || !matches!(result, Ok(Some(_))) {
122                            break;
123                        }
124                    }
125                    () = async {
126                        while !shutdown.load(Ordering::SeqCst) {
127                            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
128                        }
129                    } => {
130                        break;
131                    }
132                }
133            }
134        });
135
136        *guard = Some(lease_id);
137        Ok(lease_id)
138    }
139
140    async fn put_with_lease(&self, key: &str, value: &str) -> Result<(), RPCError> {
141        let lease_id = self.ensure_lease().await?;
142        let client = self.connect().await?;
143        let mut kv = client.kv_client();
144        kv.put(key, value, Some(PutOptions::new().with_lease(lease_id)))
145            .await
146            .map_err(|e| RPCError::ServerError(format!("etcd put failed: {e}")))?;
147        Ok(())
148    }
149
150    async fn delete(&self, key: &str) -> Result<(), RPCError> {
151        let client = self.connect().await?;
152        let mut kv = client.kv_client();
153        kv.delete(key, Some(DeleteOptions::new()))
154            .await
155            .map_err(|e| RPCError::ServerError(format!("etcd delete failed: {e}")))?;
156        Ok(())
157    }
158
159    async fn get_prefix(&self, prefix: &str) -> Result<Vec<String>, RPCError> {
160        let client = self.connect().await?;
161        let mut kv = client.kv_client();
162        let resp = kv
163            .get(prefix, Some(GetOptions::new().with_prefix()))
164            .await
165            .map_err(|e| RPCError::ServerError(format!("etcd range failed: {e}")))?;
166
167        let values = resp
168            .kvs()
169            .iter()
170            .map(|kv| String::from_utf8(kv.value().to_vec()).unwrap_or_default())
171            .filter(|v| !v.is_empty())
172            .collect();
173
174        Ok(values)
175    }
176
177    async fn start_watch(&self, service_key: &str) -> Result<(), RPCError> {
178        let dir = self.provider_path(service_key);
179        let client = self.connect().await?;
180        let mut watch = client.watch_client();
181
182        let mut stream = watch
183            .watch(dir.clone(), Some(WatchOptions::new().with_prefix()))
184            .await
185            .map_err(|e| RPCError::ServerError(format!("etcd watch failed: {e}")))?;
186
187        let subscribed = self.subscribed.clone();
188        let watch_service_key = service_key.to_string();
189        let shutdown = self.shutdown.clone();
190
191        tokio::spawn(async move {
192            loop {
193                tokio::select! {
194                    result = stream.message() => {
195                        match result {
196                            Ok(Some(resp)) => {
197                                for event in resp.events() {
198                                    if event.event_type() != EventType::Put {
199                                        continue;
200                                    }
201                                    if let Some(kv) = event.kv() {
202                                        let url_str =
203                                            String::from_utf8(kv.value().to_vec()).unwrap_or_default();
204                                        if url_str.is_empty() {
205                                            continue;
206                                        }
207                                        if let Some(parsed) = parse_provider_url(&url_str) {
208                                            let ev = ServiceEvent::Add(vec![parsed]);
209                                            if let Some(listeners) = subscribed.get(&watch_service_key) {
210                                                for l in listeners.value() {
211                                                    l.notify(ev.clone()).await;
212                                                }
213                                            }
214                                        }
215                                    }
216                                }
217                            }
218                            _ => break,
219                        }
220                    }
221                    () = async {
222                        while !shutdown.load(Ordering::SeqCst) {
223                            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
224                        }
225                    } => {
226                        break;
227                    }
228                }
229            }
230        });
231
232        Ok(())
233    }
234}
235
236impl Drop for EtcdRegistry {
237    fn drop(&mut self) {
238        self.shutdown.store(true, Ordering::SeqCst);
239    }
240}
241
242impl Node for EtcdRegistry {
243    fn get_url(&self) -> &URL {
244        &self.url
245    }
246
247    fn is_available(&self) -> bool {
248        true
249    }
250
251    fn destroy(&self) {
252        self.shutdown.store(true, Ordering::SeqCst);
253    }
254}
255
256#[async_trait]
257impl Registry for EtcdRegistry {
258    async fn register(&self, url: URL) -> Result<(), RPCError> {
259        let service_key = url.get_service_key();
260        let full = url.to_full_string();
261        let key = self.provider_key(&service_key, &full);
262        self.put_with_lease(&key, &full).await
263    }
264
265    async fn unregister(&self, url: URL) -> Result<(), RPCError> {
266        let service_key = url.get_service_key();
267        let full = url.to_full_string();
268        let key = self.provider_key(&service_key, &full);
269        self.delete(&key).await
270    }
271
272    async fn subscribe(&self, url: URL, listener: Arc<dyn NotifyListener>) -> Result<(), RPCError> {
273        let service_key = url.get_service_key();
274        let is_first = {
275            let mut entries = self.subscribed.entry(service_key.clone()).or_default();
276            let is_first = entries.is_empty();
277            entries.push(listener);
278            is_first
279        };
280
281        if is_first {
282            let dir = self.provider_path(&service_key);
283            let values = self.get_prefix(&dir).await?;
284
285            let provider_urls: Vec<URL> = values
286                .iter()
287                .filter_map(|v| parse_provider_url(v))
288                .collect();
289
290            if !provider_urls.is_empty() {
291                let event = ServiceEvent::Add(provider_urls);
292                if let Some(listeners) = self.subscribed.get(&service_key) {
293                    for l in listeners.value() {
294                        l.notify(event.clone()).await;
295                    }
296                }
297            }
298
299            self.start_watch(&service_key).await?;
300        }
301
302        Ok(())
303    }
304
305    async fn unsubscribe(
306        &self,
307        url: URL,
308        _listener: Arc<dyn NotifyListener>,
309    ) -> Result<(), RPCError> {
310        let service_key = url.get_service_key();
311        self.subscribed.remove(&service_key);
312        Ok(())
313    }
314}
315
316/// Parse a provider URL string produced by `URL::to_full_string()`.
317///
318/// Format: `protocol://ip:port/path/version?key=value&...`
319/// Example: `tri://127.0.0.1:50051//com.example.Service/1.0.0?side=provider`
320fn parse_provider_url(s: &str) -> Option<URL> {
321    let (protocol, rest) = s.split_once("://")?;
322    let (ip_port, path_and_more) = rest.split_once('/')?;
323    let (ip, port) = ip_port.split_once(':')?;
324    let (full_path, params_str) = path_and_more.split_once('?').unwrap_or((path_and_more, ""));
325    let last_slash = full_path.rfind('/')?;
326    let path = &full_path[..last_slash];
327    let version = &full_path[last_slash + 1..];
328
329    let mut url = URL::new(protocol, path);
330    url.ip = ip.to_string();
331    url.port = port.to_string();
332    url.set_param("version", version);
333
334    if !params_str.is_empty() {
335        for pair in params_str.split('&') {
336            if let Some((k, v)) = pair.split_once('=') {
337                url.set_param(k, v);
338            }
339        }
340    }
341
342    Some(url)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_etcd_registry_creation() {
351        let mut url = URL::new("etcd", "/com.example.Service");
352        url.ip = "127.0.0.1".into();
353        url.port = "2379".into();
354        let registry = EtcdRegistry::new(url);
355        assert!(registry.is_available());
356        assert_eq!(registry.root_path, "/dubbo");
357    }
358
359    #[test]
360    fn test_etcd_with_custom_root() {
361        let mut url = URL::new("etcd", "/com.example.Service");
362        url.ip = "127.0.0.1".into();
363        url.port = "2379".into();
364        let registry = EtcdRegistry::new(url).with_root_path("/custom");
365        assert_eq!(registry.root_path, "/custom");
366    }
367
368    #[test]
369    fn test_etcd_with_endpoints() {
370        let mut url = URL::new("etcd", "/com.example.Service");
371        url.ip = "127.0.0.1".into();
372        url.port = "2379".into();
373        let registry = EtcdRegistry::new(url).with_endpoints("http://etcd1:2379,http://etcd2:2379");
374        assert_eq!(
375            registry.endpoints,
376            vec![
377                "http://etcd1:2379".to_string(),
378                "http://etcd2:2379".to_string(),
379            ]
380        );
381    }
382
383    #[test]
384    fn test_provider_path_generation() {
385        let mut url = URL::new("etcd", "/com.example.Service");
386        url.ip = "127.0.0.1".into();
387        url.port = "2379".into();
388        let registry = EtcdRegistry::new(url);
389        let path = registry.provider_path("com.example.Service");
390        assert_eq!(path, "/dubbo/com.example.Service/providers");
391    }
392}