rubbo/
consumer.rs

1use crate::common::Registry;
2use rubbo_core::{Url, RegistryUrl, RubboReference, Invoker, RpcProtocol, ProtocolKind as Protocol, Request, Result as RubboResult};
3use rubbo_rpc::TripleProtocol;
4use rubbo_registry::{Registry as RegistryTrait, NacosRegistry, InstanceChange};
5use rubbo_cluster::{ClusterInvoker, Directory, LoadBalance, RoundRobinLoadBalance, FilterChain, AccessLogFilter};
6use rubbo_core::Result;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use tracing::{debug, warn};
10use futures::StreamExt;
11use async_trait::async_trait;
12
13pub struct ServiceMetadata {
14    interface_name: String,
15    group: String,
16    version: String,
17}
18
19pub struct ConsumerBuilder {
20    application_name: Option<String>,
21    registry_config: Option<Registry>,
22    protocol: Option<Protocol>,
23    load_balance: Option<Arc<Box<dyn LoadBalance>>>,
24    references: Vec<ServiceMetadata>,
25}
26
27impl Default for ConsumerBuilder {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl ConsumerBuilder {
34    pub fn new() -> Self {
35        Self {
36            application_name: None,
37            registry_config: None,
38            protocol: None,
39            load_balance: None,
40            references: Vec::new(),
41        }
42    }
43
44    pub fn application(mut self, name: &str) -> Self {
45        self.application_name = Some(name.to_string());
46        self
47    }
48
49    pub fn registry(mut self, config: Registry) -> Self {
50         self.registry_config = Some(config);
51         self
52    }
53
54    pub fn protocol(mut self, protocol: Protocol) -> Self {
55        self.protocol = Some(protocol);
56        self
57    }
58
59    pub fn load_balance<L: LoadBalance + 'static>(mut self, load_balance: L) -> Self {
60        self.load_balance = Some(Arc::new(Box::new(load_balance)));
61        self
62    }
63
64    pub fn reference<T: RubboReference + ?Sized>(mut self) -> Self {
65        self.references.push(ServiceMetadata {
66            interface_name: T::interface_name().to_string(),
67            group: T::group().to_string(),
68            version: T::version().to_string(),
69        });
70        self
71    }
72
73    pub async fn build(self) -> Result<Consumer> {
74        let registry = if let Some(Registry::Nacos(addr)) = &self.registry_config {
75             let addr = if let Some(stripped) = addr.strip_prefix("nacos://") {
76                 stripped
77             } else {
78                 addr
79             };
80             
81             let (host, port) = if let Some((h, p)) = addr.split_once(':') {
82                 (h, p.parse::<u16>().unwrap_or(8848))
83             } else {
84                 (addr, 8848)
85             };
86             
87             let registry_url = RegistryUrl::new("nacos", host, port);
88             let reg = NacosRegistry::new(&registry_url).map_err(|e| rubbo_core::Error::Registry(format!("Failed to create Nacos registry: {}", e)))?;
89             Some(Arc::new(reg) as Arc<dyn RegistryTrait>)
90        } else {
91            None
92        };
93
94        let mut invokers = HashMap::new();
95        let application_name = self.application_name.clone().unwrap_or_else(|| "rubbo-consumer".to_string());
96        let protocol_kind = self.protocol.unwrap_or(Protocol::Triple);
97
98        if let Some(registry) = &registry {
99            for meta in self.references {
100                let mut url = Url::new("tri", "0.0.0.0", 0); 
101                url.path = meta.interface_name.clone();
102                url.add_param("interface", &meta.interface_name);
103                url.add_param("group", &meta.group);
104                url.add_param("version", &meta.version);
105                url.add_param("side", "consumer");
106                url.add_param("application", &application_name);
107                
108                // Create RegistryDirectory
109                let directory = RegistryDirectory::new(url.clone(), registry.clone(), protocol_kind.clone()).await?;
110                
111                // Wait for at least one provider (optional, but good for initial check)
112                // In a real production system, we might not want to block here indefinitely or fail hard,
113                // but for this example/dev stage, it's safer to wait.
114                // We can poll directory.list() until it's not empty, with timeout.
115                let start = std::time::Instant::now();
116                loop {
117                    let list = directory.list_internal();
118                    if !list.is_empty() {
119                        debug!("Found {} providers for {}", list.len(), meta.interface_name);
120                        
121                        // Sniff serialization from provider
122                        if let Some(invoker) = list.first()
123                            && let Some(s) = invoker.url().get_param("serialization") {
124                                debug!("Detected serialization from provider: {}", s);
125                                url.add_param("serialization", s);
126                            }
127                        
128                        break;
129                    }
130                    if start.elapsed() > std::time::Duration::from_secs(5) {
131                        warn!("Timeout waiting for providers for {}. Proceeding with empty directory.", meta.interface_name);
132                        break;
133                    }
134                    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
135                }
136
137                // Create ClusterInvoker with LoadBalance
138                let load_balance = self.load_balance.clone()
139                    .unwrap_or_else(|| Arc::new(Box::new(RoundRobinLoadBalance::new()) as Box<dyn LoadBalance>));
140                
141                let cluster_invoker = ClusterInvoker::new(
142                    Arc::new(Box::new(directory) as Box<dyn Directory>),
143                    load_balance,
144                    url.clone()
145                );
146
147                // Wrap with Filters
148                let mut filter_chain = FilterChain::new(Arc::new(Box::new(cluster_invoker) as Box<dyn Invoker>));
149                filter_chain.add_filter(Box::new(AccessLogFilter));
150                
151                // The FilterChain implements Invoker now (via rubbo-cluster/src/filter.rs changes)
152                let chain_invoker = filter_chain;
153                
154                let key = format!("{}:{}:{}", meta.interface_name, meta.group, meta.version);
155                invokers.insert(key, Arc::new(Box::new(chain_invoker) as Box<dyn Invoker>));
156            }
157        }
158
159        Ok(Consumer {
160            application_name,
161            registry,
162            invokers,
163        })
164    }
165}
166
167pub struct Consumer {
168    #[allow(dead_code)]
169    application_name: String,
170    #[allow(dead_code)]
171    registry: Option<Arc<dyn RegistryTrait>>,
172    invokers: HashMap<String, Arc<Box<dyn Invoker>>>,
173}
174
175impl Consumer {
176    pub async fn reference<T: RubboReference + ?Sized>(&self) -> Result<Arc<T>> {
177        let key = T::SERVICE_KEY;
178        
179        if let Some(invoker) = self.invokers.get(key) {
180             Ok(T::create_stub(invoker.clone()))
181        } else {
182             Err(rubbo_core::Error::Other(format!("Service {} not found. Did you forget to add .reference::<T>() to ConsumerBuilder?", T::interface_name())))
183        }
184    }
185}
186
187// --- Helper Structs ---
188
189type SharedInvokers = Arc<RwLock<Vec<Arc<Box<dyn Invoker>>>>>;
190
191struct RegistryDirectory {
192    url: Url,
193    registry: Arc<dyn RegistryTrait>,
194    invokers: SharedInvokers,
195    protocol: Protocol,
196}
197
198impl RegistryDirectory {
199    async fn new(url: Url, registry: Arc<dyn RegistryTrait>, protocol: Protocol) -> Result<Self> {
200        let invokers = Arc::new(RwLock::new(Vec::new()));
201        let dir = Self {
202            url: url.clone(),
203            registry: registry.clone(),
204            invokers: invokers.clone(),
205            protocol,
206        };
207        
208        dir.subscribe().await?;
209        Ok(dir)
210    }
211
212    async fn subscribe(&self) -> Result<()> {
213        let mut stream = self.registry.subscribe(self.url.clone()).await?;
214        let invokers_store = self.invokers.clone();
215        let protocol = self.protocol.clone();
216        let service_name = self.url.path.clone();
217
218        tokio::spawn(async move {
219            while let Some(event) = stream.next().await {
220                match event {
221                    InstanceChange::Upsert { url: provider_url } => {
222                        debug!("Provider update for {}: {}", service_name, provider_url);
223                        
224                        let invoker: Option<Arc<Box<dyn Invoker>>> = match protocol {
225                            Protocol::Triple => {
226                                match TripleProtocol.refer(provider_url.clone()).await {
227                                    Ok(invoker_arc) => {
228                                         let invoker_val = (*invoker_arc).clone();
229                                         Some(Arc::new(Box::new(invoker_val) as Box<dyn Invoker>))
230                                    },
231                                    Err(e) => {
232                                        warn!("Failed to create invoker: {}", e);
233                                        None
234                                    }
235                                }
236                            },
237                            _ => None,
238                        };
239
240                        if let Some(invoker) = invoker {
241                            let mut invokers = invokers_store.write().unwrap();
242                            // Simple deduplication based on URL string
243                            invokers.retain(|i| i.url().to_string() != provider_url.to_string());
244                            invokers.push(invoker);
245                        }
246                    },
247                    InstanceChange::Remove { url: provider_url } => {
248                        debug!("Provider removed for {}: {}", service_name, provider_url);
249                        let mut invokers = invokers_store.write().unwrap();
250                        invokers.retain(|i| i.url().to_string() != provider_url.to_string());
251                    }
252                }
253            }
254        });
255        Ok(())
256    }
257    
258    fn list_internal(&self) -> Vec<Arc<Box<dyn Invoker>>> {
259        self.invokers.read().unwrap().clone()
260    }
261}
262
263#[async_trait]
264impl Directory for RegistryDirectory {
265    async fn list(&self, _req: &Request) -> RubboResult<Vec<Arc<Box<dyn Invoker>>>> {
266        Ok(self.invokers.read().unwrap().clone())
267    }
268
269    fn url(&self) -> &Url {
270        &self.url
271    }
272}