use crate::common::{Registry, Serialize};
use rubbo_core::{Url, RegistryUrl, RubboService, Invoker, RpcProtocol, ProtocolKind as Protocol};
use rubbo_rpc::TripleProtocol;
use rubbo_registry::{Registry as RegistryTrait, NacosRegistry};
use rubbo_cluster::{FilterChain, AccessLogFilter};
use rubbo_core::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
struct DispatcherInvoker {
invokers: HashMap<String, Arc<Box<dyn Invoker>>>,
url: Url,
}
#[async_trait::async_trait]
impl Invoker for DispatcherInvoker {
async fn invoke(&self, req: rubbo_core::Request) -> rubbo_core::Result<rubbo_core::Response> {
if let Some(invoker) = self.invokers.get(&req.service_name) {
invoker.invoke(req).await
} else {
Ok(rubbo_core::Response::error(format!("Service {} not found", req.service_name)))
}
}
fn url(&self) -> &Url {
&self.url
}
}
pub struct ProviderBuilder {
services: Vec<Box<dyn RubboService>>,
application_name: Option<String>,
registry_config: Option<Registry>,
protocol: Option<Protocol>,
serialization: Option<Serialize>,
bind_port: Option<u16>,
}
impl Default for ProviderBuilder {
fn default() -> Self {
Self::new()
}
}
impl ProviderBuilder {
pub fn new() -> Self {
Self {
services: Vec::new(),
application_name: None,
registry_config: None,
protocol: None,
serialization: None,
bind_port: None,
}
}
pub fn service<S: RubboService + 'static>(mut self, service: S) -> Self {
self.services.push(Box::new(service));
self
}
pub fn services(mut self, services: Vec<Box<dyn RubboService>>) -> Self {
self.services.extend(services);
self
}
pub fn application(mut self, name: &str) -> Self {
self.application_name = Some(name.to_string());
self
}
pub fn registry(mut self, config: Registry) -> Self {
self.registry_config = Some(config);
self
}
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = Some(protocol);
self
}
pub fn serialization(mut self, serialization: Serialize) -> Self {
self.serialization = Some(serialization);
self
}
pub fn port(mut self, port: u16) -> Self {
self.bind_port = Some(port);
self
}
pub async fn start(self) -> Result<()> {
if self.services.is_empty() {
return Err(rubbo_core::Error::Other("No services provided".to_string()));
}
let protocol_kind = self.protocol.unwrap_or(Protocol::Triple);
let host = std::env::var("RUBBO_PROVIDER_HOST").unwrap_or_else(|_| {
local_ip_address::local_ip()
.map(|ip| ip.to_string())
.unwrap_or_else(|_| "127.0.0.1".to_string())
});
let port = self.bind_port.unwrap_or(5051);
let mut base_url = Url::new(protocol_kind.to_string().as_str(), &host, port);
if let Some(app_name) = self.application_name {
base_url.add_param("application", &app_name);
}
base_url.add_param("side", "provider");
let mut invokers = HashMap::new();
let mut registration_urls = Vec::new();
for service in self.services {
let interface_name = service.interface_name();
let mut service_url = base_url.clone();
service_url.path = interface_name.clone();
service_url.add_param("interface", &interface_name);
service_url.add_param("methods", &service.methods());
if !service.group().is_empty() {
service_url.add_param("group", &service.group());
}
if !service.version().is_empty() {
service_url.add_param("version", &service.version());
}
service_url.add_param("protocol", &protocol_kind.to_string()); service_url.add_param("rubbo", env!("CARGO_PKG_VERSION")); service_url.add_param("release", env!("CARGO_PKG_VERSION")); service_url.add_param("anyhost", "true"); service_url.add_param("timestamp", &std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis().to_string()); service_url.add_param("category", "providers"); service_url.add_param("dynamic", "true"); service_url.add_param("deprecated", "false"); service_url.add_param("generic", "false"); service_url.add_param("service-name-mapping", "true"); service_url.add_param("side", "provider"); service_url.add_param("pid", &std::process::id().to_string());
let ser = self.serialization.unwrap_or(Serialize::Fastjson);
match ser {
#[cfg(feature = "fastjson")]
Serialize::Fastjson => {
service_url.add_param("serialization", "fastjson");
service_url.add_param("prefer.serialization", "fastjson");
}
#[cfg(feature = "fastjson2")]
Serialize::Fastjson2 => {
service_url.add_param("serialization", "fastjson2");
service_url.add_param("prefer.serialization", "fastjson2");
}
#[cfg(feature = "hessian2")]
Serialize::Hessian2 => {
service_url.add_param("serialization", "hessian2");
service_url.add_param("prefer.serialization", "hessian2");
}
}
let invoker = service.to_invoker(service_url.clone());
invokers.insert(interface_name.clone(), invoker);
let reg_url = service_url.clone();
debug!("Registering service: {}", reg_url.to_string());
registration_urls.push(reg_url);
}
let dispatcher = Arc::new(Box::new(DispatcherInvoker {
invokers,
url: base_url.clone(),
}) as Box<dyn Invoker>);
let mut filter_chain = FilterChain::new(dispatcher);
filter_chain.add_filter(Box::new(AccessLogFilter));
let final_invoker = Arc::new(Box::new(filter_chain) as Box<dyn Invoker>);
match protocol_kind {
Protocol::Triple => {
let protocol = TripleProtocol;
protocol.export(final_invoker).await.map_err(|e| rubbo_core::Error::Other(format!("Failed to export service: {}", e)))?;
}
_ => return Err(rubbo_core::Error::Protocol(format!("Unsupported protocol: {:?}", protocol_kind))),
}
debug!("Server started on port {}", port);
if let Some(Registry::Nacos(addr)) = self.registry_config {
let addr = if let Some(stripped) = addr.strip_prefix("nacos://") {
stripped
} else {
&addr
};
let (host, port) = if let Some((h, p)) = addr.split_once(':') {
(h, p.parse::<u16>().unwrap_or(8848))
} else {
(addr, 8848)
};
let registry_url = RegistryUrl::new("nacos", host, port);
let registry = NacosRegistry::new(®istry_url).map_err(|e| rubbo_core::Error::Registry(format!("Failed to create Nacos registry: {}", e)))?;
for url in registration_urls {
registry.register(url.clone()).await.map_err(|e| rubbo_core::Error::Registry(format!("Failed to register service: {}", e)))?;
debug!("Service registered to Nacos: {}", url.path);
}
}
tokio::signal::ctrl_c().await?;
Ok(())
}
}