Skip to main content

dubbo_rs_client/
lib.rs

1pub use dubbo_rs_common;
2pub use dubbo_rs_proxy;
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use dubbo_rs_cluster::{Cluster, StaticDirectory};
9use dubbo_rs_common::node::Node;
10use dubbo_rs_common::url::URL;
11use dubbo_rs_config::ProtocolConfig;
12use dubbo_rs_filter::{Filter, FilterChain};
13use dubbo_rs_loadbalance::LoadBalance;
14use dubbo_rs_protocol::{InvocationContext, Invoker, RPCResult};
15use dubbo_rs_registry::Registry;
16use tonic::transport::{Channel, Endpoint};
17
18pub struct Client {
19    protocol_config: Option<ProtocolConfig>,
20    url: Option<String>,
21    channel: Option<Channel>,
22    invoker: Option<Box<dyn Invoker>>,
23    filters: Vec<Box<dyn Filter>>,
24    cluster: Option<Box<dyn Cluster>>,
25    loadbalance: Option<Box<dyn LoadBalance>>,
26    registry: Option<Box<dyn Registry>>,
27}
28
29impl Client {
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            protocol_config: None,
34            url: None,
35            channel: None,
36            invoker: None,
37            filters: Vec::new(),
38            cluster: None,
39            loadbalance: None,
40            registry: None,
41        }
42    }
43
44    #[must_use]
45    pub fn with_protocol_config(mut self, config: ProtocolConfig) -> Self {
46        self.protocol_config = Some(config);
47        self
48    }
49
50    #[must_use]
51    pub fn with_url(mut self, url: impl Into<String>) -> Self {
52        self.url = Some(url.into());
53        self
54    }
55
56    /// Add a single filter to the client's filter chain.
57    ///
58    /// Filters execute in insertion order, outermost first.
59    #[must_use]
60    pub fn with_filter(mut self, filter: Box<dyn Filter>) -> Self {
61        self.filters.push(filter);
62        self
63    }
64
65    /// Add multiple filters to the client's filter chain.
66    ///
67    /// Filters execute in the order given (index 0 is outermost).
68    #[must_use]
69    pub fn with_filters(mut self, filters: Vec<Box<dyn Filter>>) -> Self {
70        self.filters = filters;
71        self
72    }
73
74    /// Set a cluster fault-tolerance strategy for this client.
75    #[must_use]
76    pub fn with_cluster(mut self, cluster: Box<dyn Cluster>) -> Self {
77        self.cluster = Some(cluster);
78        self
79    }
80
81    /// Set a load-balance strategy for this client.
82    #[must_use]
83    pub fn with_loadbalance(mut self, loadbalance: Box<dyn LoadBalance>) -> Self {
84        self.loadbalance = Some(loadbalance);
85        self
86    }
87
88    /// Set a registry for service discovery.
89    ///
90    /// When configured, the client will subscribe to the registry to
91    /// discover provider addresses dynamically instead of using the
92    /// single URL provided via [`with_url`](Self::with_url).
93    #[must_use]
94    pub fn with_registry(mut self, registry: Box<dyn Registry>) -> Self {
95        self.registry = Some(registry);
96        self
97    }
98
99    /// Establish a gRPC connection to the remote server.
100    ///
101    /// Parses the URL to extract host and port, then creates a tonic
102    /// `Channel`.  If filters are configured, wraps the invoker in a
103    /// [`FilterChain`]. Call this before making RPC requests.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if no URL is set, the URL is malformed, or the
108    /// connection cannot be established.
109    pub async fn dial(&mut self) -> Result<()> {
110        let url_str = self
111            .url
112            .as_ref()
113            .ok_or_else(|| anyhow::anyhow!("No URL set — call with_url() before dial()"))?;
114
115        let (host, port) = parse_triple_url(url_str)?;
116        let addr = format!("http://{host}:{port}");
117
118        let channel = Endpoint::from_shared(addr)?.connect().await?;
119        self.channel = Some(channel.clone());
120
121        let service_path = extract_service_path(url_str);
122        let mut url = URL::new("tri", &service_path);
123        url.ip = host.to_string();
124        url.port = port.to_string();
125
126        let base_invoker: Box<dyn Invoker> = Box::new(TonicInvoker {
127            channel,
128            url: url.clone(),
129        });
130
131        // If a cluster strategy is configured, wrap the invoker in a
132        // StaticDirectory and join with the cluster.
133        if let Some(cluster) = self.cluster.take() {
134            let dir = StaticDirectory::new(url.clone());
135            let arc_invoker: Arc<dyn Invoker> = Arc::from(base_invoker);
136            dir.add_invoker(arc_invoker);
137            let cluster_invoker = cluster
138                .join(Box::new(dir))
139                .await
140                .map_err(|e| anyhow::anyhow!("cluster join failed: {e}"))?;
141            self.invoker = Some(cluster_invoker);
142        } else if self.filters.is_empty() {
143            self.invoker = Some(base_invoker);
144        } else {
145            let filters: Vec<Box<dyn Filter>> = std::mem::take(&mut self.filters);
146            let chain = FilterChain::new(filters, base_invoker);
147            self.invoker = Some(chain.build());
148        }
149
150        Ok(())
151    }
152
153    /// Return a reference to the underlying tonic `Channel`, if connected.
154    #[must_use]
155    pub fn channel(&self) -> Option<&Channel> {
156        self.channel.as_ref()
157    }
158
159    /// Return a reference to the Dubbo `Invoker`, if connected.
160    ///
161    /// The invoker is wrapped with the configured filter chain.
162    #[must_use]
163    pub fn invoker(&self) -> Option<&dyn Invoker> {
164        self.invoker.as_deref()
165    }
166
167    #[must_use]
168    pub fn protocol_config(&self) -> Option<&ProtocolConfig> {
169        self.protocol_config.as_ref()
170    }
171
172    #[must_use]
173    pub fn url(&self) -> &str {
174        self.url.as_deref().unwrap_or("")
175    }
176}
177
178impl Default for Client {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// A Dubbo [`Invoker`] backed by a tonic gRPC [`Channel`].
185#[allow(dead_code)]
186struct TonicInvoker {
187    channel: Channel,
188    url: URL,
189}
190
191impl Node for TonicInvoker {
192    fn get_url(&self) -> &URL {
193        &self.url
194    }
195
196    fn is_available(&self) -> bool {
197        true
198    }
199
200    fn destroy(&self) {}
201}
202
203#[async_trait]
204impl Invoker for TonicInvoker {
205    async fn invoke(&self, _ctx: &mut InvocationContext) -> Result<RPCResult, anyhow::Error> {
206        Err(anyhow::anyhow!(
207            "TonicInvoker does not support direct invoke. \
208             Use the tonic Channel directly via Client::channel() \
209             for gRPC calls, or wrap this invoker in a protocol-specific invoker."
210        ))
211    }
212}
213
214/// Parse a triple URL like `<tri://127.0.0.1:50051/com.example.Service>` into (host, port).
215fn parse_triple_url(url_str: &str) -> Result<(&str, &str)> {
216    let stripped = url_str
217        .strip_prefix("tri://")
218        .ok_or_else(|| anyhow::anyhow!("URL must start with 'tri://': {url_str}"))?;
219
220    let addr_end = stripped.find('/').unwrap_or(stripped.len());
221    let addr = &stripped[..addr_end];
222
223    let (host, port) = addr
224        .split_once(':')
225        .ok_or_else(|| anyhow::anyhow!("URL must contain host:port: {url_str}"))?;
226
227    Ok((host, port))
228}
229
230#[must_use]
231fn extract_service_path(url_str: &str) -> String {
232    let stripped = url_str.strip_prefix("tri://").unwrap_or(url_str);
233
234    if let Some(slash_pos) = stripped.find('/') {
235        stripped[slash_pos..].to_string()
236    } else {
237        "/".to_string()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[tokio::test]
246    async fn test_client_dial_missing_url() {
247        let mut client = Client::new();
248        let result = client.dial().await;
249        assert!(result.is_err());
250    }
251
252    #[tokio::test]
253    async fn test_client_dial_invalid_url() {
254        let mut client = Client::new().with_url("not-a-url");
255        let result = client.dial().await;
256        assert!(result.is_err());
257    }
258
259    #[tokio::test]
260    async fn test_client_dial_bad_prefix() {
261        let mut client = Client::new().with_url("http://127.0.0.1:50051/test");
262        let result = client.dial().await;
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn test_client_channel_before_dial() {
268        let client = Client::new().with_url("tri://127.0.0.1:50051/test");
269        assert!(client.channel().is_none());
270    }
271
272    #[test]
273    fn test_client_builder_default() {
274        let client = Client::new();
275        assert!(client.protocol_config().is_none());
276    }
277
278    #[test]
279    fn test_client_builder_with_config() {
280        let config = ProtocolConfig::new("tri", "127.0.0.1", 50051);
281        let client = Client::new().with_protocol_config(config);
282        assert_eq!(client.protocol_config().unwrap().port, 50051);
283        assert_eq!(client.protocol_config().unwrap().host, "127.0.0.1");
284    }
285
286    #[test]
287    fn test_client_builder_with_url() {
288        let client = Client::new().with_url("tri://127.0.0.1:50051/com.example.GreetService");
289        assert_eq!(
290            client.url(),
291            "tri://127.0.0.1:50051/com.example.GreetService"
292        );
293    }
294
295    #[test]
296    fn test_parse_triple_url() {
297        let (host, port) =
298            parse_triple_url("tri://192.168.1.1:20880/com.example.DemoService").unwrap();
299        assert_eq!(host, "192.168.1.1");
300        assert_eq!(port, "20880");
301    }
302
303    #[test]
304    fn test_parse_triple_url_no_port() {
305        let result = parse_triple_url("tri://127.0.0.1/service");
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn test_parse_triple_url_empty_host() {
311        let (host, port) = parse_triple_url("tri://:50051/service").unwrap();
312        assert_eq!(host, "");
313        assert_eq!(port, "50051");
314    }
315
316    #[test]
317    fn test_parse_triple_url_no_path() {
318        let (host, port) = parse_triple_url("tri://127.0.0.1:50051").unwrap();
319        assert_eq!(host, "127.0.0.1");
320        assert_eq!(port, "50051");
321    }
322
323    #[test]
324    fn test_client_default_url() {
325        let client = Client::new();
326        assert_eq!(client.url(), "");
327    }
328
329    #[test]
330    fn test_client_default_protocol_config() {
331        let client = Client::new();
332        assert!(client.protocol_config().is_none());
333    }
334
335    #[test]
336    fn test_parse_triple_url_long_path() {
337        let (host, port) = parse_triple_url("tri://host:8080/com/example/Service").unwrap();
338        assert_eq!(host, "host");
339        assert_eq!(port, "8080");
340    }
341
342    #[test]
343    fn test_invoker_before_dial() {
344        let client = Client::new().with_url("tri://127.0.0.1:50051/test");
345        assert!(client.invoker().is_none());
346    }
347
348    #[test]
349    fn test_extract_service_path() {
350        assert_eq!(
351            extract_service_path("tri://127.0.0.1:50051/com.example.Service"),
352            "/com.example.Service"
353        );
354        assert_eq!(extract_service_path("tri://127.0.0.1:50051"), "/");
355        assert_eq!(extract_service_path("tri://127.0.0.1:50051/"), "/");
356    }
357
358    #[test]
359    fn test_with_filter_chain_builder() {
360        use dubbo_rs_filter::EchoFilter;
361
362        let client = Client::new()
363            .with_url("tri://127.0.0.1:50051/test")
364            .with_filter(Box::new(EchoFilter));
365
366        assert!(client.channel().is_none());
367        assert!(client.invoker().is_none());
368    }
369
370    #[test]
371    fn test_client_builder_with_filters() {
372        use dubbo_rs_filter::EchoFilter;
373
374        let filters: Vec<Box<dyn Filter>> = vec![Box::new(EchoFilter)];
375
376        let client = Client::new()
377            .with_url("tri://127.0.0.1:50051/test")
378            .with_filters(filters);
379
380        assert!(client.invoker().is_none());
381    }
382
383    #[test]
384    fn test_client_builder_with_cluster() {
385        use dubbo_rs_cluster::FailoverCluster;
386
387        let client = Client::new()
388            .with_url("tri://127.0.0.1:50051/test")
389            .with_cluster(Box::new(FailoverCluster::new().with_retries(5)));
390        assert!(client.invoker().is_none());
391    }
392
393    #[test]
394    fn test_client_builder_with_loadbalance() {
395        use dubbo_rs_loadbalance::RandomLoadBalance;
396
397        let client = Client::new()
398            .with_url("tri://127.0.0.1:50051/test")
399            .with_loadbalance(Box::new(RandomLoadBalance));
400        assert!(client.invoker().is_none());
401    }
402
403    #[test]
404    fn test_client_builder_with_registry() {
405        let registry = TestRegistry;
406
407        let client = Client::new()
408            .with_url("tri://127.0.0.1:50051/test")
409            .with_registry(Box::new(registry));
410        assert!(client.invoker().is_none());
411    }
412
413    #[test]
414    fn test_client_full_builder_chain() {
415        use dubbo_rs_cluster::FailoverCluster;
416        use dubbo_rs_filter::EchoFilter;
417        use dubbo_rs_loadbalance::RandomLoadBalance;
418
419        let client = Client::new()
420            .with_url("tri://127.0.0.1:50051/com.example.Service")
421            .with_protocol_config(ProtocolConfig::new("tri", "127.0.0.1", 50051))
422            .with_filter(Box::new(EchoFilter))
423            .with_cluster(Box::new(FailoverCluster::new()))
424            .with_loadbalance(Box::new(RandomLoadBalance))
425            .with_registry(Box::new(TestRegistry));
426
427        assert_eq!(client.url(), "tri://127.0.0.1:50051/com.example.Service");
428        assert_eq!(client.protocol_config().unwrap().port, 50051);
429        assert!(client.invoker().is_none());
430    }
431
432    #[test]
433    fn test_extract_service_path_edge_cases() {
434        assert_eq!(
435            extract_service_path("tri://192.168.1.1:20880/path/to/service"),
436            "/path/to/service"
437        );
438        assert_eq!(extract_service_path("tri://host:8080"), "/");
439        assert_eq!(extract_service_path("tri://host:8080/"), "/");
440        assert_eq!(extract_service_path(""), "/");
441    }
442
443    // ── Test helpers ──────────────────────────────────────────────────
444
445    use async_trait::async_trait;
446    use dubbo_rs_common::node::Node;
447    use dubbo_rs_registry::Registry;
448
449    /// A no-op registry for testing builder methods.
450    struct TestRegistry;
451
452    impl Node for TestRegistry {
453        fn get_url(&self) -> &dubbo_rs_common::url::URL {
454            static DEFAULT_URL: std::sync::LazyLock<dubbo_rs_common::url::URL> =
455                std::sync::LazyLock::new(|| dubbo_rs_common::url::URL::new("test", "/"));
456            &DEFAULT_URL
457        }
458        fn is_available(&self) -> bool {
459            true
460        }
461        fn destroy(&self) {}
462    }
463
464    #[async_trait]
465    impl Registry for TestRegistry {
466        async fn register(
467            &self,
468            _url: dubbo_rs_common::url::URL,
469        ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
470            Ok(())
471        }
472        async fn unregister(
473            &self,
474            _url: dubbo_rs_common::url::URL,
475        ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
476            Ok(())
477        }
478        async fn subscribe(
479            &self,
480            _url: dubbo_rs_common::url::URL,
481            _listener: std::sync::Arc<dyn dubbo_rs_registry::NotifyListener>,
482        ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
483            Ok(())
484        }
485        async fn unsubscribe(
486            &self,
487            _url: dubbo_rs_common::url::URL,
488            _listener: std::sync::Arc<dyn dubbo_rs_registry::NotifyListener>,
489        ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
490            Ok(())
491        }
492    }
493}