use crate::common::Registry;
use rubbo_core::{Url, RegistryUrl, RubboReference, Invoker, RpcProtocol, ProtocolKind as Protocol, Request, Result as RubboResult};
use rubbo_rpc::TripleProtocol;
use rubbo_registry::{Registry as RegistryTrait, NacosRegistry, InstanceChange};
use rubbo_cluster::{ClusterInvoker, Directory, LoadBalance, RoundRobinLoadBalance, FilterChain, AccessLogFilter};
use rubbo_core::Result;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, warn};
use futures::StreamExt;
use async_trait::async_trait;
pub struct ServiceMetadata {
interface_name: String,
group: String,
version: String,
}
pub struct ConsumerBuilder {
application_name: Option<String>,
registry_config: Option<Registry>,
protocol: Option<Protocol>,
load_balance: Option<Arc<Box<dyn LoadBalance>>>,
references: Vec<ServiceMetadata>,
}
impl Default for ConsumerBuilder {
fn default() -> Self {
Self::new()
}
}
impl ConsumerBuilder {
pub fn new() -> Self {
Self {
application_name: None,
registry_config: None,
protocol: None,
load_balance: None,
references: Vec::new(),
}
}
pub fn application(mut self, name: &str) -> Self {
self.application_name = Some(name.to_string());
self
}
pub fn registry(mut self, config: Registry) -> Self {
self.registry_config = Some(config);
self
}
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = Some(protocol);
self
}
pub fn load_balance<L: LoadBalance + 'static>(mut self, load_balance: L) -> Self {
self.load_balance = Some(Arc::new(Box::new(load_balance)));
self
}
pub fn reference<T: RubboReference + ?Sized>(mut self) -> Self {
self.references.push(ServiceMetadata {
interface_name: T::interface_name().to_string(),
group: T::group().to_string(),
version: T::version().to_string(),
});
self
}
pub async fn build(self) -> Result<Consumer> {
let registry = if let Some(Registry::Nacos(addr)) = &self.registry_config {
let addr = if let Some(stripped) = addr.strip_prefix("nacos://") {
stripped
} else {
addr
};
let (host, port) = if let Some((h, p)) = addr.split_once(':') {
(h, p.parse::<u16>().unwrap_or(8848))
} else {
(addr, 8848)
};
let registry_url = RegistryUrl::new("nacos", host, port);
let reg = NacosRegistry::new(®istry_url).map_err(|e| rubbo_core::Error::Registry(format!("Failed to create Nacos registry: {}", e)))?;
Some(Arc::new(reg) as Arc<dyn RegistryTrait>)
} else {
None
};
let mut invokers = HashMap::new();
let application_name = self.application_name.clone().unwrap_or_else(|| "rubbo-consumer".to_string());
let protocol_kind = self.protocol.unwrap_or(Protocol::Triple);
if let Some(registry) = ®istry {
for meta in self.references {
let mut url = Url::new("tri", "0.0.0.0", 0);
url.path = meta.interface_name.clone();
url.add_param("interface", &meta.interface_name);
url.add_param("group", &meta.group);
url.add_param("version", &meta.version);
url.add_param("side", "consumer");
url.add_param("application", &application_name);
let directory = RegistryDirectory::new(url.clone(), registry.clone(), protocol_kind.clone()).await?;
let start = std::time::Instant::now();
loop {
let list = directory.list_internal();
if !list.is_empty() {
debug!("Found {} providers for {}", list.len(), meta.interface_name);
if let Some(invoker) = list.first()
&& let Some(s) = invoker.url().get_param("serialization") {
debug!("Detected serialization from provider: {}", s);
url.add_param("serialization", s);
}
break;
}
if start.elapsed() > std::time::Duration::from_secs(5) {
warn!("Timeout waiting for providers for {}. Proceeding with empty directory.", meta.interface_name);
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let load_balance = self.load_balance.clone()
.unwrap_or_else(|| Arc::new(Box::new(RoundRobinLoadBalance::new()) as Box<dyn LoadBalance>));
let cluster_invoker = ClusterInvoker::new(
Arc::new(Box::new(directory) as Box<dyn Directory>),
load_balance,
url.clone()
);
let mut filter_chain = FilterChain::new(Arc::new(Box::new(cluster_invoker) as Box<dyn Invoker>));
filter_chain.add_filter(Box::new(AccessLogFilter));
let chain_invoker = filter_chain;
let key = format!("{}:{}:{}", meta.interface_name, meta.group, meta.version);
invokers.insert(key, Arc::new(Box::new(chain_invoker) as Box<dyn Invoker>));
}
}
Ok(Consumer {
application_name,
registry,
invokers,
})
}
}
pub struct Consumer {
#[allow(dead_code)]
application_name: String,
#[allow(dead_code)]
registry: Option<Arc<dyn RegistryTrait>>,
invokers: HashMap<String, Arc<Box<dyn Invoker>>>,
}
impl Consumer {
pub async fn reference<T: RubboReference + ?Sized>(&self) -> Result<Arc<T>> {
let key = T::SERVICE_KEY;
if let Some(invoker) = self.invokers.get(key) {
Ok(T::create_stub(invoker.clone()))
} else {
Err(rubbo_core::Error::Other(format!("Service {} not found. Did you forget to add .reference::<T>() to ConsumerBuilder?", T::interface_name())))
}
}
}
type SharedInvokers = Arc<RwLock<Vec<Arc<Box<dyn Invoker>>>>>;
struct RegistryDirectory {
url: Url,
registry: Arc<dyn RegistryTrait>,
invokers: SharedInvokers,
protocol: Protocol,
}
impl RegistryDirectory {
async fn new(url: Url, registry: Arc<dyn RegistryTrait>, protocol: Protocol) -> Result<Self> {
let invokers = Arc::new(RwLock::new(Vec::new()));
let dir = Self {
url: url.clone(),
registry: registry.clone(),
invokers: invokers.clone(),
protocol,
};
dir.subscribe().await?;
Ok(dir)
}
async fn subscribe(&self) -> Result<()> {
let mut stream = self.registry.subscribe(self.url.clone()).await?;
let invokers_store = self.invokers.clone();
let protocol = self.protocol.clone();
let service_name = self.url.path.clone();
tokio::spawn(async move {
while let Some(event) = stream.next().await {
match event {
InstanceChange::Upsert { url: provider_url } => {
debug!("Provider update for {}: {}", service_name, provider_url);
let invoker: Option<Arc<Box<dyn Invoker>>> = match protocol {
Protocol::Triple => {
match TripleProtocol.refer(provider_url.clone()).await {
Ok(invoker_arc) => {
let invoker_val = (*invoker_arc).clone();
Some(Arc::new(Box::new(invoker_val) as Box<dyn Invoker>))
},
Err(e) => {
warn!("Failed to create invoker: {}", e);
None
}
}
},
_ => None,
};
if let Some(invoker) = invoker {
let mut invokers = invokers_store.write().unwrap();
invokers.retain(|i| i.url().to_string() != provider_url.to_string());
invokers.push(invoker);
}
},
InstanceChange::Remove { url: provider_url } => {
debug!("Provider removed for {}: {}", service_name, provider_url);
let mut invokers = invokers_store.write().unwrap();
invokers.retain(|i| i.url().to_string() != provider_url.to_string());
}
}
}
});
Ok(())
}
fn list_internal(&self) -> Vec<Arc<Box<dyn Invoker>>> {
self.invokers.read().unwrap().clone()
}
}
#[async_trait]
impl Directory for RegistryDirectory {
async fn list(&self, _req: &Request) -> RubboResult<Vec<Arc<Box<dyn Invoker>>>> {
Ok(self.invokers.read().unwrap().clone())
}
fn url(&self) -> &Url {
&self.url
}
}