1pub use dubbo_rs_common;
2pub use dubbo_rs_proxy;
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use dubbo_rs_cluster::{Cluster, StaticDirectory};
9use dubbo_rs_common::node::Node;
10use dubbo_rs_common::url::URL;
11use dubbo_rs_config::ProtocolConfig;
12use dubbo_rs_filter::{Filter, FilterChain};
13use dubbo_rs_loadbalance::LoadBalance;
14use dubbo_rs_protocol::{InvocationContext, Invoker, RPCResult};
15use dubbo_rs_registry::Registry;
16use tonic::transport::{Channel, Endpoint};
17
18pub struct Client {
19 protocol_config: Option<ProtocolConfig>,
20 url: Option<String>,
21 channel: Option<Channel>,
22 invoker: Option<Box<dyn Invoker>>,
23 filters: Vec<Box<dyn Filter>>,
24 cluster: Option<Box<dyn Cluster>>,
25 loadbalance: Option<Box<dyn LoadBalance>>,
26 registry: Option<Box<dyn Registry>>,
27}
28
29impl Client {
30 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 protocol_config: None,
34 url: None,
35 channel: None,
36 invoker: None,
37 filters: Vec::new(),
38 cluster: None,
39 loadbalance: None,
40 registry: None,
41 }
42 }
43
44 #[must_use]
45 pub fn with_protocol_config(mut self, config: ProtocolConfig) -> Self {
46 self.protocol_config = Some(config);
47 self
48 }
49
50 #[must_use]
51 pub fn with_url(mut self, url: impl Into<String>) -> Self {
52 self.url = Some(url.into());
53 self
54 }
55
56 #[must_use]
60 pub fn with_filter(mut self, filter: Box<dyn Filter>) -> Self {
61 self.filters.push(filter);
62 self
63 }
64
65 #[must_use]
69 pub fn with_filters(mut self, filters: Vec<Box<dyn Filter>>) -> Self {
70 self.filters = filters;
71 self
72 }
73
74 #[must_use]
76 pub fn with_cluster(mut self, cluster: Box<dyn Cluster>) -> Self {
77 self.cluster = Some(cluster);
78 self
79 }
80
81 #[must_use]
83 pub fn with_loadbalance(mut self, loadbalance: Box<dyn LoadBalance>) -> Self {
84 self.loadbalance = Some(loadbalance);
85 self
86 }
87
88 #[must_use]
94 pub fn with_registry(mut self, registry: Box<dyn Registry>) -> Self {
95 self.registry = Some(registry);
96 self
97 }
98
99 pub async fn dial(&mut self) -> Result<()> {
110 let url_str = self
111 .url
112 .as_ref()
113 .ok_or_else(|| anyhow::anyhow!("No URL set — call with_url() before dial()"))?;
114
115 let (host, port) = parse_triple_url(url_str)?;
116 let addr = format!("http://{host}:{port}");
117
118 let channel = Endpoint::from_shared(addr)?.connect().await?;
119 self.channel = Some(channel.clone());
120
121 let service_path = extract_service_path(url_str);
122 let mut url = URL::new("tri", &service_path);
123 url.ip = host.to_string();
124 url.port = port.to_string();
125
126 let base_invoker: Box<dyn Invoker> = Box::new(TonicInvoker {
127 channel,
128 url: url.clone(),
129 });
130
131 if let Some(cluster) = self.cluster.take() {
134 let dir = StaticDirectory::new(url.clone());
135 let arc_invoker: Arc<dyn Invoker> = Arc::from(base_invoker);
136 dir.add_invoker(arc_invoker);
137 let cluster_invoker = cluster
138 .join(Box::new(dir))
139 .await
140 .map_err(|e| anyhow::anyhow!("cluster join failed: {e}"))?;
141 self.invoker = Some(cluster_invoker);
142 } else if self.filters.is_empty() {
143 self.invoker = Some(base_invoker);
144 } else {
145 let filters: Vec<Box<dyn Filter>> = std::mem::take(&mut self.filters);
146 let chain = FilterChain::new(filters, base_invoker);
147 self.invoker = Some(chain.build());
148 }
149
150 Ok(())
151 }
152
153 #[must_use]
155 pub fn channel(&self) -> Option<&Channel> {
156 self.channel.as_ref()
157 }
158
159 #[must_use]
163 pub fn invoker(&self) -> Option<&dyn Invoker> {
164 self.invoker.as_deref()
165 }
166
167 #[must_use]
168 pub fn protocol_config(&self) -> Option<&ProtocolConfig> {
169 self.protocol_config.as_ref()
170 }
171
172 #[must_use]
173 pub fn url(&self) -> &str {
174 self.url.as_deref().unwrap_or("")
175 }
176}
177
178impl Default for Client {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[allow(dead_code)]
186struct TonicInvoker {
187 channel: Channel,
188 url: URL,
189}
190
191impl Node for TonicInvoker {
192 fn get_url(&self) -> &URL {
193 &self.url
194 }
195
196 fn is_available(&self) -> bool {
197 true
198 }
199
200 fn destroy(&self) {}
201}
202
203#[async_trait]
204impl Invoker for TonicInvoker {
205 async fn invoke(&self, _ctx: &mut InvocationContext) -> Result<RPCResult, anyhow::Error> {
206 Err(anyhow::anyhow!(
207 "TonicInvoker does not support direct invoke. \
208 Use the tonic Channel directly via Client::channel() \
209 for gRPC calls, or wrap this invoker in a protocol-specific invoker."
210 ))
211 }
212}
213
214fn parse_triple_url(url_str: &str) -> Result<(&str, &str)> {
216 let stripped = url_str
217 .strip_prefix("tri://")
218 .ok_or_else(|| anyhow::anyhow!("URL must start with 'tri://': {url_str}"))?;
219
220 let addr_end = stripped.find('/').unwrap_or(stripped.len());
221 let addr = &stripped[..addr_end];
222
223 let (host, port) = addr
224 .split_once(':')
225 .ok_or_else(|| anyhow::anyhow!("URL must contain host:port: {url_str}"))?;
226
227 Ok((host, port))
228}
229
230#[must_use]
231fn extract_service_path(url_str: &str) -> String {
232 let stripped = url_str.strip_prefix("tri://").unwrap_or(url_str);
233
234 if let Some(slash_pos) = stripped.find('/') {
235 stripped[slash_pos..].to_string()
236 } else {
237 "/".to_string()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[tokio::test]
246 async fn test_client_dial_missing_url() {
247 let mut client = Client::new();
248 let result = client.dial().await;
249 assert!(result.is_err());
250 }
251
252 #[tokio::test]
253 async fn test_client_dial_invalid_url() {
254 let mut client = Client::new().with_url("not-a-url");
255 let result = client.dial().await;
256 assert!(result.is_err());
257 }
258
259 #[tokio::test]
260 async fn test_client_dial_bad_prefix() {
261 let mut client = Client::new().with_url("http://127.0.0.1:50051/test");
262 let result = client.dial().await;
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn test_client_channel_before_dial() {
268 let client = Client::new().with_url("tri://127.0.0.1:50051/test");
269 assert!(client.channel().is_none());
270 }
271
272 #[test]
273 fn test_client_builder_default() {
274 let client = Client::new();
275 assert!(client.protocol_config().is_none());
276 }
277
278 #[test]
279 fn test_client_builder_with_config() {
280 let config = ProtocolConfig::new("tri", "127.0.0.1", 50051);
281 let client = Client::new().with_protocol_config(config);
282 assert_eq!(client.protocol_config().unwrap().port, 50051);
283 assert_eq!(client.protocol_config().unwrap().host, "127.0.0.1");
284 }
285
286 #[test]
287 fn test_client_builder_with_url() {
288 let client = Client::new().with_url("tri://127.0.0.1:50051/com.example.GreetService");
289 assert_eq!(
290 client.url(),
291 "tri://127.0.0.1:50051/com.example.GreetService"
292 );
293 }
294
295 #[test]
296 fn test_parse_triple_url() {
297 let (host, port) =
298 parse_triple_url("tri://192.168.1.1:20880/com.example.DemoService").unwrap();
299 assert_eq!(host, "192.168.1.1");
300 assert_eq!(port, "20880");
301 }
302
303 #[test]
304 fn test_parse_triple_url_no_port() {
305 let result = parse_triple_url("tri://127.0.0.1/service");
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn test_parse_triple_url_empty_host() {
311 let (host, port) = parse_triple_url("tri://:50051/service").unwrap();
312 assert_eq!(host, "");
313 assert_eq!(port, "50051");
314 }
315
316 #[test]
317 fn test_parse_triple_url_no_path() {
318 let (host, port) = parse_triple_url("tri://127.0.0.1:50051").unwrap();
319 assert_eq!(host, "127.0.0.1");
320 assert_eq!(port, "50051");
321 }
322
323 #[test]
324 fn test_client_default_url() {
325 let client = Client::new();
326 assert_eq!(client.url(), "");
327 }
328
329 #[test]
330 fn test_client_default_protocol_config() {
331 let client = Client::new();
332 assert!(client.protocol_config().is_none());
333 }
334
335 #[test]
336 fn test_parse_triple_url_long_path() {
337 let (host, port) = parse_triple_url("tri://host:8080/com/example/Service").unwrap();
338 assert_eq!(host, "host");
339 assert_eq!(port, "8080");
340 }
341
342 #[test]
343 fn test_invoker_before_dial() {
344 let client = Client::new().with_url("tri://127.0.0.1:50051/test");
345 assert!(client.invoker().is_none());
346 }
347
348 #[test]
349 fn test_extract_service_path() {
350 assert_eq!(
351 extract_service_path("tri://127.0.0.1:50051/com.example.Service"),
352 "/com.example.Service"
353 );
354 assert_eq!(extract_service_path("tri://127.0.0.1:50051"), "/");
355 assert_eq!(extract_service_path("tri://127.0.0.1:50051/"), "/");
356 }
357
358 #[test]
359 fn test_with_filter_chain_builder() {
360 use dubbo_rs_filter::EchoFilter;
361
362 let client = Client::new()
363 .with_url("tri://127.0.0.1:50051/test")
364 .with_filter(Box::new(EchoFilter));
365
366 assert!(client.channel().is_none());
367 assert!(client.invoker().is_none());
368 }
369
370 #[test]
371 fn test_client_builder_with_filters() {
372 use dubbo_rs_filter::EchoFilter;
373
374 let filters: Vec<Box<dyn Filter>> = vec![Box::new(EchoFilter)];
375
376 let client = Client::new()
377 .with_url("tri://127.0.0.1:50051/test")
378 .with_filters(filters);
379
380 assert!(client.invoker().is_none());
381 }
382
383 #[test]
384 fn test_client_builder_with_cluster() {
385 use dubbo_rs_cluster::FailoverCluster;
386
387 let client = Client::new()
388 .with_url("tri://127.0.0.1:50051/test")
389 .with_cluster(Box::new(FailoverCluster::new().with_retries(5)));
390 assert!(client.invoker().is_none());
391 }
392
393 #[test]
394 fn test_client_builder_with_loadbalance() {
395 use dubbo_rs_loadbalance::RandomLoadBalance;
396
397 let client = Client::new()
398 .with_url("tri://127.0.0.1:50051/test")
399 .with_loadbalance(Box::new(RandomLoadBalance));
400 assert!(client.invoker().is_none());
401 }
402
403 #[test]
404 fn test_client_builder_with_registry() {
405 let registry = TestRegistry;
406
407 let client = Client::new()
408 .with_url("tri://127.0.0.1:50051/test")
409 .with_registry(Box::new(registry));
410 assert!(client.invoker().is_none());
411 }
412
413 #[test]
414 fn test_client_full_builder_chain() {
415 use dubbo_rs_cluster::FailoverCluster;
416 use dubbo_rs_filter::EchoFilter;
417 use dubbo_rs_loadbalance::RandomLoadBalance;
418
419 let client = Client::new()
420 .with_url("tri://127.0.0.1:50051/com.example.Service")
421 .with_protocol_config(ProtocolConfig::new("tri", "127.0.0.1", 50051))
422 .with_filter(Box::new(EchoFilter))
423 .with_cluster(Box::new(FailoverCluster::new()))
424 .with_loadbalance(Box::new(RandomLoadBalance))
425 .with_registry(Box::new(TestRegistry));
426
427 assert_eq!(client.url(), "tri://127.0.0.1:50051/com.example.Service");
428 assert_eq!(client.protocol_config().unwrap().port, 50051);
429 assert!(client.invoker().is_none());
430 }
431
432 #[test]
433 fn test_extract_service_path_edge_cases() {
434 assert_eq!(
435 extract_service_path("tri://192.168.1.1:20880/path/to/service"),
436 "/path/to/service"
437 );
438 assert_eq!(extract_service_path("tri://host:8080"), "/");
439 assert_eq!(extract_service_path("tri://host:8080/"), "/");
440 assert_eq!(extract_service_path(""), "/");
441 }
442
443 use async_trait::async_trait;
446 use dubbo_rs_common::node::Node;
447 use dubbo_rs_registry::Registry;
448
449 struct TestRegistry;
451
452 impl Node for TestRegistry {
453 fn get_url(&self) -> &dubbo_rs_common::url::URL {
454 static DEFAULT_URL: std::sync::LazyLock<dubbo_rs_common::url::URL> =
455 std::sync::LazyLock::new(|| dubbo_rs_common::url::URL::new("test", "/"));
456 &DEFAULT_URL
457 }
458 fn is_available(&self) -> bool {
459 true
460 }
461 fn destroy(&self) {}
462 }
463
464 #[async_trait]
465 impl Registry for TestRegistry {
466 async fn register(
467 &self,
468 _url: dubbo_rs_common::url::URL,
469 ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
470 Ok(())
471 }
472 async fn unregister(
473 &self,
474 _url: dubbo_rs_common::url::URL,
475 ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
476 Ok(())
477 }
478 async fn subscribe(
479 &self,
480 _url: dubbo_rs_common::url::URL,
481 _listener: std::sync::Arc<dyn dubbo_rs_registry::NotifyListener>,
482 ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
483 Ok(())
484 }
485 async fn unsubscribe(
486 &self,
487 _url: dubbo_rs_common::url::URL,
488 _listener: std::sync::Arc<dyn dubbo_rs_registry::NotifyListener>,
489 ) -> std::result::Result<(), dubbo_rs_common::error::RPCError> {
490 Ok(())
491 }
492 }
493}